Host HTTP and WS on the same port (#414)
* Quick hack to host HTTP and WS on the same port #373 * Quick hack to host HTTP and WS on the same port #373 - simplify code * ran clang-format Co-authored-by: En Shih <seanstone5923@gmail.com> Co-authored-by: The Artful Bodger <TheArtfulBodger@users.noreply.github.com>
This commit is contained in:
		| @@ -74,28 +74,14 @@ namespace ix | ||||
|                            int backlog, | ||||
|                            size_t maxConnections, | ||||
|                            int addressFamily, | ||||
|                            int timeoutSecs) | ||||
|         : SocketServer(port, host, backlog, maxConnections, addressFamily) | ||||
|         , _connectedClientsCount(0) | ||||
|                            int timeoutSecs, | ||||
|                            int handshakeTimeoutSecs) | ||||
|         : WebSocketServer(port, host, backlog, maxConnections, handshakeTimeoutSecs, addressFamily) | ||||
|         , _timeoutSecs(timeoutSecs) | ||||
|     { | ||||
|         setDefaultConnectionCallback(); | ||||
|     } | ||||
|  | ||||
|     HttpServer::~HttpServer() | ||||
|     { | ||||
|         stop(); | ||||
|     } | ||||
|  | ||||
|     void HttpServer::stop() | ||||
|     { | ||||
|         stopAcceptingConnections(); | ||||
|  | ||||
|         // FIXME: cancelling / closing active clients ... | ||||
|  | ||||
|         SocketServer::stop(); | ||||
|     } | ||||
|  | ||||
|     void HttpServer::setOnConnectionCallback(const OnConnectionCallback& callback) | ||||
|     { | ||||
|         _onConnectionCallback = callback; | ||||
| @@ -104,34 +90,35 @@ namespace ix | ||||
|     void HttpServer::handleConnection(std::unique_ptr<Socket> socket, | ||||
|                                       std::shared_ptr<ConnectionState> connectionState) | ||||
|     { | ||||
|         _connectedClientsCount++; | ||||
|  | ||||
|         auto ret = Http::parseRequest(socket, _timeoutSecs); | ||||
|         // FIXME: handle errors in parseRequest | ||||
|  | ||||
|         if (std::get<0>(ret)) | ||||
|         { | ||||
|             auto response = _onConnectionCallback(std::get<2>(ret), connectionState); | ||||
|             if (!Http::sendResponse(response, socket)) | ||||
|             auto request = std::get<2>(ret); | ||||
|             std::shared_ptr<ix::HttpResponse> response; | ||||
|             if (request->headers["Upgrade"] == "websocket") | ||||
|             { | ||||
|                 logError("Cannot send response"); | ||||
|                 WebSocketServer::handleUpgrade(std::move(socket), connectionState, request); | ||||
|             } | ||||
|             else | ||||
|             { | ||||
|                 auto response = _onConnectionCallback(request, connectionState); | ||||
|                 if (!Http::sendResponse(response, socket)) | ||||
|                 { | ||||
|                     logError("Cannot send response"); | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|         connectionState->setTerminated(); | ||||
|  | ||||
|         _connectedClientsCount--; | ||||
|     } | ||||
|  | ||||
|     size_t HttpServer::getConnectedClientsCount() | ||||
|     { | ||||
|         return _connectedClientsCount; | ||||
|     } | ||||
|  | ||||
|     void HttpServer::setDefaultConnectionCallback() | ||||
|     { | ||||
|         setOnConnectionCallback( | ||||
|             [this](HttpRequestPtr request, | ||||
|                    std::shared_ptr<ConnectionState> connectionState) -> HttpResponsePtr { | ||||
|                    std::shared_ptr<ConnectionState> connectionState) -> HttpResponsePtr | ||||
|             { | ||||
|                 std::string uri(request->uri); | ||||
|                 if (uri.empty() || uri == "/") | ||||
|                 { | ||||
| @@ -189,9 +176,9 @@ namespace ix | ||||
|         // See https://developer.mozilla.org/en-US/docs/Web/HTTP/Redirections | ||||
|         // | ||||
|         setOnConnectionCallback( | ||||
|             [this, | ||||
|              redirectUrl](HttpRequestPtr request, | ||||
|                           std::shared_ptr<ConnectionState> connectionState) -> HttpResponsePtr { | ||||
|             [this, redirectUrl](HttpRequestPtr request, | ||||
|                                 std::shared_ptr<ConnectionState> connectionState) -> HttpResponsePtr | ||||
|             { | ||||
|                 WebSocketHttpHeaders headers; | ||||
|                 headers["Server"] = userAgent(); | ||||
|  | ||||
| @@ -222,7 +209,8 @@ namespace ix | ||||
|     { | ||||
|         setOnConnectionCallback( | ||||
|             [this](HttpRequestPtr request, | ||||
|                    std::shared_ptr<ConnectionState> connectionState) -> HttpResponsePtr { | ||||
|                    std::shared_ptr<ConnectionState> connectionState) -> HttpResponsePtr | ||||
|             { | ||||
|                 WebSocketHttpHeaders headers; | ||||
|                 headers["Server"] = userAgent(); | ||||
|  | ||||
|   | ||||
| @@ -7,8 +7,8 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include "IXHttp.h" | ||||
| #include "IXSocketServer.h" | ||||
| #include "IXWebSocket.h" | ||||
| #include "IXWebSocketServer.h" | ||||
| #include <functional> | ||||
| #include <memory> | ||||
| #include <mutex> | ||||
| @@ -19,7 +19,7 @@ | ||||
|  | ||||
| namespace ix | ||||
| { | ||||
|     class HttpServer final : public SocketServer | ||||
|     class HttpServer final : public WebSocketServer | ||||
|     { | ||||
|     public: | ||||
|         using OnConnectionCallback = | ||||
| @@ -30,9 +30,8 @@ namespace ix | ||||
|                    int backlog = SocketServer::kDefaultTcpBacklog, | ||||
|                    size_t maxConnections = SocketServer::kDefaultMaxConnections, | ||||
|                    int addressFamily = SocketServer::kDefaultAddressFamily, | ||||
|                    int timeoutSecs = HttpServer::kDefaultTimeoutSecs); | ||||
|         virtual ~HttpServer(); | ||||
|         virtual void stop() final; | ||||
|                    int timeoutSecs = HttpServer::kDefaultTimeoutSecs, | ||||
|                    int handshakeTimeoutSecs = WebSocketServer::kDefaultHandShakeTimeoutSecs); | ||||
|  | ||||
|         void setOnConnectionCallback(const OnConnectionCallback& callback); | ||||
|  | ||||
| @@ -41,10 +40,10 @@ namespace ix | ||||
|         void makeDebugServer(); | ||||
|  | ||||
|         int getTimeoutSecs(); | ||||
|  | ||||
|     private: | ||||
|         // Member variables | ||||
|         OnConnectionCallback _onConnectionCallback; | ||||
|         std::atomic<int> _connectedClientsCount; | ||||
|  | ||||
|         const static int kDefaultTimeoutSecs; | ||||
|         int _timeoutSecs; | ||||
| @@ -52,7 +51,6 @@ namespace ix | ||||
|         // Methods | ||||
|         virtual void handleConnection(std::unique_ptr<Socket>, | ||||
|                                       std::shared_ptr<ConnectionState> connectionState) final; | ||||
|         virtual size_t getConnectedClientsCount() final; | ||||
|  | ||||
|         void setDefaultConnectionCallback(); | ||||
|     }; | ||||
|   | ||||
| @@ -41,7 +41,8 @@ namespace ix | ||||
|         , _pingIntervalSecs(kDefaultPingIntervalSecs) | ||||
|     { | ||||
|         _ws.setOnCloseCallback( | ||||
|             [this](uint16_t code, const std::string& reason, size_t wireSize, bool remote) { | ||||
|             [this](uint16_t code, const std::string& reason, size_t wireSize, bool remote) | ||||
|             { | ||||
|                 _onMessageCallback( | ||||
|                     ix::make_unique<WebSocketMessage>(WebSocketMessageType::Close, | ||||
|                                                       emptyMsg, | ||||
| @@ -240,7 +241,8 @@ namespace ix | ||||
|  | ||||
|     WebSocketInitResult WebSocket::connectToSocket(std::unique_ptr<Socket> socket, | ||||
|                                                    int timeoutSecs, | ||||
|                                                    bool enablePerMessageDeflate) | ||||
|                                                    bool enablePerMessageDeflate, | ||||
|                                                    HttpRequestPtr request) | ||||
|     { | ||||
|         { | ||||
|             std::lock_guard<std::mutex> lock(_configMutex); | ||||
| @@ -249,7 +251,7 @@ namespace ix | ||||
|         } | ||||
|  | ||||
|         WebSocketInitResult status = | ||||
|             _ws.connectToSocket(std::move(socket), timeoutSecs, enablePerMessageDeflate); | ||||
|             _ws.connectToSocket(std::move(socket), timeoutSecs, enablePerMessageDeflate, request); | ||||
|         if (!status.success) | ||||
|         { | ||||
|             return status; | ||||
| @@ -384,8 +386,9 @@ namespace ix | ||||
|                 [this](const std::string& msg, | ||||
|                        size_t wireSize, | ||||
|                        bool decompressionError, | ||||
|                        WebSocketTransport::MessageKind messageKind) { | ||||
|                     WebSocketMessageType webSocketMessageType{WebSocketMessageType::Error}; | ||||
|                        WebSocketTransport::MessageKind messageKind) | ||||
|                 { | ||||
|                     WebSocketMessageType webSocketMessageType {WebSocketMessageType::Error}; | ||||
|                     switch (messageKind) | ||||
|                     { | ||||
|                         case WebSocketTransport::MessageKind::MSG_TEXT: | ||||
|   | ||||
| @@ -16,8 +16,8 @@ | ||||
| #include "IXWebSocketHttpHeaders.h" | ||||
| #include "IXWebSocketMessage.h" | ||||
| #include "IXWebSocketPerMessageDeflateOptions.h" | ||||
| #include "IXWebSocketSendInfo.h" | ||||
| #include "IXWebSocketSendData.h" | ||||
| #include "IXWebSocketSendInfo.h" | ||||
| #include "IXWebSocketTransport.h" | ||||
| #include <atomic> | ||||
| #include <condition_variable> | ||||
| @@ -128,7 +128,8 @@ namespace ix | ||||
|         // Server | ||||
|         WebSocketInitResult connectToSocket(std::unique_ptr<Socket>, | ||||
|                                             int timeoutSecs, | ||||
|                                             bool enablePerMessageDeflate); | ||||
|                                             bool enablePerMessageDeflate, | ||||
|                                             HttpRequestPtr request = nullptr); | ||||
|  | ||||
|         WebSocketTransport _ws; | ||||
|  | ||||
|   | ||||
| @@ -240,28 +240,42 @@ namespace ix | ||||
|     } | ||||
|  | ||||
|     WebSocketInitResult WebSocketHandshake::serverHandshake(int timeoutSecs, | ||||
|                                                             bool enablePerMessageDeflate) | ||||
|                                                             bool enablePerMessageDeflate, | ||||
|                                                             HttpRequestPtr request) | ||||
|     { | ||||
|         _requestInitCancellation = false; | ||||
|  | ||||
|         auto isCancellationRequested = | ||||
|             makeCancellationRequestWithTimeout(timeoutSecs, _requestInitCancellation); | ||||
|  | ||||
|         // Read first line | ||||
|         auto lineResult = _socket->readLine(isCancellationRequested); | ||||
|         auto lineValid = lineResult.first; | ||||
|         auto line = lineResult.second; | ||||
|         std::string method; | ||||
|         std::string uri; | ||||
|         std::string httpVersion; | ||||
|  | ||||
|         if (!lineValid) | ||||
|         if (request) | ||||
|         { | ||||
|             return sendErrorResponse(400, "Error reading HTTP request line"); | ||||
|             method = request->method; | ||||
|             uri = request->uri; | ||||
|             httpVersion = request->version; | ||||
|         } | ||||
|         else | ||||
|         { | ||||
|             // Read first line | ||||
|             auto lineResult = _socket->readLine(isCancellationRequested); | ||||
|             auto lineValid = lineResult.first; | ||||
|             auto line = lineResult.second; | ||||
|  | ||||
|         // Validate request line (GET /foo HTTP/1.1\r\n) | ||||
|         auto requestLine = Http::parseRequestLine(line); | ||||
|         auto method = std::get<0>(requestLine); | ||||
|         auto uri = std::get<1>(requestLine); | ||||
|         auto httpVersion = std::get<2>(requestLine); | ||||
|             if (!lineValid) | ||||
|             { | ||||
|                 return sendErrorResponse(400, "Error reading HTTP request line"); | ||||
|             } | ||||
|  | ||||
|             // Validate request line (GET /foo HTTP/1.1\r\n) | ||||
|             auto requestLine = Http::parseRequestLine(line); | ||||
|             method = std::get<0>(requestLine); | ||||
|             uri = std::get<1>(requestLine); | ||||
|             httpVersion = std::get<2>(requestLine); | ||||
|         } | ||||
|  | ||||
|         if (method != "GET") | ||||
|         { | ||||
| @@ -274,14 +288,22 @@ namespace ix | ||||
|                                      "Invalid HTTP version, need HTTP/1.1, got: " + httpVersion); | ||||
|         } | ||||
|  | ||||
|         // Retrieve and validate HTTP headers | ||||
|         auto result = parseHttpHeaders(_socket, isCancellationRequested); | ||||
|         auto headersValid = result.first; | ||||
|         auto headers = result.second; | ||||
|  | ||||
|         if (!headersValid) | ||||
|         WebSocketHttpHeaders headers; | ||||
|         if (request) | ||||
|         { | ||||
|             return sendErrorResponse(400, "Error parsing HTTP headers"); | ||||
|             headers = request->headers; | ||||
|         } | ||||
|         else | ||||
|         { | ||||
|             // Retrieve and validate HTTP headers | ||||
|             auto result = parseHttpHeaders(_socket, isCancellationRequested); | ||||
|             auto headersValid = result.first; | ||||
|             headers = result.second; | ||||
|  | ||||
|             if (!headersValid) | ||||
|             { | ||||
|                 return sendErrorResponse(400, "Error parsing HTTP headers"); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         if (headers.find("sec-websocket-key") == headers.end()) | ||||
|   | ||||
| @@ -7,6 +7,7 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include "IXCancellationRequest.h" | ||||
| #include "IXHttp.h" | ||||
| #include "IXSocket.h" | ||||
| #include "IXWebSocketHttpHeaders.h" | ||||
| #include "IXWebSocketInitResult.h" | ||||
| @@ -35,7 +36,9 @@ namespace ix | ||||
|                                             int port, | ||||
|                                             int timeoutSecs); | ||||
|  | ||||
|         WebSocketInitResult serverHandshake(int timeoutSecs, bool enablePerMessageDeflate); | ||||
|         WebSocketInitResult serverHandshake(int timeoutSecs, | ||||
|                                             bool enablePerMessageDeflate, | ||||
|                                             HttpRequestPtr request = nullptr); | ||||
|  | ||||
|     private: | ||||
|         std::string genRandomString(const int len); | ||||
|   | ||||
| @@ -78,6 +78,15 @@ namespace ix | ||||
|  | ||||
|     void WebSocketServer::handleConnection(std::unique_ptr<Socket> socket, | ||||
|                                            std::shared_ptr<ConnectionState> connectionState) | ||||
|     { | ||||
|         handleUpgrade(std::move(socket), connectionState); | ||||
|  | ||||
|         connectionState->setTerminated(); | ||||
|     } | ||||
|  | ||||
|     void WebSocketServer::handleUpgrade(std::unique_ptr<Socket> socket, | ||||
|                                         std::shared_ptr<ConnectionState> connectionState, | ||||
|                                         HttpRequestPtr request) | ||||
|     { | ||||
|         setThreadName("Srv:ws:" + connectionState->getId()); | ||||
|  | ||||
| @@ -89,7 +98,7 @@ namespace ix | ||||
|             if (!webSocket->isOnMessageCallbackRegistered()) | ||||
|             { | ||||
|                 logError("WebSocketServer Application developer error: Server callback improperly " | ||||
|                          "registerered."); | ||||
|                          "registered."); | ||||
|                 logError("Missing call to setOnMessageCallback inside setOnConnectionCallback."); | ||||
|                 connectionState->setTerminated(); | ||||
|                 return; | ||||
| @@ -99,9 +108,8 @@ namespace ix | ||||
|         { | ||||
|             WebSocket* webSocketRawPtr = webSocket.get(); | ||||
|             webSocket->setOnMessageCallback( | ||||
|                 [this, webSocketRawPtr, connectionState](const WebSocketMessagePtr& msg) { | ||||
|                     _onClientMessageCallback(connectionState, *webSocketRawPtr, msg); | ||||
|                 }); | ||||
|                 [this, webSocketRawPtr, connectionState](const WebSocketMessagePtr& msg) | ||||
|                 { _onClientMessageCallback(connectionState, *webSocketRawPtr, msg); }); | ||||
|         } | ||||
|         else | ||||
|         { | ||||
| @@ -130,7 +138,7 @@ namespace ix | ||||
|         } | ||||
|  | ||||
|         auto status = webSocket->connectToSocket( | ||||
|             std::move(socket), _handshakeTimeoutSecs, _enablePerMessageDeflate); | ||||
|             std::move(socket), _handshakeTimeoutSecs, _enablePerMessageDeflate, request); | ||||
|         if (status.success) | ||||
|         { | ||||
|             // Process incoming messages and execute callbacks | ||||
| @@ -155,8 +163,6 @@ namespace ix | ||||
|                 logError("Cannot delete client"); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         connectionState->setTerminated(); | ||||
|     } | ||||
|  | ||||
|     std::set<std::shared_ptr<WebSocket>> WebSocketServer::getClients() | ||||
| @@ -176,28 +182,30 @@ namespace ix | ||||
|     // | ||||
|     void WebSocketServer::makeBroadcastServer() | ||||
|     { | ||||
|         setOnClientMessageCallback([this](std::shared_ptr<ConnectionState> connectionState, | ||||
|                                           WebSocket& webSocket, | ||||
|                                           const WebSocketMessagePtr& msg) { | ||||
|             auto remoteIp = connectionState->getRemoteIp(); | ||||
|             if (msg->type == ix::WebSocketMessageType::Message) | ||||
|         setOnClientMessageCallback( | ||||
|             [this](std::shared_ptr<ConnectionState> connectionState, | ||||
|                    WebSocket& webSocket, | ||||
|                    const WebSocketMessagePtr& msg) | ||||
|             { | ||||
|                 for (auto&& client : getClients()) | ||||
|                 auto remoteIp = connectionState->getRemoteIp(); | ||||
|                 if (msg->type == ix::WebSocketMessageType::Message) | ||||
|                 { | ||||
|                     if (client.get() != &webSocket) | ||||
|                     for (auto&& client : getClients()) | ||||
|                     { | ||||
|                         client->send(msg->str, msg->binary); | ||||
|  | ||||
|                         // Make sure the OS send buffer is flushed before moving on | ||||
|                         do | ||||
|                         if (client.get() != &webSocket) | ||||
|                         { | ||||
|                             std::chrono::duration<double, std::milli> duration(500); | ||||
|                             std::this_thread::sleep_for(duration); | ||||
|                         } while (client->bufferedAmount() != 0); | ||||
|                             client->send(msg->str, msg->binary); | ||||
|  | ||||
|                             // Make sure the OS send buffer is flushed before moving on | ||||
|                             do | ||||
|                             { | ||||
|                                 std::chrono::duration<double, std::milli> duration(500); | ||||
|                                 std::this_thread::sleep_for(duration); | ||||
|                             } while (client->bufferedAmount() != 0); | ||||
|                         } | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|         }); | ||||
|             }); | ||||
|     } | ||||
|  | ||||
|     bool WebSocketServer::listenAndStart() | ||||
|   | ||||
| @@ -55,6 +55,7 @@ namespace ix | ||||
|         int getHandshakeTimeoutSecs(); | ||||
|         bool isPongEnabled(); | ||||
|         bool isPerMessageDeflateEnabled(); | ||||
|  | ||||
|     private: | ||||
|         // Member variables | ||||
|         int _handshakeTimeoutSecs; | ||||
| @@ -73,5 +74,10 @@ namespace ix | ||||
|         virtual void handleConnection(std::unique_ptr<Socket> socket, | ||||
|                                       std::shared_ptr<ConnectionState> connectionState); | ||||
|         virtual size_t getConnectedClientsCount() final; | ||||
|  | ||||
|     protected: | ||||
|         void handleUpgrade(std::unique_ptr<Socket> socket, | ||||
|                            std::shared_ptr<ConnectionState> connectionState, | ||||
|                            HttpRequestPtr request = nullptr); | ||||
|     }; | ||||
| } // namespace ix | ||||
|   | ||||
| @@ -170,7 +170,8 @@ namespace ix | ||||
|     // Server | ||||
|     WebSocketInitResult WebSocketTransport::connectToSocket(std::unique_ptr<Socket> socket, | ||||
|                                                             int timeoutSecs, | ||||
|                                                             bool enablePerMessageDeflate) | ||||
|                                                             bool enablePerMessageDeflate, | ||||
|                                                             HttpRequestPtr request) | ||||
|     { | ||||
|         std::lock_guard<std::mutex> lock(_socketMutex); | ||||
|  | ||||
| @@ -187,7 +188,8 @@ namespace ix | ||||
|                                               _perMessageDeflateOptions, | ||||
|                                               _enablePerMessageDeflate); | ||||
|  | ||||
|         auto result = webSocketHandshake.serverHandshake(timeoutSecs, enablePerMessageDeflate); | ||||
|         auto result = | ||||
|             webSocketHandshake.serverHandshake(timeoutSecs, enablePerMessageDeflate, request); | ||||
|         if (result.success) | ||||
|         { | ||||
|             setReadyState(ReadyState::OPEN); | ||||
|   | ||||
| @@ -18,8 +18,8 @@ | ||||
| #include "IXWebSocketHttpHeaders.h" | ||||
| #include "IXWebSocketPerMessageDeflate.h" | ||||
| #include "IXWebSocketPerMessageDeflateOptions.h" | ||||
| #include "IXWebSocketSendInfo.h" | ||||
| #include "IXWebSocketSendData.h" | ||||
| #include "IXWebSocketSendInfo.h" | ||||
| #include <atomic> | ||||
| #include <functional> | ||||
| #include <list> | ||||
| @@ -86,7 +86,8 @@ namespace ix | ||||
|         // Server | ||||
|         WebSocketInitResult connectToSocket(std::unique_ptr<Socket> socket, | ||||
|                                             int timeoutSecs, | ||||
|                                             bool enablePerMessageDeflate); | ||||
|                                             bool enablePerMessageDeflate, | ||||
|                                             HttpRequestPtr request = nullptr); | ||||
|  | ||||
|         PollResult poll(); | ||||
|         WebSocketSendInfo sendBinary(const IXWebSocketSendData& message, | ||||
|   | ||||
		Reference in New Issue
	
	Block a user