SocketServer::handleConnection takes an std::shared_ptr<Socket> instead of a file descriptor

This commit is contained in:
Benjamin Sergeant 2019-09-30 21:48:55 -07:00
parent 562d7484e4
commit 1ed39677ce
14 changed files with 47 additions and 49 deletions

View File

@ -8,7 +8,6 @@
#include <ixwebsocket/IXNetSystem.h> #include <ixwebsocket/IXNetSystem.h>
#include <ixwebsocket/IXSocketConnect.h> #include <ixwebsocket/IXSocketConnect.h>
#include <ixwebsocket/IXSocketFactory.h>
#include <ixwebsocket/IXSocket.h> #include <ixwebsocket/IXSocket.h>
#include <ixwebsocket/IXCancellationRequest.h> #include <ixwebsocket/IXCancellationRequest.h>
#include <fstream> #include <fstream>
@ -45,16 +44,11 @@ namespace ix
SocketServer::stop(); SocketServer::stop();
} }
void RedisServer::handleConnection(int fd, std::shared_ptr<ConnectionState> connectionState) void RedisServer::handleConnection(std::shared_ptr<Socket> socket,
std::shared_ptr<ConnectionState> connectionState)
{ {
_connectedClientsCount++; _connectedClientsCount++;
std::string errorMsg;
auto socket = createSocket(fd, errorMsg);
// Set the socket to non blocking mode + other tweaks
SocketConnect::configure(fd);
while (!_stopHandlingConnections) while (!_stopHandlingConnections)
{ {
std::vector<std::string> tokens; std::vector<std::string> tokens;
@ -105,7 +99,6 @@ namespace ix
logInfo("Connection closed for connection id " + connectionState->getId()); logInfo("Connection closed for connection id " + connectionState->getId());
connectionState->setTerminated(); connectionState->setTerminated();
Socket::closeSocket(fd);
_connectedClientsCount--; _connectedClientsCount--;
} }

View File

@ -42,7 +42,7 @@ namespace ix
std::atomic<bool> _stopHandlingConnections; std::atomic<bool> _stopHandlingConnections;
// Methods // Methods
virtual void handleConnection(int fd, virtual void handleConnection(std::shared_ptr<Socket>,
std::shared_ptr<ConnectionState> connectionState) final; std::shared_ptr<ConnectionState> connectionState) final;
virtual size_t getConnectedClientsCount() final; virtual size_t getConnectedClientsCount() final;

View File

@ -8,7 +8,6 @@
#include "IXNetSystem.h" #include "IXNetSystem.h"
#include "IXSocketConnect.h" #include "IXSocketConnect.h"
#include "IXSocketFactory.h"
#include "IXUserAgent.h" #include "IXUserAgent.h"
#include <fstream> #include <fstream>
#include <iostream> #include <iostream>
@ -70,16 +69,11 @@ namespace ix
_onConnectionCallback = callback; _onConnectionCallback = callback;
} }
void HttpServer::handleConnection(int fd, std::shared_ptr<ConnectionState> connectionState) void HttpServer::handleConnection(std::shared_ptr<Socket> socket,
std::shared_ptr<ConnectionState> connectionState)
{ {
_connectedClientsCount++; _connectedClientsCount++;
std::string errorMsg;
auto socket = createSocket(fd, errorMsg);
// Set the socket to non blocking mode + other tweaks
SocketConnect::configure(fd);
auto ret = Http::parseRequest(socket); auto ret = Http::parseRequest(socket);
// FIXME: handle errors in parseRequest // FIXME: handle errors in parseRequest
@ -92,7 +86,6 @@ namespace ix
} }
} }
connectionState->setTerminated(); connectionState->setTerminated();
Socket::closeSocket(fd);
_connectedClientsCount--; _connectedClientsCount--;
} }

View File

@ -42,7 +42,7 @@ namespace ix
std::atomic<int> _connectedClientsCount; std::atomic<int> _connectedClientsCount;
// Methods // Methods
virtual void handleConnection(int fd, virtual void handleConnection(std::shared_ptr<Socket>,
std::shared_ptr<ConnectionState> connectionState) final; std::shared_ptr<ConnectionState> connectionState) final;
virtual size_t getConnectedClientsCount() final; virtual size_t getConnectedClientsCount() final;

View File

