Host HTTP and WS on the same port (#414)

* Quick hack to host HTTP and WS on the same port #373

* Quick hack to host HTTP and WS on the same port #373 - simplify code

* ran clang-format

Co-authored-by: En Shih <seanstone5923@gmail.com>
Co-authored-by: The Artful Bodger <TheArtfulBodger@users.noreply.github.com>
This commit is contained in:
TheArtfulBodger 2022-11-06 01:53:11 +00:00 committed by GitHub
parent 472cf68c31
commit b0fd119d14
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 127 additions and 95 deletions

View File

@ -74,28 +74,14 @@ namespace ix
int backlog, int backlog,
size_t maxConnections, size_t maxConnections,
int addressFamily, int addressFamily,
int timeoutSecs) int timeoutSecs,
: SocketServer(port, host, backlog, maxConnections, addressFamily) int handshakeTimeoutSecs)
, _connectedClientsCount(0) : WebSocketServer(port, host, backlog, maxConnections, handshakeTimeoutSecs, addressFamily)
, _timeoutSecs(timeoutSecs) , _timeoutSecs(timeoutSecs)
{ {
setDefaultConnectionCallback(); setDefaultConnectionCallback();
} }
HttpServer::~HttpServer()
{
stop();
}
void HttpServer::stop()
{
stopAcceptingConnections();
// FIXME: cancelling / closing active clients ...
SocketServer::stop();
}
void HttpServer::setOnConnectionCallback(const OnConnectionCallback& callback) void HttpServer::setOnConnectionCallback(const OnConnectionCallback& callback)
{ {
_onConnectionCallback = callback; _onConnectionCallback = callback;
@ -104,34 +90,35 @@ namespace ix
void HttpServer::handleConnection(std::unique_ptr<Socket> socket, void HttpServer::handleConnection(std::unique_ptr<Socket> socket,
std::shared_ptr<ConnectionState> connectionState) std::shared_ptr<ConnectionState> connectionState)
{ {
_connectedClientsCount++;
auto ret = Http::parseRequest(socket, _timeoutSecs); auto ret = Http::parseRequest(socket, _timeoutSecs);
// FIXME: handle errors in parseRequest // FIXME: handle errors in parseRequest
if (std::get<0>(ret)) if (std::get<0>(ret))
{ {
auto response = _onConnectionCallback(std::get<2>(ret), connectionState); auto request = std::get<2>(ret);
if (!Http::sendResponse(response, socket)) std::shared_ptr<ix::HttpResponse> response;
if (request->headers["Upgrade"] == "websocket")
{ {
logError("Cannot send response"); WebSocketServer::handleUpgrade(std::move(socket), connectionState, request);
}
else
{
auto response = _onConnectionCallback(request, connectionState);
if (!Http::sendResponse(response, socket))
{
logError("Cannot send response");
}
} }
} }
connectionState->setTerminated(); connectionState->setTerminated();
_connectedClientsCount--;
}
size_t HttpServer::getConnectedClientsCount()
{
return _connectedClientsCount;
} }
void HttpServer::setDefaultConnectionCallback() void HttpServer::setDefaultConnectionCallback()
{ {
setOnConnectionCallback( setOnConnectionCallback(
[this](HttpRequestPtr request, [this](HttpRequestPtr request,
std::shared_ptr<ConnectionState> connectionState) -> HttpResponsePtr { std::shared_ptr<ConnectionState> connectionState) -> HttpResponsePtr
{
std::string uri(request->uri); std::string uri(request->uri);
if (uri.empty() || uri == "/") if (uri.empty() || uri == "/")
{ {
@ -189,9 +176,9 @@ namespace ix
// See https://developer.mozilla.org/en-US/docs/Web/HTTP/Redirections // See https://developer.mozilla.org/en-US/docs/Web/HTTP/Redirections
// //
setOnConnectionCallback( setOnConnectionCallback(
[this, [this, redirectUrl](HttpRequestPtr request,
redirectUrl](HttpRequestPtr request, std::shared_ptr<ConnectionState> connectionState) -> HttpResponsePtr
std::shared_ptr<ConnectionState> connectionState) -> HttpResponsePtr { {
WebSocketHttpHeaders headers; WebSocketHttpHeaders headers;
headers["Server"] = userAgent(); headers["Server"] = userAgent();
@ -222,7 +209,8 @@ namespace ix
{ {
setOnConnectionCallback( setOnConnectionCallback(
[this](HttpRequestPtr request, [this](HttpRequestPtr request,
std::shared_ptr<ConnectionState> connectionState) -> HttpResponsePtr { std::shared_ptr<ConnectionState> connectionState) -> HttpResponsePtr
{
WebSocketHttpHeaders headers; WebSocketHttpHeaders headers;
headers["Server"] = userAgent(); headers["Server"] = userAgent();

View File

@ -7,8 +7,8 @@
#pragma once #pragma once
#include "IXHttp.h" #include "IXHttp.h"
#include "IXSocketServer.h"
#include "IXWebSocket.h" #include "IXWebSocket.h"
#include "IXWebSocketServer.h"
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <mutex> #include <mutex>
@ -19,7 +19,7 @@
namespace ix namespace ix
{ {
class HttpServer final : public SocketServer class HttpServer final : public WebSocketServer
{ {
public: public:
using OnConnectionCallback = using OnConnectionCallback =
@ -30,9 +30,8 @@ namespace ix
int backlog = SocketServer::kDefaultTcpBacklog, int backlog = SocketServer::kDefaultTcpBacklog,
size_t maxConnections = SocketServer::kDefaultMaxConnections, size_t maxConnections = SocketServer::kDefaultMaxConnections,
int addressFamily = SocketServer::kDefaultAddressFamily, int addressFamily = SocketServer::kDefaultAddressFamily,
int timeoutSecs = HttpServer::kDefaultTimeoutSecs); int timeoutSecs = HttpServer::kDefaultTimeoutSecs,
virtual ~HttpServer(); int handshakeTimeoutSecs = WebSocketServer::kDefaultHandShakeTimeoutSecs);
virtual void stop() final;
void setOnConnectionCallback(const OnConnectionCallback& callback); void setOnConnectionCallback(const OnConnectionCallback& callback);
@ -41,10 +40,10 @@ namespace ix
void makeDebugServer(); void makeDebugServer();
int getTimeoutSecs(); int getTimeoutSecs();
private: private:
// Member variables // Member variables
OnConnectionCallback _onConnectionCallback; OnConnectionCallback _onConnectionCallback;
std::atomic<int> _connectedClientsCount;
const static int kDefaultTimeoutSecs; const static int kDefaultTimeoutSecs;
int _timeoutSecs; int _timeoutSecs;
@ -52,7 +51,6 @@ namespace ix
// Methods // Methods
virtual void handleConnection(std::unique_ptr<Socket>, virtual void handleConnection(std::unique_ptr<Socket>,
std::shared_ptr<ConnectionState> connectionState) final; std::shared_ptr<ConnectionState> connectionState) final;
virtual size_t getConnectedClientsCount() final;
void setDefaultConnectionCallback(); void setDefaultConnectionCallback();
}; };

View File

@ -41,7 +41,8 @@ namespace ix
, _pingIntervalSecs(kDefaultPingIntervalSecs) , _pingIntervalSecs(kDefaultPingIntervalSecs)
{ {
_ws.setOnCloseCallback( _ws.setOnCloseCallback(
[this](uint16_t code, const std::string& reason, size_t wireSize, bool remote) { [this](uint16_t code, const std::string& reason, size_t wireSize, bool remote)
{
_onMessageCallback( _onMessageCallback(
ix::make_unique<WebSocketMessage>(WebSocketMessageType::Close, ix::make_unique<WebSocketMessage>(WebSocketMessageType::Close,
emptyMsg, emptyMsg,
@ -240,7 +241,8 @@ namespace ix
WebSocketInitResult WebSocket::connectToSocket(std::unique_ptr<Socket> socket, WebSocketInitResult WebSocket::connectToSocket(std::unique_ptr<Socket> socket,
int timeoutSecs, int timeoutSecs,
bool enablePerMessageDeflate) bool enablePerMessageDeflate,
HttpRequestPtr request)
{ {
{ {
std::lock_guard<std::mutex> lock(_configMutex); std::lock_guard<std::mutex> lock(_configMutex);
@ -249,7 +251,7 @@ namespace ix
} }
WebSocketInitResult status = WebSocketInitResult status =
_ws.connectToSocket(std::move(socket), timeoutSecs, enablePerMessageDeflate); _ws.connectToSocket(std::move(socket), timeoutSecs, enablePerMessageDeflate, request);
if (!status.success) if (!status.success)
{ {
return status; return status;
@ -384,8 +386,9 @@ namespace ix
[this](const std::string& msg, [this](const std::string& msg,
size_t wireSize, size_t wireSize,
bool decompressionError, bool decompressionError,
WebSocketTransport::MessageKind messageKind) { WebSocketTransport::MessageKind messageKind)
WebSocketMessageType webSocketMessageType{WebSocketMessageType::Error}; {
WebSocketMessageType webSocketMessageType {WebSocketMessageType::Error};
switch (messageKind) switch (messageKind)
{ {
case WebSocketTransport::MessageKind::MSG_TEXT: case WebSocketTransport::MessageKind::MSG_TEXT:

View File

@ -16,8 +16,8 @@
#include "IXWebSocketHttpHeaders.h" #include "IXWebSocketHttpHeaders.h"
#include "IXWebSocketMessage.h" #include "IXWebSocketMessage.h"
#include "IXWebSocketPerMessageDeflateOptions.h" #include "IXWebSocketPerMessageDeflateOptions.h"
#include "IXWebSocketSendInfo.h"
#include "IXWebSocketSendData.h" #include "IXWebSocketSendData.h"
#include "IXWebSocketSendInfo.h"
#include "IXWebSocketTransport.h" #include "IXWebSocketTransport.h"
#include <atomic> #include <atomic>
#include <condition_variable> #include <condition_variable>
@ -128,7 +128,8 @@ namespace ix
// Server // Server
WebSocketInitResult connectToSocket(std::unique_ptr<Socket>, WebSocketInitResult connectToSocket(std::unique_ptr<Socket>,
int timeoutSecs, int timeoutSecs,
bool enablePerMessageDeflate); bool enablePerMessageDeflate,
HttpRequestPtr request = nullptr);
WebSocketTransport _ws; WebSocketTransport _ws;

View File

@ -240,28 +240,42 @@ namespace ix
} }
WebSocketInitResult WebSocketHandshake::serverHandshake(int timeoutSecs, WebSocketInitResult WebSocketHandshake::serverHandshake(int timeoutSecs,
bool enablePerMessageDeflate) bool enablePerMessageDeflate,
HttpRequestPtr request)
{ {
_requestInitCancellation = false; _requestInitCancellation = false;
auto isCancellationRequested = auto isCancellationRequested =
makeCancellationRequestWithTimeout(timeoutSecs, _requestInitCancellation); makeCancellationRequestWithTimeout(timeoutSecs, _requestInitCancellation);
// Read first line std::string method;
auto lineResult = _socket->readLine(isCancellationRequested); std::string uri;
auto lineValid = lineResult.first; std::string httpVersion;
auto line = lineResult.second;
if (!lineValid) if (request)
{ {
return sendErrorResponse(400, "Error reading HTTP request line"); method = request->method;
uri = request->uri;
httpVersion = request->version;
} }
else
{
// Read first line
auto lineResult = _socket->readLine(isCancellationRequested);
auto lineValid = lineResult.first;
auto line = lineResult.second;
// Validate request line (GET /foo HTTP/1.1\r\n) if (!lineValid)
auto requestLine = Http::parseRequestLine(line); {
auto method = std::get<0>(requestLine); return sendErrorResponse(400, "Error reading HTTP request line");
auto uri = std::get<1>(requestLine); }
auto httpVersion = std::get<2>(requestLine);
// Validate request line (GET /foo HTTP/1.1\r\n)
auto requestLine = Http::parseRequestLine(line);
method = std::get<0>(requestLine);
uri = std::get<1>(requestLine);
httpVersion = std::get<2>(requestLine);
}
if (method != "GET") if (method != "GET")
{ {
@ -274,14 +288,22 @@ namespace ix
"Invalid HTTP version, need HTTP/1.1, got: " + httpVersion); "Invalid HTTP version, need HTTP/1.1, got: " + httpVersion);
} }
// Retrieve and validate HTTP headers WebSocketHttpHeaders headers;
auto result = parseHttpHeaders(_socket, isCancellationRequested); if (request)
auto headersValid = result.first;
auto headers = result.second;
if (!headersValid)
{ {
return sendErrorResponse(400, "Error parsing HTTP headers"); headers = request->headers;
}
else
{
// Retrieve and validate HTTP headers
auto result = parseHttpHeaders(_socket, isCancellationRequested);
auto headersValid = result.first;
headers = result.second;
if (!headersValid)
{
return sendErrorResponse(400, "Error parsing HTTP headers");
}
} }
if (headers.find("sec-websocket-key") == headers.end()) if (headers.find("sec-websocket-key") == headers.end())

View File

@ -7,6 +7,7 @@
#pragma once #pragma once
#include "IXCancellationRequest.h" #include "IXCancellationRequest.h"
#include "IXHttp.h"
#include "IXSocket.h" #include "IXSocket.h"
#include "IXWebSocketHttpHeaders.h" #include "IXWebSocketHttpHeaders.h"
#include "IXWebSocketInitResult.h" #include "IXWebSocketInitResult.h"
@ -35,7 +36,9 @@ namespace ix
int port, int port,
int timeoutSecs); int timeoutSecs);
WebSocketInitResult serverHandshake(int timeoutSecs, bool enablePerMessageDeflate); WebSocketInitResult serverHandshake(int timeoutSecs,
bool enablePerMessageDeflate,
HttpRequestPtr request = nullptr);
private: private:
std::string genRandomString(const int len); std::string genRandomString(const int len);

View File

@ -78,6 +78,15 @@ namespace ix
void WebSocketServer::handleConnection(std::unique_ptr<Socket> socket, void WebSocketServer::handleConnection(std::unique_ptr<Socket> socket,
std::shared_ptr<ConnectionState> connectionState) std::shared_ptr<ConnectionState> connectionState)
{
handleUpgrade(std::move(socket), connectionState);
connectionState->setTerminated();
}
void WebSocketServer::handleUpgrade(std::unique_ptr<Socket> socket,
std::shared_ptr<ConnectionState> connectionState,
HttpRequestPtr request)
{ {
setThreadName("Srv:ws:" + connectionState->getId()); setThreadName("Srv:ws:" + connectionState->getId());
@ -89,7 +98,7 @@ namespace ix
if (!webSocket->isOnMessageCallbackRegistered()) if (!webSocket->isOnMessageCallbackRegistered())
{ {
logError("WebSocketServer Application developer error: Server callback improperly " logError("WebSocketServer Application developer error: Server callback improperly "
"registerered."); "registered.");
logError("Missing call to setOnMessageCallback inside setOnConnectionCallback."); logError("Missing call to setOnMessageCallback inside setOnConnectionCallback.");
connectionState->setTerminated(); connectionState->setTerminated();
return; return;
@ -99,9 +108,8 @@ namespace ix
{ {
WebSocket* webSocketRawPtr = webSocket.get(); WebSocket* webSocketRawPtr = webSocket.get();
webSocket->setOnMessageCallback( webSocket->setOnMessageCallback(
[this, webSocketRawPtr, connectionState](const WebSocketMessagePtr& msg) { [this, webSocketRawPtr, connectionState](const WebSocketMessagePtr& msg)
_onClientMessageCallback(connectionState, *webSocketRawPtr, msg); { _onClientMessageCallback(connectionState, *webSocketRawPtr, msg); });
});
} }
else else
{ {
@ -130,7 +138,7 @@ namespace ix
} }
auto status = webSocket->connectToSocket( auto status = webSocket->connectToSocket(
std::move(socket), _handshakeTimeoutSecs, _enablePerMessageDeflate); std::move(socket), _handshakeTimeoutSecs, _enablePerMessageDeflate, request);
if (status.success) if (status.success)
{ {
// Process incoming messages and execute callbacks // Process incoming messages and execute callbacks
@ -155,8 +163,6 @@ namespace ix
logError("Cannot delete client"); logError("Cannot delete client");
} }
} }
connectionState->setTerminated();
} }
std::set<std::shared_ptr<WebSocket>> WebSocketServer::getClients() std::set<std::shared_ptr<WebSocket>> WebSocketServer::getClients()
@ -176,28 +182,30 @@ namespace ix
// //
void WebSocketServer::makeBroadcastServer() void WebSocketServer::makeBroadcastServer()
{ {
setOnClientMessageCallback([this](std::shared_ptr<ConnectionState> connectionState, setOnClientMessageCallback(
WebSocket& webSocket, [this](std::shared_ptr<ConnectionState> connectionState,
const WebSocketMessagePtr& msg) { WebSocket& webSocket,
auto remoteIp = connectionState->getRemoteIp(); const WebSocketMessagePtr& msg)
if (msg->type == ix::WebSocketMessageType::Message)
{ {
for (auto&& client : getClients()) auto remoteIp = connectionState->getRemoteIp();
if (msg->type == ix::WebSocketMessageType::Message)
{ {
if (client.get() != &webSocket) for (auto&& client : getClients())
{ {
client->send(msg->str, msg->binary); if (client.get() != &webSocket)
// Make sure the OS send buffer is flushed before moving on
do
{ {
std::chrono::duration<double, std::milli> duration(500); client->send(msg->str, msg->binary);
std::this_thread::sleep_for(duration);
} while (client->bufferedAmount() != 0); // Make sure the OS send buffer is flushed before moving on
do
{
std::chrono::duration<double, std::milli> duration(500);
std::this_thread::sleep_for(duration);
} while (client->bufferedAmount() != 0);
}
} }
} }
} });
});
} }
bool WebSocketServer::listenAndStart() bool WebSocketServer::listenAndStart()

View File

@ -55,6 +55,7 @@ namespace ix
int getHandshakeTimeoutSecs(); int getHandshakeTimeoutSecs();
bool isPongEnabled(); bool isPongEnabled();
bool isPerMessageDeflateEnabled(); bool isPerMessageDeflateEnabled();
private: private:
// Member variables // Member variables
int _handshakeTimeoutSecs; int _handshakeTimeoutSecs;
@ -73,5 +74,10 @@ namespace ix
virtual void handleConnection(std::unique_ptr<Socket> socket, virtual void handleConnection(std::unique_ptr<Socket> socket,
std::shared_ptr<ConnectionState> connectionState); std::shared_ptr<ConnectionState> connectionState);
virtual size_t getConnectedClientsCount() final; virtual size_t getConnectedClientsCount() final;
protected:
void handleUpgrade(std::unique_ptr<Socket> socket,
std::shared_ptr<ConnectionState> connectionState,
HttpRequestPtr request = nullptr);
}; };
} // namespace ix } // namespace ix

View File

@ -170,7 +170,8 @@ namespace ix
// Server // Server
WebSocketInitResult WebSocketTransport::connectToSocket(std::unique_ptr<Socket> socket, WebSocketInitResult WebSocketTransport::connectToSocket(std::unique_ptr<Socket> socket,
int timeoutSecs, int timeoutSecs,
bool enablePerMessageDeflate) bool enablePerMessageDeflate,
HttpRequestPtr request)
{ {
std::lock_guard<std::mutex> lock(_socketMutex); std::lock_guard<std::mutex> lock(_socketMutex);
@ -187,7 +188,8 @@ namespace ix
_perMessageDeflateOptions, _perMessageDeflateOptions,
_enablePerMessageDeflate); _enablePerMessageDeflate);
auto result = webSocketHandshake.serverHandshake(timeoutSecs, enablePerMessageDeflate); auto result =
webSocketHandshake.serverHandshake(timeoutSecs, enablePerMessageDeflate, request);
if (result.success) if (result.success)
{ {
setReadyState(ReadyState::OPEN); setReadyState(ReadyState::OPEN);

View File

@ -18,8 +18,8 @@
#include "IXWebSocketHttpHeaders.h" #include "IXWebSocketHttpHeaders.h"
#include "IXWebSocketPerMessageDeflate.h" #include "IXWebSocketPerMessageDeflate.h"
#include "IXWebSocketPerMessageDeflateOptions.h" #include "IXWebSocketPerMessageDeflateOptions.h"
#include "IXWebSocketSendInfo.h"
#include "IXWebSocketSendData.h" #include "IXWebSocketSendData.h"
#include "IXWebSocketSendInfo.h"
#include <atomic> #include <atomic>
#include <functional> #include <functional>
#include <list> #include <list>
@ -86,7 +86,8 @@ namespace ix
// Server // Server
WebSocketInitResult connectToSocket(std::unique_ptr<Socket> socket, WebSocketInitResult connectToSocket(std::unique_ptr<Socket> socket,
int timeoutSecs, int timeoutSecs,
bool enablePerMessageDeflate); bool enablePerMessageDeflate,
HttpRequestPtr request = nullptr);
PollResult poll(); PollResult poll();
WebSocketSendInfo sendBinary(const IXWebSocketSendData& message, WebSocketSendInfo sendBinary(const IXWebSocketSendData& message,