From a2cbfcebec76457c9bf5f9cbf2df92a0a13ec10e Mon Sep 17 00:00:00 2001 From: Kapitan Date: Fri, 5 Jul 2024 18:49:31 +0200 Subject: [PATCH] Support server subprotocols --- ixwebsocket/IXWebSocket.cpp | 3 ++- ixwebsocket/IXWebSocketHandshake.cpp | 33 ++++++++++++++++++++++++++++ ixwebsocket/IXWebSocketHandshake.h | 5 +++++ ixwebsocket/IXWebSocketServer.cpp | 15 +++++++++++++ ixwebsocket/IXWebSocketServer.h | 4 ++++ ixwebsocket/IXWebSocketTransport.cpp | 3 ++- ixwebsocket/IXWebSocketTransport.h | 1 + 7 files changed, 62 insertions(+), 2 deletions(-) diff --git a/ixwebsocket/IXWebSocket.cpp b/ixwebsocket/IXWebSocket.cpp index 8ba50eea..fe88a78f 100644 --- a/ixwebsocket/IXWebSocket.cpp +++ b/ixwebsocket/IXWebSocket.cpp @@ -265,7 +265,8 @@ namespace ix } WebSocketInitResult status = - _ws.connectToSocket(std::move(socket), timeoutSecs, enablePerMessageDeflate, request); + _ws.connectToSocket( + std::move(socket), timeoutSecs, enablePerMessageDeflate, getSubProtocols(), request); if (!status.success) { return status; diff --git a/ixwebsocket/IXWebSocketHandshake.cpp b/ixwebsocket/IXWebSocketHandshake.cpp index 206b0790..39a05a6d 100644 --- a/ixwebsocket/IXWebSocketHandshake.cpp +++ b/ixwebsocket/IXWebSocketHandshake.cpp @@ -246,8 +246,30 @@ namespace ix return WebSocketInitResult(true, status, "", headers, path); } + void WebSocketHandshake::getSelectedSubProtocol(const std::vector& subProtocols, + std::string& selectedSubProtocol, + const std::string& headerSubProtocol) + { + std::stringstream ss; + ss << headerSubProtocol; + std::string protocol; + while (std::getline(ss, protocol, ',')) + { + bool subProtocolFound = false; + for (const auto& supportedSubProtocol : subProtocols) + { + if (protocol != supportedSubProtocol) continue; + selectedSubProtocol = protocol; + subProtocolFound = true; + break; + } + if (subProtocolFound) break; + } + } + WebSocketInitResult WebSocketHandshake::serverHandshake(int timeoutSecs, bool enablePerMessageDeflate, + const std::vector& subProtocols, HttpRequestPtr request) { _requestInitCancellation = false; @@ -362,6 +384,17 @@ namespace ix ss << "Connection: Upgrade\r\n"; ss << "Server: " << userAgent() << "\r\n"; + if(!subProtocols.empty()) + { + std::string headerSubProtocol = headers["sec-websocket-protocol"]; + std::string selectedSubProtocol; + getSelectedSubProtocol(subProtocols, selectedSubProtocol, headerSubProtocol); + if(!selectedSubProtocol.empty()) + { + ss << "Sec-WebSocket-Protocol: " << selectedSubProtocol << "\r\n"; + } + } + // Parse the client headers. Does it support deflate ? std::string header = headers["sec-websocket-extensions"]; WebSocketPerMessageDeflateOptions webSocketPerMessageDeflateOptions(header); diff --git a/ixwebsocket/IXWebSocketHandshake.h b/ixwebsocket/IXWebSocketHandshake.h index a6bf2a15..81569b1b 100644 --- a/ixwebsocket/IXWebSocketHandshake.h +++ b/ixwebsocket/IXWebSocketHandshake.h @@ -39,6 +39,7 @@ namespace ix WebSocketInitResult serverHandshake(int timeoutSecs, bool enablePerMessageDeflate, + const std::vector& subProtocols, HttpRequestPtr request = nullptr); private: @@ -49,6 +50,10 @@ namespace ix bool insensitiveStringCompare(const std::string& a, const std::string& b); + static void getSelectedSubProtocol(const std::vector& subProtocols, + std::string& selectedSubProtocol, + const std::string& headerSubProtocol); + std::atomic& _requestInitCancellation; std::unique_ptr& _socket; WebSocketPerMessageDeflatePtr& _perMessageDeflate; diff --git a/ixwebsocket/IXWebSocketServer.cpp b/ixwebsocket/IXWebSocketServer.cpp index cb6988a5..ff4f59f5 100644 --- a/ixwebsocket/IXWebSocketServer.cpp +++ b/ixwebsocket/IXWebSocketServer.cpp @@ -79,6 +79,16 @@ namespace ix _onClientMessageCallback = callback; } + void WebSocketServer::addSubProtocol(const std::string& subProtocol) + { + _subProtocols.push_back(subProtocol); + } + + const std::vector& WebSocketServer::getSubProtocols() + { + return _subProtocols; + } + void WebSocketServer::handleConnection(std::unique_ptr socket, std::shared_ptr connectionState) { @@ -97,6 +107,11 @@ namespace ix webSocket->setAutoThreadName(false); webSocket->setPingInterval(_pingIntervalSeconds); + if(!_subProtocols.empty()) + { + for(const auto& subProtocol : _subProtocols) + webSocket->addSubProtocol(subProtocol); + } if (_onConnectionCallback) { diff --git a/ixwebsocket/IXWebSocketServer.h b/ixwebsocket/IXWebSocketServer.h index 7636074e..877e6e5e 100644 --- a/ixwebsocket/IXWebSocketServer.h +++ b/ixwebsocket/IXWebSocketServer.h @@ -45,6 +45,9 @@ namespace ix void setOnConnectionCallback(const OnConnectionCallback& callback); void setOnClientMessageCallback(const OnClientMessageCallback& callback); + void addSubProtocol(const std::string& subProtocol); + const std::vector& getSubProtocols(); + // Get all the connected clients std::set> getClients(); @@ -63,6 +66,7 @@ namespace ix bool _enablePong; bool _enablePerMessageDeflate; int _pingIntervalSeconds; + std::vector _subProtocols; OnConnectionCallback _onConnectionCallback; OnClientMessageCallback _onClientMessageCallback; diff --git a/ixwebsocket/IXWebSocketTransport.cpp b/ixwebsocket/IXWebSocketTransport.cpp index f9d36c52..47cd2ee3 100644 --- a/ixwebsocket/IXWebSocketTransport.cpp +++ b/ixwebsocket/IXWebSocketTransport.cpp @@ -172,6 +172,7 @@ namespace ix WebSocketInitResult WebSocketTransport::connectToSocket(std::unique_ptr socket, int timeoutSecs, bool enablePerMessageDeflate, + const std::vector& subProtocols, HttpRequestPtr request) { std::lock_guard lock(_socketMutex); @@ -190,7 +191,7 @@ namespace ix _enablePerMessageDeflate); auto result = - webSocketHandshake.serverHandshake(timeoutSecs, enablePerMessageDeflate, request); + webSocketHandshake.serverHandshake(timeoutSecs, enablePerMessageDeflate, subProtocols, request); if (result.success) { setReadyState(ReadyState::OPEN); diff --git a/ixwebsocket/IXWebSocketTransport.h b/ixwebsocket/IXWebSocketTransport.h index 8473c55c..075c789b 100644 --- a/ixwebsocket/IXWebSocketTransport.h +++ b/ixwebsocket/IXWebSocketTransport.h @@ -88,6 +88,7 @@ namespace ix WebSocketInitResult connectToSocket(std::unique_ptr socket, int timeoutSecs, bool enablePerMessageDeflate, + const std::vector& subProtocols, HttpRequestPtr request = nullptr); PollResult poll();