@ -9,6 +9,7 @@
#include "IXNetSystem.h" #include "IXNetSystem.h"
#include "IXSocket.h" #include "IXSocket.h"
#include "IXSocketConnect.h" #include "IXSocketConnect.h"
#include "IXSocketFactory.h"
#include <assert.h> #include <assert.h>
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
@ -267,11 +268,25 @@ namespace ix
if (_stop) return; if (_stop) return;
// create socket
std::string errorMsg;
auto socket = createSocket(clientFd, errorMsg);
if (socket == nullptr)
{
logError("SocketServer::run() cannot create socket: " + errorMsg);
Socket::closeSocket(clientFd);
continue;
}
// Set the socket to non blocking mode + other tweaks
SocketConnect::configure(clientFd);
// Launch the handleConnection work asynchronously in its own thread. // Launch the handleConnection work asynchronously in its own thread.
std::lock_guard<std::mutex> lock(_connectionsThreadsMutex); std::lock_guard<std::mutex> lock(_connectionsThreadsMutex);
_connectionsThreads.push_back(std::make_pair( _connectionsThreads.push_back(std::make_pair(
connectionState, connectionState,
std::thread(&SocketServer::handleConnection, this, clientFd, connectionState))); std::thread(&SocketServer::handleConnection, this, socket, connectionState)));
} }
} }

View File

@ -21,6 +21,8 @@
namespace ix namespace ix
{ {
class Socket;
class SocketServer class SocketServer
{ {
public: public:
@ -96,7 +98,8 @@ namespace ix
// the factory to create ConnectionState objects // the factory to create ConnectionState objects
ConnectionStateFactory _connectionStateFactory; ConnectionStateFactory _connectionStateFactory;
virtual void handleConnection(int fd, std::shared_ptr<ConnectionState> connectionState) = 0; virtual void handleConnection(std::shared_ptr<Socket>,
std::shared_ptr<ConnectionState> connectionState) = 0;
virtual size_t getConnectedClientsCount() = 0; virtual size_t getConnectedClientsCount() = 0;
// Returns true if all connection threads are joined // Returns true if all connection threads are joined

View File

@ -201,7 +201,8 @@ namespace ix
return status; return status;
} }
WebSocketInitResult WebSocket::connectToSocket(int fd, int timeoutSecs) WebSocketInitResult WebSocket::connectToSocket(std::shared_ptr<Socket> socket,
int timeoutSecs)
{ {
{ {
std::lock_guard<std::mutex> lock(_configMutex); std::lock_guard<std::mutex> lock(_configMutex);
@ -212,7 +213,7 @@ namespace ix
_pingTimeoutSecs); _pingTimeoutSecs);
} }
WebSocketInitResult status = _ws.connectToSocket(fd, timeoutSecs); WebSocketInitResult status = _ws.connectToSocket(socket, timeoutSecs);
if (!status.success) if (!status.success)
{ {
return status; return status;

View File

@ -113,7 +113,8 @@ namespace ix
static void invokeTrafficTrackerCallback(size_t size, bool incoming); static void invokeTrafficTrackerCallback(size_t size, bool incoming);
// Server // Server
WebSocketInitResult connectToSocket(int fd, int timeoutSecs); WebSocketInitResult connectToSocket(std::shared_ptr<Socket>,
int timeoutSecs);
WebSocketTransport _ws; WebSocketTransport _ws;

View File

@ -239,18 +239,13 @@ namespace ix
return WebSocketInitResult(true, status, "", headers, path); return WebSocketInitResult(true, status, "", headers, path);
} }
WebSocketInitResult WebSocketHandshake::serverHandshake(int fd, int timeoutSecs) WebSocketInitResult WebSocketHandshake::serverHandshake(int timeoutSecs)
{ {
_requestInitCancellation = false; _requestInitCancellation = false;
// Set the socket to non blocking mode + other tweaks
SocketConnect::configure(fd);
auto isCancellationRequested = auto isCancellationRequested =
makeCancellationRequestWithTimeout(timeoutSecs, _requestInitCancellation); makeCancellationRequestWithTimeout(timeoutSecs, _requestInitCancellation);
std::string remote = std::string("remote fd ") + std::to_string(fd);
// Read first line // Read first line
auto lineResult = _socket->readLine(isCancellationRequested); auto lineResult = _socket->readLine(isCancellationRequested);
auto lineValid = lineResult.first; auto lineValid = lineResult.first;
@ -358,7 +353,7 @@ namespace ix
if (!_socket->writeBytes(ss.str(), isCancellationRequested)) if (!_socket->writeBytes(ss.str(), isCancellationRequested))
{ {
return WebSocketInitResult( return WebSocketInitResult(
false, 0, std::string("Failed sending response to ") + remote); false, 0, std::string("Failed sending response to remote end"));
} }
return WebSocketInitResult(true, 200, "", headers, uri); return WebSocketInitResult(true, 200, "", headers, uri);

View File

@ -56,7 +56,7 @@ namespace ix
int port, int port,
int timeoutSecs); int timeoutSecs);
WebSocketInitResult serverHandshake(int fd, int timeoutSecs); WebSocketInitResult serverHandshake(int timeoutSecs);
private: private:
std::string genRandomString(const int len); std::string genRandomString(const int len);

View File

@ -63,7 +63,8 @@ namespace ix
_onConnectionCallback = callback; _onConnectionCallback = callback;
} }
void WebSocketServer::handleConnection(int fd, std::shared_ptr<ConnectionState> connectionState) void WebSocketServer::handleConnection(std::shared_ptr<Socket> socket,
std::shared_ptr<ConnectionState> connectionState)
{ {
auto webSocket = std::make_shared<WebSocket>(); auto webSocket = std::make_shared<WebSocket>();
_onConnectionCallback(webSocket, connectionState); _onConnectionCallback(webSocket, connectionState);
@ -81,7 +82,7 @@ namespace ix
_clients.insert(webSocket); _clients.insert(webSocket);
} }
auto status = webSocket->connectToSocket(fd, _handshakeTimeoutSecs); auto status = webSocket->connectToSocket(socket, _handshakeTimeoutSecs);
if (status.success) if (status.success)
{ {
// Process incoming messages and execute callbacks // Process incoming messages and execute callbacks
@ -107,8 +108,6 @@ namespace ix
logInfo("WebSocketServer::handleConnection() done"); logInfo("WebSocketServer::handleConnection() done");
connectionState->setTerminated(); connectionState->setTerminated();
Socket::closeSocket(fd);
} }
std::set<std::shared_ptr<WebSocket>> WebSocketServer::getClients() std::set<std::shared_ptr<WebSocket>> WebSocketServer::getClients()

