New connection state for server code + fix OpenSSL double init bug
This commit is contained in:
parent
0999073435
commit
788c92487c
@ -12,7 +12,7 @@ set (CMAKE_CXX_EXTENSIONS OFF)
|
||||
|
||||
# -Wshorten-64-to-32 does not work with clang
|
||||
if (NOT WIN32)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic -Wshorten-64-to-32")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic")
|
||||
endif()
|
||||
|
||||
if ("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang")
|
||||
@ -39,6 +39,7 @@ set( IXWEBSOCKET_SOURCES
|
||||
ixwebsocket/IXSelectInterrupt.cpp
|
||||
ixwebsocket/IXSelectInterruptPipe.cpp
|
||||
ixwebsocket/IXSelectInterruptFactory.cpp
|
||||
ixwebsocket/IXConnectionState.cpp
|
||||
)
|
||||
|
||||
set( IXWEBSOCKET_HEADERS
|
||||
@ -66,6 +67,7 @@ set( IXWEBSOCKET_HEADERS
|
||||
ixwebsocket/IXSelectInterrupt.h
|
||||
ixwebsocket/IXSelectInterruptPipe.h
|
||||
ixwebsocket/IXSelectInterruptFactory.h
|
||||
ixwebsocket/IXConnectionState.h
|
||||
)
|
||||
|
||||
# Platform specific code
|
||||
|
37
ixwebsocket/IXConnectionState.cpp
Normal file
37
ixwebsocket/IXConnectionState.cpp
Normal file
@ -0,0 +1,37 @@
|
||||
/*
|
||||
* IXConnectionState.cpp
|
||||
* Author: Benjamin Sergeant
|
||||
* Copyright (c) 2019 Machine Zone, Inc. All rights reserved.
|
||||
*/
|
||||
|
||||
#include "IXConnectionState.h"
|
||||
|
||||
#include <sstream>
|
||||
|
||||
namespace ix
|
||||
{
|
||||
std::atomic<uint64_t> ConnectionState::_globalId(0);
|
||||
|
||||
ConnectionState::ConnectionState()
|
||||
{
|
||||
computeId();
|
||||
}
|
||||
|
||||
void ConnectionState::computeId()
|
||||
{
|
||||
std::stringstream ss;
|
||||
ss << _globalId++;
|
||||
_id = ss.str();
|
||||
}
|
||||
|
||||
const std::string& ConnectionState::getId() const
|
||||
{
|
||||
return _id;
|
||||
}
|
||||
|
||||
std::shared_ptr<ConnectionState> ConnectionState::createConnectionState()
|
||||
{
|
||||
return std::make_shared<ConnectionState>();
|
||||
}
|
||||
}
|
||||
|
33
ixwebsocket/IXConnectionState.h
Normal file
33
ixwebsocket/IXConnectionState.h
Normal file
@ -0,0 +1,33 @@
|
||||
/*
|
||||
* IXConnectionState.h
|
||||
* Author: Benjamin Sergeant
|
||||
* Copyright (c) 2019 Machine Zone, Inc. All rights reserved.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <stdint.h>
|
||||
#include <string>
|
||||
#include <atomic>
|
||||
#include <memory>
|
||||
|
||||
namespace ix
|
||||
{
|
||||
class ConnectionState {
|
||||
public:
|
||||
ConnectionState();
|
||||
virtual ~ConnectionState() = default;
|
||||
|
||||
virtual void computeId();
|
||||
virtual const std::string& getId() const;
|
||||
|
||||
static std::shared_ptr<ConnectionState> createConnectionState();
|
||||
|
||||
protected:
|
||||
std::string _id;
|
||||
|
||||
static std::atomic<uint64_t> _globalId;
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -21,6 +21,7 @@
|
||||
namespace ix
|
||||
{
|
||||
std::atomic<bool> SocketOpenSSL::_openSSLInitializationSuccessful(false);
|
||||
std::once_flag SocketOpenSSL::_openSSLInitFlag;
|
||||
|
||||
SocketOpenSSL::SocketOpenSSL(int fd) : Socket(fd),
|
||||
_ssl_connection(nullptr),
|
||||
|
@ -50,7 +50,7 @@ namespace ix
|
||||
const SSL_METHOD* _ssl_method;
|
||||
mutable std::mutex _mutex; // OpenSSL routines are not thread-safe
|
||||
|
||||
std::once_flag _openSSLInitFlag;
|
||||
static std::once_flag _openSSLInitFlag;
|
||||
static std::atomic<bool> _openSSLInitializationSuccessful;
|
||||
};
|
||||
|
||||
|
@ -29,7 +29,8 @@ namespace ix
|
||||
_host(host),
|
||||
_backlog(backlog),
|
||||
_maxConnections(maxConnections),
|
||||
_stop(false)
|
||||
_stop(false),
|
||||
_connectionStateFactory(&ConnectionState::createConnectionState)
|
||||
{
|
||||
|
||||
}
|
||||
@ -145,6 +146,12 @@ namespace ix
|
||||
::close(_serverFd);
|
||||
}
|
||||
|
||||
void SocketServer::setConnectionStateFactory(
|
||||
const ConnectionStateFactory& connectionStateFactory)
|
||||
{
|
||||
_connectionStateFactory = connectionStateFactory;
|
||||
}
|
||||
|
||||
void SocketServer::run()
|
||||
{
|
||||
// Set the socket to non blocking mode, so that accept calls are not blocking
|
||||
@ -214,6 +221,12 @@ namespace ix
|
||||
continue;
|
||||
}
|
||||
|
||||
std::shared_ptr<ConnectionState> connectionState;
|
||||
if (_connectionStateFactory)
|
||||
{
|
||||
connectionState = _connectionStateFactory();
|
||||
}
|
||||
|
||||
// Launch the handleConnection work asynchronously in its own thread.
|
||||
//
|
||||
// the destructor of a future returned by std::async blocks,
|
||||
@ -221,7 +234,8 @@ namespace ix
|
||||
f = std::async(std::launch::async,
|
||||
&SocketServer::handleConnection,
|
||||
this,
|
||||
clientFd);
|
||||
clientFd,
|
||||
connectionState);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -6,6 +6,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "IXConnectionState.h"
|
||||
|
||||
#include <utility> // pair
|
||||
#include <string>
|
||||
#include <set>
|
||||
@ -20,6 +22,8 @@ namespace ix
|
||||
{
|
||||
class SocketServer {
|
||||
public:
|
||||
using ConnectionStateFactory = std::function<std::shared_ptr<ConnectionState>()>;
|
||||
|
||||
SocketServer(int port = SocketServer::kDefaultPort,
|
||||
const std::string& host = SocketServer::kDefaultHost,
|
||||
int backlog = SocketServer::kDefaultTcpBacklog,
|
||||
@ -27,6 +31,8 @@ namespace ix
|
||||
virtual ~SocketServer();
|
||||
virtual void stop();
|
||||
|
||||
void setConnectionStateFactory(const ConnectionStateFactory& connectionStateFactory);
|
||||
|
||||
const static int kDefaultPort;
|
||||
const static std::string kDefaultHost;
|
||||
const static int kDefaultTcpBacklog;
|
||||
@ -60,9 +66,13 @@ namespace ix
|
||||
std::condition_variable _conditionVariable;
|
||||
std::mutex _conditionVariableMutex;
|
||||
|
||||
//
|
||||
ConnectionStateFactory _connectionStateFactory;
|
||||
|
||||
// Methods
|
||||
void run();
|
||||
virtual void handleConnection(int fd) = 0;
|
||||
virtual void handleConnection(int fd,
|
||||
std::shared_ptr<ConnectionState> connectionState) = 0;
|
||||
virtual size_t getConnectedClientsCount() = 0;
|
||||
};
|
||||
}
|
||||
|
@ -49,10 +49,12 @@ namespace ix
|
||||
_onConnectionCallback = callback;
|
||||
}
|
||||
|
||||
void WebSocketServer::handleConnection(int fd)
|
||||
void WebSocketServer::handleConnection(
|
||||
int fd,
|
||||
std::shared_ptr<ConnectionState> connectionState)
|
||||
{
|
||||
auto webSocket = std::make_shared<WebSocket>();
|
||||
_onConnectionCallback(webSocket);
|
||||
_onConnectionCallback(webSocket, connectionState);
|
||||
|
||||
webSocket->disableAutomaticReconnection();
|
||||
|
||||
|
@ -20,7 +20,8 @@
|
||||
|
||||
namespace ix
|
||||
{
|
||||
using OnConnectionCallback = std::function<void(std::shared_ptr<WebSocket>)>;
|
||||
using OnConnectionCallback = std::function<void(std::shared_ptr<WebSocket>,
|
||||
std::shared_ptr<ConnectionState>)>;
|
||||
|
||||
class WebSocketServer : public SocketServer {
|
||||
public:
|
||||
@ -49,7 +50,8 @@ namespace ix
|
||||
const static int kDefaultHandShakeTimeoutSecs;
|
||||
|
||||
// Methods
|
||||
virtual void handleConnection(int fd) final;
|
||||
virtual void handleConnection(int fd,
|
||||
std::shared_ptr<ConnectionState> connectionState) final;
|
||||
virtual size_t getConnectedClientsCount() final;
|
||||
};
|
||||
}
|
||||
|
@ -128,10 +128,11 @@ namespace
|
||||
{
|
||||
// A dev/null server
|
||||
server.setOnConnectionCallback(
|
||||
[&server, &receivedPingMessages](std::shared_ptr<ix::WebSocket> webSocket)
|
||||
[&server, &receivedPingMessages](std::shared_ptr<ix::WebSocket> webSocket,
|
||||
std::shared_ptr<ConnectionState> connectionState)
|
||||
{
|
||||
webSocket->setOnMessageCallback(
|
||||
[webSocket, &server, &receivedPingMessages](ix::WebSocketMessageType messageType,
|
||||
[webSocket, connectionState, &server, &receivedPingMessages](ix::WebSocketMessageType messageType,
|
||||
const std::string& str,
|
||||
size_t wireSize,
|
||||
const ix::WebSocketErrorInfo& error,
|
||||
@ -141,6 +142,7 @@ namespace
|
||||
if (messageType == ix::WebSocket_MessageType_Open)
|
||||
{
|
||||
Logger() << "New server connection";
|
||||
Logger() << "id: " << connectionState->getId();
|
||||
Logger() << "Uri: " << openInfo.uri;
|
||||
Logger() << "Headers:";
|
||||
for (auto it : openInfo.headers)
|
||||
|
@ -18,13 +18,32 @@ using namespace ix;
|
||||
|
||||
namespace ix
|
||||
{
|
||||
bool startServer(ix::WebSocketServer& server)
|
||||
// Test that we can override the connectionState impl to provide our own
|
||||
class ConnectionStateCustom : public ConnectionState
|
||||
{
|
||||
void computeId()
|
||||
{
|
||||
// a very boring invariant id that we can test against in the unittest
|
||||
_id = "foobarConnectionId";
|
||||
}
|
||||
};
|
||||
|
||||
bool startServer(ix::WebSocketServer& server,
|
||||
std::string& connectionId)
|
||||
{
|
||||
auto factory = []() -> std::shared_ptr<ConnectionState>
|
||||
{
|
||||
return std::make_shared<ConnectionStateCustom>();
|
||||
};
|
||||
server.setConnectionStateFactory(factory);
|
||||
|
||||
server.setOnConnectionCallback(
|
||||
[&server](std::shared_ptr<ix::WebSocket> webSocket)
|
||||
[&server, &connectionId](std::shared_ptr<ix::WebSocket> webSocket,
|
||||
std::shared_ptr<ConnectionState> connectionState)
|
||||
{
|
||||
webSocket->setOnMessageCallback(
|
||||
[webSocket, &server](ix::WebSocketMessageType messageType,
|
||||
[webSocket, connectionState,
|
||||
&connectionId, &server](ix::WebSocketMessageType messageType,
|
||||
const std::string& str,
|
||||
size_t wireSize,
|
||||
const ix::WebSocketErrorInfo& error,
|
||||
@ -34,12 +53,15 @@ namespace ix
|
||||
if (messageType == ix::WebSocket_MessageType_Open)
|
||||
{
|
||||
Logger() << "New connection";
|
||||
Logger() << "id: " << connectionState->getId();
|
||||
Logger() << "Uri: " << openInfo.uri;
|
||||
Logger() << "Headers:";
|
||||
for (auto it : openInfo.headers)
|
||||
{
|
||||
Logger() << it.first << ": " << it.second;
|
||||
}
|
||||
|
||||
connectionId = connectionState->getId();
|
||||
}
|
||||
else if (messageType == ix::WebSocket_MessageType_Close)
|
||||
{
|
||||
@ -78,7 +100,8 @@ TEST_CASE("Websocket_server", "[websocket_server]")
|
||||
{
|
||||
int port = getFreePort();
|
||||
ix::WebSocketServer server(port);
|
||||
REQUIRE(startServer(server));
|
||||
std::string connectionId;
|
||||
REQUIRE(startServer(server, connectionId));
|
||||
|
||||
std::string errMsg;
|
||||
bool tls = false;
|
||||
@ -111,7 +134,8 @@ TEST_CASE("Websocket_server", "[websocket_server]")
|
||||
{
|
||||
int port = getFreePort();
|
||||
ix::WebSocketServer server(port);
|
||||
REQUIRE(startServer(server));
|
||||
std::string connectionId;
|
||||
REQUIRE(startServer(server, connectionId));
|
||||
|
||||
std::string errMsg;
|
||||
bool tls = false;
|
||||
@ -147,7 +171,8 @@ TEST_CASE("Websocket_server", "[websocket_server]")
|
||||
{
|
||||
int port = getFreePort();
|
||||
ix::WebSocketServer server(port);
|
||||
REQUIRE(startServer(server));
|
||||
std::string connectionId;
|
||||
REQUIRE(startServer(server, connectionId));
|
||||
|
||||
std::string errMsg;
|
||||
bool tls = false;
|
||||
@ -178,6 +203,8 @@ TEST_CASE("Websocket_server", "[websocket_server]")
|
||||
// Give us 500ms for the server to notice that clients went away
|
||||
ix::msleep(500);
|
||||
|
||||
REQUIRE(connectionId == "foobarConnectionId");
|
||||
|
||||
server.stop();
|
||||
REQUIRE(server.getClients().size() == 0);
|
||||
}
|
||||
|
@ -217,10 +217,11 @@ namespace
|
||||
bool startServer(ix::WebSocketServer& server)
|
||||
{
|
||||
server.setOnConnectionCallback(
|
||||
[&server](std::shared_ptr<ix::WebSocket> webSocket)
|
||||
[&server](std::shared_ptr<ix::WebSocket> webSocket,
|
||||
std::shared_ptr<ConnectionState> connectionState)
|
||||
{
|
||||
webSocket->setOnMessageCallback(
|
||||
[webSocket, &server](ix::WebSocketMessageType messageType,
|
||||
[webSocket, connectionState, &server](ix::WebSocketMessageType messageType,
|
||||
const std::string& str,
|
||||
size_t wireSize,
|
||||
const ix::WebSocketErrorInfo& error,
|
||||
@ -230,6 +231,7 @@ namespace
|
||||
if (messageType == ix::WebSocket_MessageType_Open)
|
||||
{
|
||||
Logger() << "New connection";
|
||||
Logger() << "id: " << connectionState->getId();
|
||||
Logger() << "Uri: " << openInfo.uri;
|
||||
Logger() << "Headers:";
|
||||
for (auto it : openInfo.headers)
|
||||
|
@ -114,7 +114,7 @@ namespace ix
|
||||
return false;
|
||||
}
|
||||
|
||||
// The first line of the response describe the return type,
|
||||
// The first line of the response describe the return type,
|
||||
// => *3 (an array of 3 elements)
|
||||
auto lineResult = _socket->readLine(nullptr);
|
||||
auto lineValid = lineResult.first;
|
||||
|
@ -17,10 +17,11 @@ namespace ix
|
||||
ix::WebSocketServer server(port, hostname);
|
||||
|
||||
server.setOnConnectionCallback(
|
||||
[&server](std::shared_ptr<ix::WebSocket> webSocket)
|
||||
[&server](std::shared_ptr<WebSocket> webSocket,
|
||||
std::shared_ptr<ConnectionState> connectionState)
|
||||
{
|
||||
webSocket->setOnMessageCallback(
|
||||
[webSocket, &server](ix::WebSocketMessageType messageType,
|
||||
[webSocket, connectionState, &server](ix::WebSocketMessageType messageType,
|
||||
const std::string& str,
|
||||
size_t wireSize,
|
||||
const ix::WebSocketErrorInfo& error,
|
||||
@ -30,6 +31,7 @@ namespace ix
|
||||
if (messageType == ix::WebSocket_MessageType_Open)
|
||||
{
|
||||
std::cerr << "New connection" << std::endl;
|
||||
std::cerr << "id: " << connectionState->getId() << std::endl;
|
||||
std::cerr << "Uri: " << openInfo.uri << std::endl;
|
||||
std::cerr << "Headers:" << std::endl;
|
||||
for (auto it : openInfo.headers)
|
||||
|
@ -17,10 +17,11 @@ namespace ix
|
||||
ix::WebSocketServer server(port, hostname);
|
||||
|
||||
server.setOnConnectionCallback(
|
||||
[](std::shared_ptr<ix::WebSocket> webSocket)
|
||||
[](std::shared_ptr<ix::WebSocket> webSocket,
|
||||
std::shared_ptr<ConnectionState> connectionState)
|
||||
{
|
||||
webSocket->setOnMessageCallback(
|
||||
[webSocket](ix::WebSocketMessageType messageType,
|
||||
[webSocket, connectionState](ix::WebSocketMessageType messageType,
|
||||
const std::string& str,
|
||||
size_t wireSize,
|
||||
const ix::WebSocketErrorInfo& error,
|
||||
@ -30,6 +31,7 @@ namespace ix
|
||||
if (messageType == ix::WebSocket_MessageType_Open)
|
||||
{
|
||||
std::cerr << "New connection" << std::endl;
|
||||
std::cerr << "id: " << connectionState->getId() << std::endl;
|
||||
std::cerr << "Uri: " << openInfo.uri << std::endl;
|
||||
std::cerr << "Headers:" << std::endl;
|
||||
for (auto it : openInfo.headers)
|
||||
|
@ -17,10 +17,11 @@ namespace ix
|
||||
ix::WebSocketServer server(port, hostname);
|
||||
|
||||
server.setOnConnectionCallback(
|
||||
[&server](std::shared_ptr<ix::WebSocket> webSocket)
|
||||
[&server](std::shared_ptr<ix::WebSocket> webSocket,
|
||||
std::shared_ptr<ConnectionState> connectionState)
|
||||
{
|
||||
webSocket->setOnMessageCallback(
|
||||
[webSocket, &server](ix::WebSocketMessageType messageType,
|
||||
[webSocket, connectionState, &server](ix::WebSocketMessageType messageType,
|
||||
const std::string& str,
|
||||
size_t wireSize,
|
||||
const ix::WebSocketErrorInfo& error,
|
||||
@ -30,6 +31,7 @@ namespace ix
|
||||
if (messageType == ix::WebSocket_MessageType_Open)
|
||||
{
|
||||
std::cerr << "New connection" << std::endl;
|
||||
std::cerr << "id: " << connectionState->getId() << std::endl;
|
||||
std::cerr << "Uri: " << openInfo.uri << std::endl;
|
||||
std::cerr << "Headers:" << std::endl;
|
||||
for (auto it : openInfo.headers)
|
||||
|
Loading…
Reference in New Issue
Block a user