diff --git a/ixwebsocket/IXHttpServer.cpp b/ixwebsocket/IXHttpServer.cpp index c015854c..406d7d25 100644 --- a/ixwebsocket/IXHttpServer.cpp +++ b/ixwebsocket/IXHttpServer.cpp @@ -74,28 +74,14 @@ namespace ix int backlog, size_t maxConnections, int addressFamily, - int timeoutSecs) - : SocketServer(port, host, backlog, maxConnections, addressFamily) - , _connectedClientsCount(0) + int timeoutSecs, + int handshakeTimeoutSecs) + : WebSocketServer(port, host, backlog, maxConnections, handshakeTimeoutSecs, addressFamily) , _timeoutSecs(timeoutSecs) { setDefaultConnectionCallback(); } - HttpServer::~HttpServer() - { - stop(); - } - - void HttpServer::stop() - { - stopAcceptingConnections(); - - // FIXME: cancelling / closing active clients ... - - SocketServer::stop(); - } - void HttpServer::setOnConnectionCallback(const OnConnectionCallback& callback) { _onConnectionCallback = callback; @@ -104,34 +90,35 @@ namespace ix void HttpServer::handleConnection(std::unique_ptr socket, std::shared_ptr connectionState) { - _connectedClientsCount++; - auto ret = Http::parseRequest(socket, _timeoutSecs); // FIXME: handle errors in parseRequest if (std::get<0>(ret)) { - auto response = _onConnectionCallback(std::get<2>(ret), connectionState); - if (!Http::sendResponse(response, socket)) + auto request = std::get<2>(ret); + std::shared_ptr response; + if (request->headers["Upgrade"] == "websocket") { - logError("Cannot send response"); + WebSocketServer::handleUpgrade(std::move(socket), connectionState, request); + } + else + { + auto response = _onConnectionCallback(request, connectionState); + if (!Http::sendResponse(response, socket)) + { + logError("Cannot send response"); + } } } connectionState->setTerminated(); - - _connectedClientsCount--; - } - - size_t HttpServer::getConnectedClientsCount() - { - return _connectedClientsCount; } void HttpServer::setDefaultConnectionCallback() { setOnConnectionCallback( [this](HttpRequestPtr request, - std::shared_ptr connectionState) -> HttpResponsePtr { + std::shared_ptr connectionState) -> HttpResponsePtr + { std::string uri(request->uri); if (uri.empty() || uri == "/") { @@ -189,9 +176,9 @@ namespace ix // See https://developer.mozilla.org/en-US/docs/Web/HTTP/Redirections // setOnConnectionCallback( - [this, - redirectUrl](HttpRequestPtr request, - std::shared_ptr connectionState) -> HttpResponsePtr { + [this, redirectUrl](HttpRequestPtr request, + std::shared_ptr connectionState) -> HttpResponsePtr + { WebSocketHttpHeaders headers; headers["Server"] = userAgent(); @@ -222,7 +209,8 @@ namespace ix { setOnConnectionCallback( [this](HttpRequestPtr request, - std::shared_ptr connectionState) -> HttpResponsePtr { + std::shared_ptr connectionState) -> HttpResponsePtr + { WebSocketHttpHeaders headers; headers["Server"] = userAgent(); diff --git a/ixwebsocket/IXHttpServer.h b/ixwebsocket/IXHttpServer.h index 7de67631..0d74d709 100644 --- a/ixwebsocket/IXHttpServer.h +++ b/ixwebsocket/IXHttpServer.h @@ -7,8 +7,8 @@ #pragma once #include "IXHttp.h" -#include "IXSocketServer.h" #include "IXWebSocket.h" +#include "IXWebSocketServer.h" #include #include #include @@ -19,7 +19,7 @@ namespace ix { - class HttpServer final : public SocketServer + class HttpServer final : public WebSocketServer { public: using OnConnectionCallback = @@ -30,9 +30,8 @@ namespace ix int backlog = SocketServer::kDefaultTcpBacklog, size_t maxConnections = SocketServer::kDefaultMaxConnections, int addressFamily = SocketServer::kDefaultAddressFamily, - int timeoutSecs = HttpServer::kDefaultTimeoutSecs); - virtual ~HttpServer(); - virtual void stop() final; + int timeoutSecs = HttpServer::kDefaultTimeoutSecs, + int handshakeTimeoutSecs = WebSocketServer::kDefaultHandShakeTimeoutSecs); void setOnConnectionCallback(const OnConnectionCallback& callback); @@ -41,10 +40,10 @@ namespace ix void makeDebugServer(); int getTimeoutSecs(); + private: // Member variables OnConnectionCallback _onConnectionCallback; - std::atomic _connectedClientsCount; const static int kDefaultTimeoutSecs; int _timeoutSecs; @@ -52,7 +51,6 @@ namespace ix // Methods virtual void handleConnection(std::unique_ptr, std::shared_ptr connectionState) final; - virtual size_t getConnectedClientsCount() final; void setDefaultConnectionCallback(); }; diff --git a/ixwebsocket/IXWebSocket.cpp b/ixwebsocket/IXWebSocket.cpp index 7aab7902..5c3ec9ad 100644 --- a/ixwebsocket/IXWebSocket.cpp +++ b/ixwebsocket/IXWebSocket.cpp @@ -41,7 +41,8 @@ namespace ix , _pingIntervalSecs(kDefaultPingIntervalSecs) { _ws.setOnCloseCallback( - [this](uint16_t code, const std::string& reason, size_t wireSize, bool remote) { + [this](uint16_t code, const std::string& reason, size_t wireSize, bool remote) + { _onMessageCallback( ix::make_unique(WebSocketMessageType::Close, emptyMsg, @@ -240,7 +241,8 @@ namespace ix WebSocketInitResult WebSocket::connectToSocket(std::unique_ptr socket, int timeoutSecs, - bool enablePerMessageDeflate) + bool enablePerMessageDeflate, + HttpRequestPtr request) { { std::lock_guard lock(_configMutex); @@ -249,7 +251,7 @@ namespace ix } WebSocketInitResult status = - _ws.connectToSocket(std::move(socket), timeoutSecs, enablePerMessageDeflate); + _ws.connectToSocket(std::move(socket), timeoutSecs, enablePerMessageDeflate, request); if (!status.success) { return status; @@ -384,8 +386,9 @@ namespace ix [this](const std::string& msg, size_t wireSize, bool decompressionError, - WebSocketTransport::MessageKind messageKind) { - WebSocketMessageType webSocketMessageType{WebSocketMessageType::Error}; + WebSocketTransport::MessageKind messageKind) + { + WebSocketMessageType webSocketMessageType {WebSocketMessageType::Error}; switch (messageKind) { case WebSocketTransport::MessageKind::MSG_TEXT: diff --git a/ixwebsocket/IXWebSocket.h b/ixwebsocket/IXWebSocket.h index 37df88ca..84ef3b85 100644 --- a/ixwebsocket/IXWebSocket.h +++ b/ixwebsocket/IXWebSocket.h @@ -16,8 +16,8 @@ #include "IXWebSocketHttpHeaders.h" #include "IXWebSocketMessage.h" #include "IXWebSocketPerMessageDeflateOptions.h" -#include "IXWebSocketSendInfo.h" #include "IXWebSocketSendData.h" +#include "IXWebSocketSendInfo.h" #include "IXWebSocketTransport.h" #include #include @@ -128,7 +128,8 @@ namespace ix // Server WebSocketInitResult connectToSocket(std::unique_ptr, int timeoutSecs, - bool enablePerMessageDeflate); + bool enablePerMessageDeflate, + HttpRequestPtr request = nullptr); WebSocketTransport _ws; diff --git a/ixwebsocket/IXWebSocketHandshake.cpp b/ixwebsocket/IXWebSocketHandshake.cpp index 53b3c806..a0216dd7 100644 --- a/ixwebsocket/IXWebSocketHandshake.cpp +++ b/ixwebsocket/IXWebSocketHandshake.cpp @@ -240,28 +240,42 @@ namespace ix } WebSocketInitResult WebSocketHandshake::serverHandshake(int timeoutSecs, - bool enablePerMessageDeflate) + bool enablePerMessageDeflate, + HttpRequestPtr request) { _requestInitCancellation = false; auto isCancellationRequested = makeCancellationRequestWithTimeout(timeoutSecs, _requestInitCancellation); - // Read first line - auto lineResult = _socket->readLine(isCancellationRequested); - auto lineValid = lineResult.first; - auto line = lineResult.second; + std::string method; + std::string uri; + std::string httpVersion; - if (!lineValid) + if (request) { - return sendErrorResponse(400, "Error reading HTTP request line"); + method = request->method; + uri = request->uri; + httpVersion = request->version; } + else + { + // Read first line + auto lineResult = _socket->readLine(isCancellationRequested); + auto lineValid = lineResult.first; + auto line = lineResult.second; - // Validate request line (GET /foo HTTP/1.1\r\n) - auto requestLine = Http::parseRequestLine(line); - auto method = std::get<0>(requestLine); - auto uri = std::get<1>(requestLine); - auto httpVersion = std::get<2>(requestLine); + if (!lineValid) + { + return sendErrorResponse(400, "Error reading HTTP request line"); + } + + // Validate request line (GET /foo HTTP/1.1\r\n) + auto requestLine = Http::parseRequestLine(line); + method = std::get<0>(requestLine); + uri = std::get<1>(requestLine); + httpVersion = std::get<2>(requestLine); + } if (method != "GET") { @@ -274,14 +288,22 @@ namespace ix "Invalid HTTP version, need HTTP/1.1, got: " + httpVersion); } - // Retrieve and validate HTTP headers - auto result = parseHttpHeaders(_socket, isCancellationRequested); - auto headersValid = result.first; - auto headers = result.second; - - if (!headersValid) + WebSocketHttpHeaders headers; + if (request) { - return sendErrorResponse(400, "Error parsing HTTP headers"); + headers = request->headers; + } + else + { + // Retrieve and validate HTTP headers + auto result = parseHttpHeaders(_socket, isCancellationRequested); + auto headersValid = result.first; + headers = result.second; + + if (!headersValid) + { + return sendErrorResponse(400, "Error parsing HTTP headers"); + } } if (headers.find("sec-websocket-key") == headers.end()) diff --git a/ixwebsocket/IXWebSocketHandshake.h b/ixwebsocket/IXWebSocketHandshake.h index 0a275e43..3c153791 100644 --- a/ixwebsocket/IXWebSocketHandshake.h +++ b/ixwebsocket/IXWebSocketHandshake.h @@ -7,6 +7,7 @@ #pragma once #include "IXCancellationRequest.h" +#include "IXHttp.h" #include "IXSocket.h" #include "IXWebSocketHttpHeaders.h" #include "IXWebSocketInitResult.h" @@ -35,7 +36,9 @@ namespace ix int port, int timeoutSecs); - WebSocketInitResult serverHandshake(int timeoutSecs, bool enablePerMessageDeflate); + WebSocketInitResult serverHandshake(int timeoutSecs, + bool enablePerMessageDeflate, + HttpRequestPtr request = nullptr); private: std::string genRandomString(const int len); diff --git a/ixwebsocket/IXWebSocketServer.cpp b/ixwebsocket/IXWebSocketServer.cpp index 90593d50..03b0ea50 100644 --- a/ixwebsocket/IXWebSocketServer.cpp +++ b/ixwebsocket/IXWebSocketServer.cpp @@ -78,6 +78,15 @@ namespace ix void WebSocketServer::handleConnection(std::unique_ptr socket, std::shared_ptr connectionState) + { + handleUpgrade(std::move(socket), connectionState); + + connectionState->setTerminated(); + } + + void WebSocketServer::handleUpgrade(std::unique_ptr socket, + std::shared_ptr connectionState, + HttpRequestPtr request) { setThreadName("Srv:ws:" + connectionState->getId()); @@ -89,7 +98,7 @@ namespace ix if (!webSocket->isOnMessageCallbackRegistered()) { logError("WebSocketServer Application developer error: Server callback improperly " - "registerered."); + "registered."); logError("Missing call to setOnMessageCallback inside setOnConnectionCallback."); connectionState->setTerminated(); return; @@ -99,9 +108,8 @@ namespace ix { WebSocket* webSocketRawPtr = webSocket.get(); webSocket->setOnMessageCallback( - [this, webSocketRawPtr, connectionState](const WebSocketMessagePtr& msg) { - _onClientMessageCallback(connectionState, *webSocketRawPtr, msg); - }); + [this, webSocketRawPtr, connectionState](const WebSocketMessagePtr& msg) + { _onClientMessageCallback(connectionState, *webSocketRawPtr, msg); }); } else { @@ -130,7 +138,7 @@ namespace ix } auto status = webSocket->connectToSocket( - std::move(socket), _handshakeTimeoutSecs, _enablePerMessageDeflate); + std::move(socket), _handshakeTimeoutSecs, _enablePerMessageDeflate, request); if (status.success) { // Process incoming messages and execute callbacks @@ -155,8 +163,6 @@ namespace ix logError("Cannot delete client"); } } - - connectionState->setTerminated(); } std::set> WebSocketServer::getClients() @@ -176,28 +182,30 @@ namespace ix // void WebSocketServer::makeBroadcastServer() { - setOnClientMessageCallback([this](std::shared_ptr connectionState, - WebSocket& webSocket, - const WebSocketMessagePtr& msg) { - auto remoteIp = connectionState->getRemoteIp(); - if (msg->type == ix::WebSocketMessageType::Message) + setOnClientMessageCallback( + [this](std::shared_ptr connectionState, + WebSocket& webSocket, + const WebSocketMessagePtr& msg) { - for (auto&& client : getClients()) + auto remoteIp = connectionState->getRemoteIp(); + if (msg->type == ix::WebSocketMessageType::Message) { - if (client.get() != &webSocket) + for (auto&& client : getClients()) { - client->send(msg->str, msg->binary); - - // Make sure the OS send buffer is flushed before moving on - do + if (client.get() != &webSocket) { - std::chrono::duration duration(500); - std::this_thread::sleep_for(duration); - } while (client->bufferedAmount() != 0); + client->send(msg->str, msg->binary); + + // Make sure the OS send buffer is flushed before moving on + do + { + std::chrono::duration duration(500); + std::this_thread::sleep_for(duration); + } while (client->bufferedAmount() != 0); + } } } - } - }); + }); } bool WebSocketServer::listenAndStart() diff --git a/ixwebsocket/IXWebSocketServer.h b/ixwebsocket/IXWebSocketServer.h index 6cae6331..dcb21e81 100644 --- a/ixwebsocket/IXWebSocketServer.h +++ b/ixwebsocket/IXWebSocketServer.h @@ -55,6 +55,7 @@ namespace ix int getHandshakeTimeoutSecs(); bool isPongEnabled(); bool isPerMessageDeflateEnabled(); + private: // Member variables int _handshakeTimeoutSecs; @@ -73,5 +74,10 @@ namespace ix virtual void handleConnection(std::unique_ptr socket, std::shared_ptr connectionState); virtual size_t getConnectedClientsCount() final; + + protected: + void handleUpgrade(std::unique_ptr socket, + std::shared_ptr connectionState, + HttpRequestPtr request = nullptr); }; } // namespace ix diff --git a/ixwebsocket/IXWebSocketTransport.cpp b/ixwebsocket/IXWebSocketTransport.cpp index 86ec52e8..f4d89e7e 100644 --- a/ixwebsocket/IXWebSocketTransport.cpp +++ b/ixwebsocket/IXWebSocketTransport.cpp @@ -170,7 +170,8 @@ namespace ix // Server WebSocketInitResult WebSocketTransport::connectToSocket(std::unique_ptr socket, int timeoutSecs, - bool enablePerMessageDeflate) + bool enablePerMessageDeflate, + HttpRequestPtr request) { std::lock_guard lock(_socketMutex); @@ -187,7 +188,8 @@ namespace ix _perMessageDeflateOptions, _enablePerMessageDeflate); - auto result = webSocketHandshake.serverHandshake(timeoutSecs, enablePerMessageDeflate); + auto result = + webSocketHandshake.serverHandshake(timeoutSecs, enablePerMessageDeflate, request); if (result.success) { setReadyState(ReadyState::OPEN); diff --git a/ixwebsocket/IXWebSocketTransport.h b/ixwebsocket/IXWebSocketTransport.h index bdfd409f..7e906daa 100644 --- a/ixwebsocket/IXWebSocketTransport.h +++ b/ixwebsocket/IXWebSocketTransport.h @@ -18,8 +18,8 @@ #include "IXWebSocketHttpHeaders.h" #include "IXWebSocketPerMessageDeflate.h" #include "IXWebSocketPerMessageDeflateOptions.h" -#include "IXWebSocketSendInfo.h" #include "IXWebSocketSendData.h" +#include "IXWebSocketSendInfo.h" #include #include #include @@ -86,7 +86,8 @@ namespace ix // Server WebSocketInitResult connectToSocket(std::unique_ptr socket, int timeoutSecs, - bool enablePerMessageDeflate); + bool enablePerMessageDeflate, + HttpRequestPtr request = nullptr); PollResult poll(); WebSocketSendInfo sendBinary(const IXWebSocketSendData& message,