View File

@ -55,7 +55,7 @@ namespace ix
const static bool kDefaultEnablePong; const static bool kDefaultEnablePong;
// Methods // Methods
virtual void handleConnection(int fd, virtual void handleConnection(std::shared_ptr<Socket> socket,
std::shared_ptr<ConnectionState> connectionState) final; std::shared_ptr<ConnectionState> connectionState) final;
virtual size_t getConnectedClientsCount() final; virtual size_t getConnectedClientsCount() final;
}; };

View File

@ -171,20 +171,15 @@ namespace ix
} }
// Server // Server
WebSocketInitResult WebSocketTransport::connectToSocket(int fd, int timeoutSecs) WebSocketInitResult WebSocketTransport::connectToSocket(std::shared_ptr<Socket> socket,
int timeoutSecs)
{ {
std::lock_guard<std::mutex> lock(_socketMutex); std::lock_guard<std::mutex> lock(_socketMutex);
// Server should not mask the data it sends to the client // Server should not mask the data it sends to the client
_useMask = false; _useMask = false;
std::string errorMsg; _socket = socket;
_socket = createSocket(fd, errorMsg);
if (!_socket)
{
return WebSocketInitResult(false, 0, errorMsg);
}
WebSocketHandshake webSocketHandshake(_requestInitCancellation, WebSocketHandshake webSocketHandshake(_requestInitCancellation,
_socket, _socket,
@ -192,7 +187,7 @@ namespace ix
_perMessageDeflateOptions, _perMessageDeflateOptions,
_enablePerMessageDeflate); _enablePerMessageDeflate);
auto result = webSocketHandshake.serverHandshake(fd, timeoutSecs); auto result = webSocketHandshake.serverHandshake(timeoutSecs);
if (result.success) if (result.success)
{ {
setReadyState(ReadyState::OPEN); setReadyState(ReadyState::OPEN);

View File

@ -77,11 +77,14 @@ namespace ix
int pingIntervalSecs, int pingIntervalSecs,
int pingTimeoutSecs); int pingTimeoutSecs);
WebSocketInitResult connectToUrl( // Client // Client
WebSocketInitResult connectToUrl(
const std::string& url, const std::string& url,
const WebSocketHttpHeaders& headers, const WebSocketHttpHeaders& headers,
int timeoutSecs); int timeoutSecs);
WebSocketInitResult connectToSocket(int fd, // Server
// Server
WebSocketInitResult connectToSocket(std::shared_ptr<Socket> socket,
int timeoutSecs); int timeoutSecs);
PollResult poll(); PollResult poll();