diff --git a/CMakeLists.txt b/CMakeLists.txt index 864818f6..31f20a40 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -39,6 +39,7 @@ set( IXWEBSOCKET_SOURCES ixwebsocket/IXSelectInterrupt.cpp ixwebsocket/IXSelectInterruptPipe.cpp ixwebsocket/IXSelectInterruptFactory.cpp + ixwebsocket/IXConnectionState.cpp ) set( IXWEBSOCKET_HEADERS @@ -66,6 +67,7 @@ set( IXWEBSOCKET_HEADERS ixwebsocket/IXSelectInterrupt.h ixwebsocket/IXSelectInterruptPipe.h ixwebsocket/IXSelectInterruptFactory.h + ixwebsocket/IXConnectionState.h ) # Platform specific code diff --git a/README.md b/README.md index 09d62013..519525af 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ communication channels over a single TCP connection. *IXWebSocket* is a C++ libr * macOS * iOS * Linux -* Android +* Android ## Examples @@ -63,10 +63,11 @@ Here is what the server API looks like. Note that server support is very recent ix::WebSocketServer server(port); server.setOnConnectionCallback( - [&server](std::shared_ptr webSocket) + [&server](std::shared_ptr webSocket, + std::shared_ptr connectionState) { webSocket->setOnMessageCallback( - [webSocket, &server](ix::WebSocketMessageType messageType, + [webSocket, connectionState, &server](ix::WebSocketMessageType messageType, const std::string& str, size_t wireSize, const ix::WebSocketErrorInfo& error, @@ -77,6 +78,12 @@ server.setOnConnectionCallback( { std::cerr << "New connection" << std::endl; + // A connection state object is available, and has a default id + // You can subclass ConnectionState and pass an alternate factory + // to override it. It is useful if you want to store custom + // attributes per connection (authenticated bool flag, attributes, etc...) + std::cerr << "id: " << connectionState->getId() << std::endl; + // The uri the client did connect to. std::cerr << "Uri: " << openInfo.uri << std::endl; @@ -223,13 +230,13 @@ Here is a simplistic diagram which explains how the code is structured in term o +-----------------------+ --- Public | | Start the receiving Background thread. Auto reconnection. Simple websocket Ping. | IXWebSocket | Interface used by C++ test clients. No IX dependencies. -| | +| | +-----------------------+ | | | IXWebSocketServer | Run a server and give each connections its own WebSocket object. | | Each connection is handled in a new OS thread. | | -+-----------------------+ --- Private ++-----------------------+ --- Private | | | IXWebSocketTransport | Low level websocket code, framing, managing raw socket. Adapted from easywsclient. | | diff --git a/ixwebsocket/IXConnectionState.cpp b/ixwebsocket/IXConnectionState.cpp new file mode 100644 index 00000000..18a575e0 --- /dev/null +++ b/ixwebsocket/IXConnectionState.cpp @@ -0,0 +1,37 @@ +/* + * IXConnectionState.cpp + * Author: Benjamin Sergeant + * Copyright (c) 2019 Machine Zone, Inc. All rights reserved. + */ + +#include "IXConnectionState.h" + +#include + +namespace ix +{ + std::atomic ConnectionState::_globalId(0); + + ConnectionState::ConnectionState() + { + computeId(); + } + + void ConnectionState::computeId() + { + std::stringstream ss; + ss << _globalId++; + _id = ss.str(); + } + + const std::string& ConnectionState::getId() const + { + return _id; + } + + std::shared_ptr ConnectionState::createConnectionState() + { + return std::make_shared(); + } +} + diff --git a/ixwebsocket/IXConnectionState.h b/ixwebsocket/IXConnectionState.h new file mode 100644 index 00000000..0c5d9920 --- /dev/null +++ b/ixwebsocket/IXConnectionState.h @@ -0,0 +1,33 @@ +/* + * IXConnectionState.h + * Author: Benjamin Sergeant + * Copyright (c) 2019 Machine Zone, Inc. All rights reserved. + */ + +#pragma once + +#include +#include +#include +#include + +namespace ix +{ + class ConnectionState { + public: + ConnectionState(); + virtual ~ConnectionState() = default; + + virtual void computeId(); + virtual const std::string& getId() const; + + static std::shared_ptr createConnectionState(); + + protected: + std::string _id; + + static std::atomic _globalId; + }; +} + + diff --git a/ixwebsocket/IXSocketOpenSSL.cpp b/ixwebsocket/IXSocketOpenSSL.cpp index bc3d9774..3e6b14fc 100644 --- a/ixwebsocket/IXSocketOpenSSL.cpp +++ b/ixwebsocket/IXSocketOpenSSL.cpp @@ -21,6 +21,7 @@ namespace ix { std::atomic SocketOpenSSL::_openSSLInitializationSuccessful(false); + std::once_flag SocketOpenSSL::_openSSLInitFlag; SocketOpenSSL::SocketOpenSSL(int fd) : Socket(fd), _ssl_connection(nullptr), diff --git a/ixwebsocket/IXSocketOpenSSL.h b/ixwebsocket/IXSocketOpenSSL.h index ac0d18a6..9fa55a5a 100644 --- a/ixwebsocket/IXSocketOpenSSL.h +++ b/ixwebsocket/IXSocketOpenSSL.h @@ -50,7 +50,7 @@ namespace ix const SSL_METHOD* _ssl_method; mutable std::mutex _mutex; // OpenSSL routines are not thread-safe - std::once_flag _openSSLInitFlag; + static std::once_flag _openSSLInitFlag; static std::atomic _openSSLInitializationSuccessful; }; diff --git a/ixwebsocket/IXSocketServer.cpp b/ixwebsocket/IXSocketServer.cpp index 3f8beef4..c77c76be 100644 --- a/ixwebsocket/IXSocketServer.cpp +++ b/ixwebsocket/IXSocketServer.cpp @@ -29,7 +29,8 @@ namespace ix _host(host), _backlog(backlog), _maxConnections(maxConnections), - _stop(false) + _stop(false), + _connectionStateFactory(&ConnectionState::createConnectionState) { } @@ -145,6 +146,12 @@ namespace ix ::close(_serverFd); } + void SocketServer::setConnectionStateFactory( + const ConnectionStateFactory& connectionStateFactory) + { + _connectionStateFactory = connectionStateFactory; + } + void SocketServer::run() { // Set the socket to non blocking mode, so that accept calls are not blocking @@ -214,6 +221,12 @@ namespace ix continue; } + std::shared_ptr connectionState; + if (_connectionStateFactory) + { + connectionState = _connectionStateFactory(); + } + // Launch the handleConnection work asynchronously in its own thread. // // the destructor of a future returned by std::async blocks, @@ -221,7 +234,8 @@ namespace ix f = std::async(std::launch::async, &SocketServer::handleConnection, this, - clientFd); + clientFd, + connectionState); } } } diff --git a/ixwebsocket/IXSocketServer.h b/ixwebsocket/IXSocketServer.h index 5fea37d9..6f9b1998 100644 --- a/ixwebsocket/IXSocketServer.h +++ b/ixwebsocket/IXSocketServer.h @@ -6,6 +6,8 @@ #pragma once +#include "IXConnectionState.h" + #include // pair #include #include @@ -20,6 +22,8 @@ namespace ix { class SocketServer { public: + using ConnectionStateFactory = std::function()>; + SocketServer(int port = SocketServer::kDefaultPort, const std::string& host = SocketServer::kDefaultHost, int backlog = SocketServer::kDefaultTcpBacklog, @@ -27,6 +31,8 @@ namespace ix virtual ~SocketServer(); virtual void stop(); + void setConnectionStateFactory(const ConnectionStateFactory& connectionStateFactory); + const static int kDefaultPort; const static std::string kDefaultHost; const static int kDefaultTcpBacklog; @@ -60,9 +66,13 @@ namespace ix std::condition_variable _conditionVariable; std::mutex _conditionVariableMutex; + // + ConnectionStateFactory _connectionStateFactory; + // Methods void run(); - virtual void handleConnection(int fd) = 0; + virtual void handleConnection(int fd, + std::shared_ptr connectionState) = 0; virtual size_t getConnectedClientsCount() = 0; }; } diff --git a/ixwebsocket/IXWebSocketServer.cpp b/ixwebsocket/IXWebSocketServer.cpp index 3ffad19b..ffffd09d 100644 --- a/ixwebsocket/IXWebSocketServer.cpp +++ b/ixwebsocket/IXWebSocketServer.cpp @@ -49,10 +49,12 @@ namespace ix _onConnectionCallback = callback; } - void WebSocketServer::handleConnection(int fd) + void WebSocketServer::handleConnection( + int fd, + std::shared_ptr connectionState) { auto webSocket = std::make_shared(); - _onConnectionCallback(webSocket); + _onConnectionCallback(webSocket, connectionState); webSocket->disableAutomaticReconnection(); diff --git a/ixwebsocket/IXWebSocketServer.h b/ixwebsocket/IXWebSocketServer.h index 3ad4976a..4964b15f 100644 --- a/ixwebsocket/IXWebSocketServer.h +++ b/ixwebsocket/IXWebSocketServer.h @@ -20,7 +20,8 @@ namespace ix { - using OnConnectionCallback = std::function)>; + using OnConnectionCallback = std::function, + std::shared_ptr)>; class WebSocketServer : public SocketServer { public: @@ -49,7 +50,8 @@ namespace ix const static int kDefaultHandShakeTimeoutSecs; // Methods - virtual void handleConnection(int fd) final; + virtual void handleConnection(int fd, + std::shared_ptr connectionState) final; virtual size_t getConnectedClientsCount() final; }; } diff --git a/ixwebsocket/IXWebSocketTransport.h b/ixwebsocket/IXWebSocketTransport.h index 45b17944..73b7a350 100644 --- a/ixwebsocket/IXWebSocketTransport.h +++ b/ixwebsocket/IXWebSocketTransport.h @@ -148,7 +148,7 @@ namespace ix mutable std::mutex _lastSendTimePointMutex; std::chrono::time_point _lastSendTimePoint; - // No data was send through the socket for longer that the heartbeat period + // No data was send through the socket for longer than the heartbeat period bool heartBeatPeriodExceeded(); void sendOnSocket(); diff --git a/test/IXWebSocketHeartBeatTest.cpp b/test/IXWebSocketHeartBeatTest.cpp index 52e01b60..f6425563 100644 --- a/test/IXWebSocketHeartBeatTest.cpp +++ b/test/IXWebSocketHeartBeatTest.cpp @@ -128,10 +128,11 @@ namespace { // A dev/null server server.setOnConnectionCallback( - [&server, &receivedPingMessages](std::shared_ptr webSocket) + [&server, &receivedPingMessages](std::shared_ptr webSocket, + std::shared_ptr connectionState) { webSocket->setOnMessageCallback( - [webSocket, &server, &receivedPingMessages](ix::WebSocketMessageType messageType, + [webSocket, connectionState, &server, &receivedPingMessages](ix::WebSocketMessageType messageType, const std::string& str, size_t wireSize, const ix::WebSocketErrorInfo& error, @@ -141,6 +142,7 @@ namespace if (messageType == ix::WebSocket_MessageType_Open) { Logger() << "New server connection"; + Logger() << "id: " << connectionState->getId(); Logger() << "Uri: " << openInfo.uri; Logger() << "Headers:"; for (auto it : openInfo.headers) diff --git a/test/IXWebSocketServerTest.cpp b/test/IXWebSocketServerTest.cpp index ab7d49e4..d80b5aa2 100644 --- a/test/IXWebSocketServerTest.cpp +++ b/test/IXWebSocketServerTest.cpp @@ -18,13 +18,32 @@ using namespace ix; namespace ix { - bool startServer(ix::WebSocketServer& server) + // Test that we can override the connectionState impl to provide our own + class ConnectionStateCustom : public ConnectionState { + void computeId() + { + // a very boring invariant id that we can test against in the unittest + _id = "foobarConnectionId"; + } + }; + + bool startServer(ix::WebSocketServer& server, + std::string& connectionId) + { + auto factory = []() -> std::shared_ptr + { + return std::make_shared(); + }; + server.setConnectionStateFactory(factory); + server.setOnConnectionCallback( - [&server](std::shared_ptr webSocket) + [&server, &connectionId](std::shared_ptr webSocket, + std::shared_ptr connectionState) { webSocket->setOnMessageCallback( - [webSocket, &server](ix::WebSocketMessageType messageType, + [webSocket, connectionState, + &connectionId, &server](ix::WebSocketMessageType messageType, const std::string& str, size_t wireSize, const ix::WebSocketErrorInfo& error, @@ -34,12 +53,15 @@ namespace ix if (messageType == ix::WebSocket_MessageType_Open) { Logger() << "New connection"; + Logger() << "id: " << connectionState->getId(); Logger() << "Uri: " << openInfo.uri; Logger() << "Headers:"; for (auto it : openInfo.headers) { Logger() << it.first << ": " << it.second; } + + connectionId = connectionState->getId(); } else if (messageType == ix::WebSocket_MessageType_Close) { @@ -78,7 +100,8 @@ TEST_CASE("Websocket_server", "[websocket_server]") { int port = getFreePort(); ix::WebSocketServer server(port); - REQUIRE(startServer(server)); + std::string connectionId; + REQUIRE(startServer(server, connectionId)); std::string errMsg; bool tls = false; @@ -111,7 +134,8 @@ TEST_CASE("Websocket_server", "[websocket_server]") { int port = getFreePort(); ix::WebSocketServer server(port); - REQUIRE(startServer(server)); + std::string connectionId; + REQUIRE(startServer(server, connectionId)); std::string errMsg; bool tls = false; @@ -147,7 +171,8 @@ TEST_CASE("Websocket_server", "[websocket_server]") { int port = getFreePort(); ix::WebSocketServer server(port); - REQUIRE(startServer(server)); + std::string connectionId; + REQUIRE(startServer(server, connectionId)); std::string errMsg; bool tls = false; @@ -178,6 +203,8 @@ TEST_CASE("Websocket_server", "[websocket_server]") // Give us 500ms for the server to notice that clients went away ix::msleep(500); + REQUIRE(connectionId == "foobarConnectionId"); + server.stop(); REQUIRE(server.getClients().size() == 0); } diff --git a/test/cmd_websocket_chat.cpp b/test/cmd_websocket_chat.cpp index e8dc177d..2907edc8 100644 --- a/test/cmd_websocket_chat.cpp +++ b/test/cmd_websocket_chat.cpp @@ -217,10 +217,11 @@ namespace bool startServer(ix::WebSocketServer& server) { server.setOnConnectionCallback( - [&server](std::shared_ptr webSocket) + [&server](std::shared_ptr webSocket, + std::shared_ptr connectionState) { webSocket->setOnMessageCallback( - [webSocket, &server](ix::WebSocketMessageType messageType, + [webSocket, connectionState, &server](ix::WebSocketMessageType messageType, const std::string& str, size_t wireSize, const ix::WebSocketErrorInfo& error, @@ -230,6 +231,7 @@ namespace if (messageType == ix::WebSocket_MessageType_Open) { Logger() << "New connection"; + Logger() << "id: " << connectionState->getId(); Logger() << "Uri: " << openInfo.uri; Logger() << "Headers:"; for (auto it : openInfo.headers) diff --git a/third_party/remote_trailing_whitespaces.sh b/third_party/remote_trailing_whitespaces.sh index 7255d50a..14127299 100644 --- a/third_party/remote_trailing_whitespaces.sh +++ b/third_party/remote_trailing_whitespaces.sh @@ -1,2 +1,3 @@ find . -type f -name '*.cpp' -exec sed -i '' 's/[[:space:]]*$//' {} \+ find . -type f -name '*.h' -exec sed -i '' 's/[[:space:]]*$//' {} \+ +find . -type f -name '*.md' -exec sed -i '' 's/[[:space:]]*$//' {} \+ diff --git a/ws/IXRedisClient.cpp b/ws/IXRedisClient.cpp index 95a1ad32..d023c08b 100644 --- a/ws/IXRedisClient.cpp +++ b/ws/IXRedisClient.cpp @@ -114,7 +114,7 @@ namespace ix return false; } - // The first line of the response describe the return type, + // The first line of the response describe the return type, // => *3 (an array of 3 elements) auto lineResult = _socket->readLine(nullptr); auto lineValid = lineResult.first; diff --git a/ws/README.md b/ws/README.md index bb1395e1..909e1438 100644 --- a/ws/README.md +++ b/ws/README.md @@ -29,7 +29,7 @@ Subcommands: ws transfer # running on port 8080. # Start receiver first -ws receive ws://localhost:8080 +ws receive ws://localhost:8080 # Then send a file. File will be received and written to disk by the receiver process ws send ws://localhost:8080 /file/to/path diff --git a/ws/ws_broadcast_server.cpp b/ws/ws_broadcast_server.cpp index 12e7d068..5473c376 100644 --- a/ws/ws_broadcast_server.cpp +++ b/ws/ws_broadcast_server.cpp @@ -17,10 +17,11 @@ namespace ix ix::WebSocketServer server(port, hostname); server.setOnConnectionCallback( - [&server](std::shared_ptr webSocket) + [&server](std::shared_ptr webSocket, + std::shared_ptr connectionState) { webSocket->setOnMessageCallback( - [webSocket, &server](ix::WebSocketMessageType messageType, + [webSocket, connectionState, &server](ix::WebSocketMessageType messageType, const std::string& str, size_t wireSize, const ix::WebSocketErrorInfo& error, @@ -30,6 +31,7 @@ namespace ix if (messageType == ix::WebSocket_MessageType_Open) { std::cerr << "New connection" << std::endl; + std::cerr << "id: " << connectionState->getId() << std::endl; std::cerr << "Uri: " << openInfo.uri << std::endl; std::cerr << "Headers:" << std::endl; for (auto it : openInfo.headers) diff --git a/ws/ws_echo_server.cpp b/ws/ws_echo_server.cpp index ac778196..4ab0add4 100644 --- a/ws/ws_echo_server.cpp +++ b/ws/ws_echo_server.cpp @@ -17,10 +17,11 @@ namespace ix ix::WebSocketServer server(port, hostname); server.setOnConnectionCallback( - [](std::shared_ptr webSocket) + [](std::shared_ptr webSocket, + std::shared_ptr connectionState) { webSocket->setOnMessageCallback( - [webSocket](ix::WebSocketMessageType messageType, + [webSocket, connectionState](ix::WebSocketMessageType messageType, const std::string& str, size_t wireSize, const ix::WebSocketErrorInfo& error, @@ -30,6 +31,7 @@ namespace ix if (messageType == ix::WebSocket_MessageType_Open) { std::cerr << "New connection" << std::endl; + std::cerr << "id: " << connectionState->getId() << std::endl; std::cerr << "Uri: " << openInfo.uri << std::endl; std::cerr << "Headers:" << std::endl; for (auto it : openInfo.headers) diff --git a/ws/ws_transfer.cpp b/ws/ws_transfer.cpp index e4564d39..58646010 100644 --- a/ws/ws_transfer.cpp +++ b/ws/ws_transfer.cpp @@ -17,10 +17,11 @@ namespace ix ix::WebSocketServer server(port, hostname); server.setOnConnectionCallback( - [&server](std::shared_ptr webSocket) + [&server](std::shared_ptr webSocket, + std::shared_ptr connectionState) { webSocket->setOnMessageCallback( - [webSocket, &server](ix::WebSocketMessageType messageType, + [webSocket, connectionState, &server](ix::WebSocketMessageType messageType, const std::string& str, size_t wireSize, const ix::WebSocketErrorInfo& error, @@ -30,6 +31,7 @@ namespace ix if (messageType == ix::WebSocket_MessageType_Open) { std::cerr << "New connection" << std::endl; + std::cerr << "id: " << connectionState->getId() << std::endl; std::cerr << "Uri: " << openInfo.uri << std::endl; std::cerr << "Headers:" << std::endl; for (auto it : openInfo.headers)