Support server subprotocols

This commit is contained in:
Michał Leśniak 2024-07-05 18:49:31 +02:00
parent 9884c325dd
commit a2cbfcebec
7 changed files with 62 additions and 2 deletions

View File

@ -265,7 +265,8 @@ namespace ix
} }
WebSocketInitResult status = WebSocketInitResult status =
_ws.connectToSocket(std::move(socket), timeoutSecs, enablePerMessageDeflate, request); _ws.connectToSocket(
std::move(socket), timeoutSecs, enablePerMessageDeflate, getSubProtocols(), request);
if (!status.success) if (!status.success)
{ {
return status; return status;

View File

@ -246,8 +246,30 @@ namespace ix
return WebSocketInitResult(true, status, "", headers, path); return WebSocketInitResult(true, status, "", headers, path);
} }
void WebSocketHandshake::getSelectedSubProtocol(const std::vector<std::string>& subProtocols,
std::string& selectedSubProtocol,
const std::string& headerSubProtocol)
{
std::stringstream ss;
ss << headerSubProtocol;
std::string protocol;
while (std::getline(ss, protocol, ','))
{
bool subProtocolFound = false;
for (const auto& supportedSubProtocol : subProtocols)
{
if (protocol != supportedSubProtocol) continue;
selectedSubProtocol = protocol;
subProtocolFound = true;
break;
}
if (subProtocolFound) break;
}
}
WebSocketInitResult WebSocketHandshake::serverHandshake(int timeoutSecs, WebSocketInitResult WebSocketHandshake::serverHandshake(int timeoutSecs,
bool enablePerMessageDeflate, bool enablePerMessageDeflate,
const std::vector<std::string>& subProtocols,
HttpRequestPtr request) HttpRequestPtr request)
{ {
_requestInitCancellation = false; _requestInitCancellation = false;
@ -362,6 +384,17 @@ namespace ix
ss << "Connection: Upgrade\r\n"; ss << "Connection: Upgrade\r\n";
ss << "Server: " << userAgent() << "\r\n"; ss << "Server: " << userAgent() << "\r\n";
if(!subProtocols.empty())
{
std::string headerSubProtocol = headers["sec-websocket-protocol"];
std::string selectedSubProtocol;
getSelectedSubProtocol(subProtocols, selectedSubProtocol, headerSubProtocol);
if(!selectedSubProtocol.empty())
{
ss << "Sec-WebSocket-Protocol: " << selectedSubProtocol << "\r\n";
}
}
// Parse the client headers. Does it support deflate ? // Parse the client headers. Does it support deflate ?
std::string header = headers["sec-websocket-extensions"]; std::string header = headers["sec-websocket-extensions"];
WebSocketPerMessageDeflateOptions webSocketPerMessageDeflateOptions(header); WebSocketPerMessageDeflateOptions webSocketPerMessageDeflateOptions(header);

View File

@ -39,6 +39,7 @@ namespace ix
WebSocketInitResult serverHandshake(int timeoutSecs, WebSocketInitResult serverHandshake(int timeoutSecs,
bool enablePerMessageDeflate, bool enablePerMessageDeflate,
const std::vector<std::string>& subProtocols,
HttpRequestPtr request = nullptr); HttpRequestPtr request = nullptr);
private: private:
@ -49,6 +50,10 @@ namespace ix
bool insensitiveStringCompare(const std::string& a, const std::string& b); bool insensitiveStringCompare(const std::string& a, const std::string& b);
static void getSelectedSubProtocol(const std::vector<std::string>& subProtocols,
std::string& selectedSubProtocol,
const std::string& headerSubProtocol);
std::atomic<bool>& _requestInitCancellation; std::atomic<bool>& _requestInitCancellation;
std::unique_ptr<Socket>& _socket; std::unique_ptr<Socket>& _socket;
WebSocketPerMessageDeflatePtr& _perMessageDeflate; WebSocketPerMessageDeflatePtr& _perMessageDeflate;

View File

@ -79,6 +79,16 @@ namespace ix
_onClientMessageCallback = callback; _onClientMessageCallback = callback;
} }
void WebSocketServer::addSubProtocol(const std::string& subProtocol)
{
_subProtocols.push_back(subProtocol);
}
const std::vector<std::string>& WebSocketServer::getSubProtocols()
{
return _subProtocols;
}
void WebSocketServer::handleConnection(std::unique_ptr<Socket> socket, void WebSocketServer::handleConnection(std::unique_ptr<Socket> socket,
std::shared_ptr<ConnectionState> connectionState) std::shared_ptr<ConnectionState> connectionState)
{ {
@ -97,6 +107,11 @@ namespace ix
webSocket->setAutoThreadName(false); webSocket->setAutoThreadName(false);
webSocket->setPingInterval(_pingIntervalSeconds); webSocket->setPingInterval(_pingIntervalSeconds);
if(!_subProtocols.empty())
{
for(const auto& subProtocol : _subProtocols)
webSocket->addSubProtocol(subProtocol);
}
if (_onConnectionCallback) if (_onConnectionCallback)
{ {

View File

@ -45,6 +45,9 @@ namespace ix
void setOnConnectionCallback(const OnConnectionCallback& callback); void setOnConnectionCallback(const OnConnectionCallback& callback);
void setOnClientMessageCallback(const OnClientMessageCallback& callback); void setOnClientMessageCallback(const OnClientMessageCallback& callback);
void addSubProtocol(const std::string& subProtocol);
const std::vector<std::string>& getSubProtocols();
// Get all the connected clients // Get all the connected clients
std::set<std::shared_ptr<WebSocket>> getClients(); std::set<std::shared_ptr<WebSocket>> getClients();
@ -63,6 +66,7 @@ namespace ix
bool _enablePong; bool _enablePong;
bool _enablePerMessageDeflate; bool _enablePerMessageDeflate;
int _pingIntervalSeconds; int _pingIntervalSeconds;
std::vector<std::string> _subProtocols;
OnConnectionCallback _onConnectionCallback; OnConnectionCallback _onConnectionCallback;
OnClientMessageCallback _onClientMessageCallback; OnClientMessageCallback _onClientMessageCallback;

View File

@ -172,6 +172,7 @@ namespace ix
WebSocketInitResult WebSocketTransport::connectToSocket(std::unique_ptr<Socket> socket, WebSocketInitResult WebSocketTransport::connectToSocket(std::unique_ptr<Socket> socket,
int timeoutSecs, int timeoutSecs,
bool enablePerMessageDeflate, bool enablePerMessageDeflate,
const std::vector<std::string>& subProtocols,
HttpRequestPtr request) HttpRequestPtr request)
{ {
std::lock_guard<std::mutex> lock(_socketMutex); std::lock_guard<std::mutex> lock(_socketMutex);
@ -190,7 +191,7 @@ namespace ix
_enablePerMessageDeflate); _enablePerMessageDeflate);
auto result = auto result =
webSocketHandshake.serverHandshake(timeoutSecs, enablePerMessageDeflate, request); webSocketHandshake.serverHandshake(timeoutSecs, enablePerMessageDeflate, subProtocols, request);
if (result.success) if (result.success)
{ {
setReadyState(ReadyState::OPEN); setReadyState(ReadyState::OPEN);

View File

@ -88,6 +88,7 @@ namespace ix
WebSocketInitResult connectToSocket(std::unique_ptr<Socket> socket, WebSocketInitResult connectToSocket(std::unique_ptr<Socket> socket,
int timeoutSecs, int timeoutSecs,
bool enablePerMessageDeflate, bool enablePerMessageDeflate,
const std::vector<std::string>& subProtocols,
HttpRequestPtr request = nullptr); HttpRequestPtr request = nullptr);
PollResult poll(); PollResult poll();