Support server subprotocols

This commit is contained in:
Michał Leśniak 2024-07-05 18:49:31 +02:00
parent 9884c325dd
commit a2cbfcebec
7 changed files with 62 additions and 2 deletions

View File

@ -265,7 +265,8 @@ namespace ix
}
WebSocketInitResult status =
_ws.connectToSocket(std::move(socket), timeoutSecs, enablePerMessageDeflate, request);
_ws.connectToSocket(
std::move(socket), timeoutSecs, enablePerMessageDeflate, getSubProtocols(), request);
if (!status.success)
{
return status;

View File

@ -246,8 +246,30 @@ namespace ix
return WebSocketInitResult(true, status, "", headers, path);
}
void WebSocketHandshake::getSelectedSubProtocol(const std::vector<std::string>& subProtocols,
std::string& selectedSubProtocol,
const std::string& headerSubProtocol)
{
std::stringstream ss;
ss << headerSubProtocol;
std::string protocol;
while (std::getline(ss, protocol, ','))
{
bool subProtocolFound = false;
for (const auto& supportedSubProtocol : subProtocols)
{
if (protocol != supportedSubProtocol) continue;
selectedSubProtocol = protocol;
subProtocolFound = true;
break;
}
if (subProtocolFound) break;
}
}
WebSocketInitResult WebSocketHandshake::serverHandshake(int timeoutSecs,
bool enablePerMessageDeflate,
const std::vector<std::string>& subProtocols,
HttpRequestPtr request)
{
_requestInitCancellation = false;
@ -362,6 +384,17 @@ namespace ix
ss << "Connection: Upgrade\r\n";
ss << "Server: " << userAgent() << "\r\n";
if(!subProtocols.empty())
{
std::string headerSubProtocol = headers["sec-websocket-protocol"];
std::string selectedSubProtocol;
getSelectedSubProtocol(subProtocols, selectedSubProtocol, headerSubProtocol);
if(!selectedSubProtocol.empty())
{
ss << "Sec-WebSocket-Protocol: " << selectedSubProtocol << "\r\n";
}
}
// Parse the client headers. Does it support deflate ?
std::string header = headers["sec-websocket-extensions"];
WebSocketPerMessageDeflateOptions webSocketPerMessageDeflateOptions(header);

View File

@ -39,6 +39,7 @@ namespace ix
WebSocketInitResult serverHandshake(int timeoutSecs,
bool enablePerMessageDeflate,
const std::vector<std::string>& subProtocols,
HttpRequestPtr request = nullptr);
private:
@ -49,6 +50,10 @@ namespace ix
bool insensitiveStringCompare(const std::string& a, const std::string& b);
static void getSelectedSubProtocol(const std::vector<std::string>& subProtocols,
std::string& selectedSubProtocol,
const std::string& headerSubProtocol);
std::atomic<bool>& _requestInitCancellation;
std::unique_ptr<Socket>& _socket;
WebSocketPerMessageDeflatePtr& _perMessageDeflate;

View File

@ -79,6 +79,16 @@ namespace ix
_onClientMessageCallback = callback;
}
void WebSocketServer::addSubProtocol(const std::string& subProtocol)
{
_subProtocols.push_back(subProtocol);
}
const std::vector<std::string>& WebSocketServer::getSubProtocols()
{
return _subProtocols;
}
void WebSocketServer::handleConnection(std::unique_ptr<Socket> socket,
std::shared_ptr<ConnectionState> connectionState)
{
@ -97,6 +107,11 @@ namespace ix
webSocket->setAutoThreadName(false);
webSocket->setPingInterval(_pingIntervalSeconds);
if(!_subProtocols.empty())
{
for(const auto& subProtocol : _subProtocols)
webSocket->addSubProtocol(subProtocol);
}
if (_onConnectionCallback)
{

View File

@ -45,6 +45,9 @@ namespace ix
void setOnConnectionCallback(const OnConnectionCallback& callback);
void setOnClientMessageCallback(const OnClientMessageCallback& callback);
void addSubProtocol(const std::string& subProtocol);
const std::vector<std::string>& getSubProtocols();
// Get all the connected clients
std::set<std::shared_ptr<WebSocket>> getClients();
@ -63,6 +66,7 @@ namespace ix
bool _enablePong;
bool _enablePerMessageDeflate;
int _pingIntervalSeconds;
std::vector<std::string> _subProtocols;
OnConnectionCallback _onConnectionCallback;
OnClientMessageCallback _onClientMessageCallback;

View File

@ -172,6 +172,7 @@ namespace ix
WebSocketInitResult WebSocketTransport::connectToSocket(std::unique_ptr<Socket> socket,
int timeoutSecs,
bool enablePerMessageDeflate,
const std::vector<std::string>& subProtocols,
HttpRequestPtr request)
{
std::lock_guard<std::mutex> lock(_socketMutex);
@ -190,7 +191,7 @@ namespace ix
_enablePerMessageDeflate);
auto result =
webSocketHandshake.serverHandshake(timeoutSecs, enablePerMessageDeflate, request);
webSocketHandshake.serverHandshake(timeoutSecs, enablePerMessageDeflate, subProtocols, request);
if (result.success)
{
setReadyState(ReadyState::OPEN);

View File

@ -88,6 +88,7 @@ namespace ix
WebSocketInitResult connectToSocket(std::unique_ptr<Socket> socket,
int timeoutSecs,
bool enablePerMessageDeflate,
const std::vector<std::string>& subProtocols,
HttpRequestPtr request = nullptr);
PollResult poll();