Host HTTP and WS on the same port (#414)

* Quick hack to host HTTP and WS on the same port #373

* Quick hack to host HTTP and WS on the same port #373 - simplify code

* ran clang-format

Co-authored-by: En Shih <seanstone5923@gmail.com>
Co-authored-by: The Artful Bodger <TheArtfulBodger@users.noreply.github.com>
This commit is contained in:
TheArtfulBodger 2022-11-06 01:53:11 +00:00 committed by GitHub
parent 472cf68c31
commit b0fd119d14
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 127 additions and 95 deletions

View File

@ -74,28 +74,14 @@ namespace ix
int backlog,
size_t maxConnections,
int addressFamily,
int timeoutSecs)
: SocketServer(port, host, backlog, maxConnections, addressFamily)
, _connectedClientsCount(0)
int timeoutSecs,
int handshakeTimeoutSecs)
: WebSocketServer(port, host, backlog, maxConnections, handshakeTimeoutSecs, addressFamily)
, _timeoutSecs(timeoutSecs)
{
setDefaultConnectionCallback();
}
HttpServer::~HttpServer()
{
stop();
}
void HttpServer::stop()
{
stopAcceptingConnections();
// FIXME: cancelling / closing active clients ...
SocketServer::stop();
}
void HttpServer::setOnConnectionCallback(const OnConnectionCallback& callback)
{
_onConnectionCallback = callback;
@ -104,34 +90,35 @@ namespace ix
void HttpServer::handleConnection(std::unique_ptr<Socket> socket,
std::shared_ptr<ConnectionState> connectionState)
{
_connectedClientsCount++;
auto ret = Http::parseRequest(socket, _timeoutSecs);
// FIXME: handle errors in parseRequest
if (std::get<0>(ret))
{
auto response = _onConnectionCallback(std::get<2>(ret), connectionState);
auto request = std::get<2>(ret);
std::shared_ptr<ix::HttpResponse> response;
if (request->headers["Upgrade"] == "websocket")
{
WebSocketServer::handleUpgrade(std::move(socket), connectionState, request);
}
else
{
auto response = _onConnectionCallback(request, connectionState);
if (!Http::sendResponse(response, socket))
{
logError("Cannot send response");
}
}
connectionState->setTerminated();
_connectedClientsCount--;
}
size_t HttpServer::getConnectedClientsCount()
{
return _connectedClientsCount;
connectionState->setTerminated();
}
void HttpServer::setDefaultConnectionCallback()
{
setOnConnectionCallback(
[this](HttpRequestPtr request,
std::shared_ptr<ConnectionState> connectionState) -> HttpResponsePtr {
std::shared_ptr<ConnectionState> connectionState) -> HttpResponsePtr
{
std::string uri(request->uri);
if (uri.empty() || uri == "/")
{
@ -189,9 +176,9 @@ namespace ix
// See https://developer.mozilla.org/en-US/docs/Web/HTTP/Redirections
//
setOnConnectionCallback(
[this,
redirectUrl](HttpRequestPtr request,
std::shared_ptr<ConnectionState> connectionState) -> HttpResponsePtr {
[this, redirectUrl](HttpRequestPtr request,
std::shared_ptr<ConnectionState> connectionState) -> HttpResponsePtr
{
WebSocketHttpHeaders headers;
headers["Server"] = userAgent();
@ -222,7 +209,8 @@ namespace ix
{
setOnConnectionCallback(
[this](HttpRequestPtr request,
std::shared_ptr<ConnectionState> connectionState) -> HttpResponsePtr {
std::shared_ptr<ConnectionState> connectionState) -> HttpResponsePtr
{
WebSocketHttpHeaders headers;
headers["Server"] = userAgent();

View File

@ -7,8 +7,8 @@
#pragma once
#include "IXHttp.h"
#include "IXSocketServer.h"
#include "IXWebSocket.h"
#include "IXWebSocketServer.h"
#include <functional>
#include <memory>
#include <mutex>
@ -19,7 +19,7 @@
namespace ix
{
class HttpServer final : public SocketServer
class HttpServer final : public WebSocketServer
{
public:
using OnConnectionCallback =
@ -30,9 +30,8 @@ namespace ix
int backlog = SocketServer::kDefaultTcpBacklog,
size_t maxConnections = SocketServer::kDefaultMaxConnections,
int addressFamily = SocketServer::kDefaultAddressFamily,
int timeoutSecs = HttpServer::kDefaultTimeoutSecs);
virtual ~HttpServer();
virtual void stop() final;
int timeoutSecs = HttpServer::kDefaultTimeoutSecs,
int handshakeTimeoutSecs = WebSocketServer::kDefaultHandShakeTimeoutSecs);
void setOnConnectionCallback(const OnConnectionCallback& callback);
@ -41,10 +40,10 @@ namespace ix
void makeDebugServer();
int getTimeoutSecs();
private:
// Member variables
OnConnectionCallback _onConnectionCallback;
std::atomic<int> _connectedClientsCount;
const static int kDefaultTimeoutSecs;
int _timeoutSecs;
@ -52,7 +51,6 @@ namespace ix
// Methods
virtual void handleConnection(std::unique_ptr<Socket>,
std::shared_ptr<ConnectionState> connectionState) final;
virtual size_t getConnectedClientsCount() final;
void setDefaultConnectionCallback();
};

View File

@ -41,7 +41,8 @@ namespace ix
, _pingIntervalSecs(kDefaultPingIntervalSecs)
{
_ws.setOnCloseCallback(
[this](uint16_t code, const std::string& reason, size_t wireSize, bool remote) {
[this](uint16_t code, const std::string& reason, size_t wireSize, bool remote)
{
_onMessageCallback(
ix::make_unique<WebSocketMessage>(WebSocketMessageType::Close,
emptyMsg,
@ -240,7 +241,8 @@ namespace ix
WebSocketInitResult WebSocket::connectToSocket(std::unique_ptr<Socket> socket,
int timeoutSecs,
bool enablePerMessageDeflate)
bool enablePerMessageDeflate,
HttpRequestPtr request)
{
{
std::lock_guard<std::mutex> lock(_configMutex);
@ -249,7 +251,7 @@ namespace ix
}
WebSocketInitResult status =
_ws.connectToSocket(std::move(socket), timeoutSecs, enablePerMessageDeflate);
_ws.connectToSocket(std::move(socket), timeoutSecs, enablePerMessageDeflate, request);
if (!status.success)
{
return status;
@ -384,7 +386,8 @@ namespace ix
[this](const std::string& msg,
size_t wireSize,
bool decompressionError,
WebSocketTransport::MessageKind messageKind) {
WebSocketTransport::MessageKind messageKind)
{
WebSocketMessageType webSocketMessageType {WebSocketMessageType::Error};
switch (messageKind)
{

View File

@ -16,8 +16,8 @@
#include "IXWebSocketHttpHeaders.h"
#include "IXWebSocketMessage.h"
#include "IXWebSocketPerMessageDeflateOptions.h"
#include "IXWebSocketSendInfo.h"
#include "IXWebSocketSendData.h"
#include "IXWebSocketSendInfo.h"
#include "IXWebSocketTransport.h"
#include <atomic>
#include <condition_variable>
@ -128,7 +128,8 @@ namespace ix
// Server
WebSocketInitResult connectToSocket(std::unique_ptr<Socket>,
int timeoutSecs,
bool enablePerMessageDeflate);
bool enablePerMessageDeflate,
HttpRequestPtr request = nullptr);
WebSocketTransport _ws;

View File

@ -240,13 +240,26 @@ namespace ix
}
WebSocketInitResult WebSocketHandshake::serverHandshake(int timeoutSecs,
bool enablePerMessageDeflate)
bool enablePerMessageDeflate,
HttpRequestPtr request)
{
_requestInitCancellation = false;
auto isCancellationRequested =
makeCancellationRequestWithTimeout(timeoutSecs, _requestInitCancellation);
std::string method;
std::string uri;
std::string httpVersion;
if (request)
{
method = request->method;
uri = request->uri;
httpVersion = request->version;
}
else
{
// Read first line
auto lineResult = _socket->readLine(isCancellationRequested);
auto lineValid = lineResult.first;
@ -259,9 +272,10 @@ namespace ix
// Validate request line (GET /foo HTTP/1.1\r\n)
auto requestLine = Http::parseRequestLine(line);
auto method = std::get<0>(requestLine);
auto uri = std::get<1>(requestLine);
auto httpVersion = std::get<2>(requestLine);
method = std::get<0>(requestLine);
uri = std::get<1>(requestLine);
httpVersion = std::get<2>(requestLine);
}
if (method != "GET")
{
@ -274,15 +288,23 @@ namespace ix
"Invalid HTTP version, need HTTP/1.1, got: " + httpVersion);
}
WebSocketHttpHeaders headers;
if (request)
{
headers = request->headers;
}
else
{
// Retrieve and validate HTTP headers
auto result = parseHttpHeaders(_socket, isCancellationRequested);
auto headersValid = result.first;
auto headers = result.second;
headers = result.second;
if (!headersValid)
{
return sendErrorResponse(400, "Error parsing HTTP headers");
}
}
if (headers.find("sec-websocket-key") == headers.end())
{

View File

@ -7,6 +7,7 @@
#pragma once
#include "IXCancellationRequest.h"
#include "IXHttp.h"
#include "IXSocket.h"
#include "IXWebSocketHttpHeaders.h"
#include "IXWebSocketInitResult.h"
@ -35,7 +36,9 @@ namespace ix
int port,
int timeoutSecs);
WebSocketInitResult serverHandshake(int timeoutSecs, bool enablePerMessageDeflate);
WebSocketInitResult serverHandshake(int timeoutSecs,
bool enablePerMessageDeflate,
HttpRequestPtr request = nullptr);
private:
std::string genRandomString(const int len);

View File

@ -78,6 +78,15 @@ namespace ix
void WebSocketServer::handleConnection(std::unique_ptr<Socket> socket,
std::shared_ptr<ConnectionState> connectionState)
{
handleUpgrade(std::move(socket), connectionState);
connectionState->setTerminated();
}
void WebSocketServer::handleUpgrade(std::unique_ptr<Socket> socket,
std::shared_ptr<ConnectionState> connectionState,
HttpRequestPtr request)
{
setThreadName("Srv:ws:" + connectionState->getId());
@ -89,7 +98,7 @@ namespace ix
if (!webSocket->isOnMessageCallbackRegistered())
{
logError("WebSocketServer Application developer error: Server callback improperly "
"registerered.");
"registered.");
logError("Missing call to setOnMessageCallback inside setOnConnectionCallback.");
connectionState->setTerminated();
return;
@ -99,9 +108,8 @@ namespace ix
{
WebSocket* webSocketRawPtr = webSocket.get();
webSocket->setOnMessageCallback(
[this, webSocketRawPtr, connectionState](const WebSocketMessagePtr& msg) {
_onClientMessageCallback(connectionState, *webSocketRawPtr, msg);
});
[this, webSocketRawPtr, connectionState](const WebSocketMessagePtr& msg)
{ _onClientMessageCallback(connectionState, *webSocketRawPtr, msg); });
}
else
{
@ -130,7 +138,7 @@ namespace ix
}
auto status = webSocket->connectToSocket(
std::move(socket), _handshakeTimeoutSecs, _enablePerMessageDeflate);
std::move(socket), _handshakeTimeoutSecs, _enablePerMessageDeflate, request);
if (status.success)
{
// Process incoming messages and execute callbacks
@ -155,8 +163,6 @@ namespace ix
logError("Cannot delete client");
}
}
connectionState->setTerminated();
}
std::set<std::shared_ptr<WebSocket>> WebSocketServer::getClients()
@ -176,9 +182,11 @@ namespace ix
//
void WebSocketServer::makeBroadcastServer()
{
setOnClientMessageCallback([this](std::shared_ptr<ConnectionState> connectionState,
setOnClientMessageCallback(
[this](std::shared_ptr<ConnectionState> connectionState,
WebSocket& webSocket,
const WebSocketMessagePtr& msg) {
const WebSocketMessagePtr& msg)
{
auto remoteIp = connectionState->getRemoteIp();
if (msg->type == ix::WebSocketMessageType::Message)
{

View File

@ -55,6 +55,7 @@ namespace ix
int getHandshakeTimeoutSecs();
bool isPongEnabled();
bool isPerMessageDeflateEnabled();
private:
// Member variables
int _handshakeTimeoutSecs;
@ -73,5 +74,10 @@ namespace ix
virtual void handleConnection(std::unique_ptr<Socket> socket,
std::shared_ptr<ConnectionState> connectionState);
virtual size_t getConnectedClientsCount() final;
protected:
void handleUpgrade(std::unique_ptr<Socket> socket,
std::shared_ptr<ConnectionState> connectionState,
HttpRequestPtr request = nullptr);
};
} // namespace ix

View File

@ -170,7 +170,8 @@ namespace ix
// Server
WebSocketInitResult WebSocketTransport::connectToSocket(std::unique_ptr<Socket> socket,
int timeoutSecs,
bool enablePerMessageDeflate)
bool enablePerMessageDeflate,
HttpRequestPtr request)
{
std::lock_guard<std::mutex> lock(_socketMutex);
@ -187,7 +188,8 @@ namespace ix
_perMessageDeflateOptions,
_enablePerMessageDeflate);
auto result = webSocketHandshake.serverHandshake(timeoutSecs, enablePerMessageDeflate);
auto result =
webSocketHandshake.serverHandshake(timeoutSecs, enablePerMessageDeflate, request);
if (result.success)
{
setReadyState(ReadyState::OPEN);

View File

@ -18,8 +18,8 @@
#include "IXWebSocketHttpHeaders.h"
#include "IXWebSocketPerMessageDeflate.h"
#include "IXWebSocketPerMessageDeflateOptions.h"
#include "IXWebSocketSendInfo.h"
#include "IXWebSocketSendData.h"
#include "IXWebSocketSendInfo.h"
#include <atomic>
#include <functional>
#include <list>
@ -86,7 +86,8 @@ namespace ix
// Server
WebSocketInitResult connectToSocket(std::unique_ptr<Socket> socket,
int timeoutSecs,
bool enablePerMessageDeflate);
bool enablePerMessageDeflate,
HttpRequestPtr request = nullptr);
PollResult poll();
WebSocketSendInfo sendBinary(const IXWebSocketSendData& message,