diff --git a/README.md b/README.md index 8da9d4a1..cf002213 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,11 @@ Here is what the client API looks like. ix::WebSocket webSocket; std::string url("ws://localhost:8080/"); -webSocket.configure(url); +webSocket.setUrl(url); + +// Optional heart beat, sent every 45 seconds when there isn't any traffic +// to make sure that load balancers do not kill an idle connection. +webSocket.setHeartBeatPeriod(45); // Setup a callback to be fired when a message or an event (open, close, error) is received webSocket.setOnMessageCallback( @@ -305,4 +309,13 @@ A ping message can be sent to the server, with an optional data string. ``` websocket.ping("ping data, optional (empty string is ok): limited to 125 bytes long"); + +### Heartbeat. + +You can configure an optional heart beat / keep-alive, sent every 45 seconds +when there isn't any traffic to make sure that load balancers do not kill an +idle connection. + +``` +webSocket.setHeartBeatPeriod(45); ``` diff --git a/ixwebsocket/IXSocket.cpp b/ixwebsocket/IXSocket.cpp index cfedb3b0..d456d3d2 100644 --- a/ixwebsocket/IXSocket.cpp +++ b/ixwebsocket/IXSocket.cpp @@ -21,6 +21,9 @@ namespace ix { + const int Socket::kDefaultPollNoTimeout = -1; // No poll timeout by default + const int Socket::kDefaultPollTimeout = kDefaultPollNoTimeout; + Socket::Socket(int fd) : _sockfd(fd) { @@ -32,14 +35,8 @@ namespace ix close(); } - void Socket::poll(const OnPollCallback& onPollCallback) + void Socket::poll(const OnPollCallback& onPollCallback, int timeoutSecs) { - if (_sockfd == -1) - { - onPollCallback(); - return; - } - fd_set rfds; FD_ZERO(&rfds); FD_SET(_sockfd, &rfds); @@ -48,11 +45,26 @@ namespace ix FD_SET(_eventfd.getFd(), &rfds); #endif + struct timeval timeout; + timeout.tv_sec = timeoutSecs; + timeout.tv_usec = 0; + int sockfd = _sockfd; int nfds = (std::max)(sockfd, _eventfd.getFd()); - select(nfds + 1, &rfds, nullptr, nullptr, nullptr); + int ret = select(nfds + 1, &rfds, nullptr, nullptr, + (timeoutSecs == kDefaultPollNoTimeout) ? nullptr : &timeout); - onPollCallback(); + PollResultType pollResult = PollResultType_ReadyForRead; + if (ret < 0) + { + pollResult = PollResultType_Error; + } + else if (ret == 0) + { + pollResult = PollResultType_Timeout; + } + + onPollCallback(pollResult); } void Socket::wakeUpFromPoll() diff --git a/ixwebsocket/IXSocket.h b/ixwebsocket/IXSocket.h index 5004d66b..fc72a291 100644 --- a/ixwebsocket/IXSocket.h +++ b/ixwebsocket/IXSocket.h @@ -21,16 +21,24 @@ typedef SSIZE_T ssize_t; namespace ix { + enum PollResultType + { + PollResultType_ReadyForRead = 0, + PollResultType_Timeout = 1, + PollResultType_Error = 2 + }; + class Socket { public: - using OnPollCallback = std::function; + using OnPollCallback = std::function; Socket(int fd = -1); virtual ~Socket(); void configure(); - virtual void poll(const OnPollCallback& onPollCallback); + virtual void poll(const OnPollCallback& onPollCallback, + int timeoutSecs = kDefaultPollTimeout); virtual void wakeUpFromPoll(); // Virtual methods @@ -62,5 +70,9 @@ namespace ix std::atomic _sockfd; std::mutex _socketMutex; EventFd _eventfd; + + private: + static const int kDefaultPollTimeout; + static const int kDefaultPollNoTimeout; }; } diff --git a/ixwebsocket/IXSocketServer.cpp b/ixwebsocket/IXSocketServer.cpp index a6344e27..56012e88 100644 --- a/ixwebsocket/IXSocketServer.cpp +++ b/ixwebsocket/IXSocketServer.cpp @@ -71,9 +71,11 @@ namespace ix (char*) &enable, sizeof(enable)) < 0) { std::stringstream ss; - ss << "SocketServer::listen() error calling setsockopt(SO_REUSEADDR): " - << strerror(errno); + ss << "SocketServer::listen() error calling setsockopt(SO_REUSEADDR) " + << "at address " << _host << ":" << _port + << " : " << strerror(Socket::getErrno()); + ::close(_serverFd); return std::make_pair(false, ss.str()); } @@ -93,21 +95,25 @@ namespace ix if (bind(_serverFd, (struct sockaddr *)&server, sizeof(server)) < 0) { std::stringstream ss; - ss << "SocketServer::listen() error calling bind: " - << strerror(Socket::getErrno()); + ss << "SocketServer::listen() error calling bind " + << "at address " << _host << ":" << _port + << " : " << strerror(Socket::getErrno()); + ::close(_serverFd); return std::make_pair(false, ss.str()); } - /* - * Listen for connections. Specify the tcp backlog. - */ - if (::listen(_serverFd, _backlog) != 0) + // + // Listen for connections. Specify the tcp backlog. + // + if (::listen(_serverFd, _backlog) < 0) { std::stringstream ss; - ss << "SocketServer::listen() error calling listen: " - << strerror(Socket::getErrno()); + ss << "SocketServer::listen() error calling listen " + << "at address " << _host << ":" << _port + << " : " << strerror(Socket::getErrno()); + ::close(_serverFd); return std::make_pair(false, ss.str()); } @@ -136,6 +142,7 @@ namespace ix _stop = false; _conditionVariable.notify_one(); + ::close(_serverFd); } void SocketServer::run() diff --git a/ixwebsocket/IXWebSocket.cpp b/ixwebsocket/IXWebSocket.cpp index b51ed980..d3463cf3 100644 --- a/ixwebsocket/IXWebSocket.cpp +++ b/ixwebsocket/IXWebSocket.cpp @@ -31,12 +31,14 @@ namespace ix { OnTrafficTrackerCallback WebSocket::_onTrafficTrackerCallback = nullptr; const int WebSocket::kDefaultHandShakeTimeoutSecs(60); + const int WebSocket::kDefaultHeartBeatPeriod(-1); WebSocket::WebSocket() : _onMessageCallback(OnMessageCallback()), _stop(false), _automaticReconnection(true), - _handshakeTimeoutSecs(kDefaultHandShakeTimeoutSecs) + _handshakeTimeoutSecs(kDefaultHandShakeTimeoutSecs), + _heartBeatPeriod(kDefaultHeartBeatPeriod) { _ws.setOnCloseCallback( [this](uint16_t code, const std::string& reason, size_t wireSize) @@ -77,6 +79,18 @@ namespace ix return _perMessageDeflateOptions; } + void WebSocket::setHeartBeatPeriod(int hearBeatPeriod) + { + std::lock_guard lock(_configMutex); + _heartBeatPeriod = hearBeatPeriod; + } + + int WebSocket::getHeartBeatPeriod() const + { + std::lock_guard lock(_configMutex); + return _heartBeatPeriod; + } + void WebSocket::start() { if (_thread.joinable()) return; // we've already been started @@ -110,7 +124,8 @@ namespace ix { { std::lock_guard lock(_configMutex); - _ws.configure(_perMessageDeflateOptions); + _ws.configure(_perMessageDeflateOptions, + _heartBeatPeriod); } WebSocketInitResult status = _ws.connectToUrl(_url, timeoutSecs); @@ -130,7 +145,7 @@ namespace ix { { std::lock_guard lock(_configMutex); - _ws.configure(_perMessageDeflateOptions); + _ws.configure(_perMessageDeflateOptions, _heartBeatPeriod); } WebSocketInitResult status = _ws.connectToSocket(fd, timeoutSecs); diff --git a/ixwebsocket/IXWebSocket.h b/ixwebsocket/IXWebSocket.h index ef75d8df..081b0cc4 100644 --- a/ixwebsocket/IXWebSocket.h +++ b/ixwebsocket/IXWebSocket.h @@ -86,7 +86,8 @@ namespace ix void setUrl(const std::string& url); void setPerMessageDeflateOptions(const WebSocketPerMessageDeflateOptions& perMessageDeflateOptions); - void setHandshakeTimeout(int _handshakeTimeoutSecs); + void setHandshakeTimeout(int handshakeTimeoutSecs); + void setHeartBeatPeriod(int hearBeatPeriod); // Run asynchronously, by calling start and stop. void start(); @@ -107,6 +108,7 @@ namespace ix ReadyState getReadyState() const; const std::string& getUrl() const; const WebSocketPerMessageDeflateOptions& getPerMessageDeflateOptions() const; + int getHeartBeatPeriod() const; void enableAutomaticReconnection(); void disableAutomaticReconnection(); @@ -142,6 +144,10 @@ namespace ix std::atomic _handshakeTimeoutSecs; static const int kDefaultHandShakeTimeoutSecs; + // Optional Heartbeat + int _heartBeatPeriod; + static const int kDefaultHeartBeatPeriod; + friend class WebSocketServer; }; } diff --git a/ixwebsocket/IXWebSocketTransport.cpp b/ixwebsocket/IXWebSocketTransport.cpp index 4c92a901..15b155f8 100644 --- a/ixwebsocket/IXWebSocketTransport.cpp +++ b/ixwebsocket/IXWebSocketTransport.cpp @@ -33,12 +33,17 @@ namespace ix { + const std::string WebSocketTransport::kHeartBeatPingMessage("ixwebsocket::hearbeat"); + const int WebSocketTransport::kDefaultHeartBeatPeriod(-1); + WebSocketTransport::WebSocketTransport() : _readyState(CLOSED), _closeCode(0), _closeWireSize(0), _enablePerMessageDeflate(false), - _requestInitCancellation(false) + _requestInitCancellation(false), + _heartBeatPeriod(kDefaultHeartBeatPeriod), + _lastSendTimePoint(std::chrono::steady_clock::now()) { } @@ -48,10 +53,12 @@ namespace ix ; } - void WebSocketTransport::configure(const WebSocketPerMessageDeflateOptions& perMessageDeflateOptions) + void WebSocketTransport::configure(const WebSocketPerMessageDeflateOptions& perMessageDeflateOptions, + int hearBeatPeriod) { _perMessageDeflateOptions = perMessageDeflateOptions; _enablePerMessageDeflate = _perMessageDeflateOptions.enabled(); + _heartBeatPeriod = hearBeatPeriod; } // Client @@ -149,11 +156,30 @@ namespace ix _onCloseCallback = onCloseCallback; } + bool WebSocketTransport::exceedSendHeartBeatTimeOut() + { + std::lock_guard lock(_lastSendTimePointMutex); + auto now = std::chrono::steady_clock::now(); + return now - _lastSendTimePoint > std::chrono::seconds(_heartBeatPeriod); + } + void WebSocketTransport::poll() { _socket->poll( - [this]() + [this](PollResultType pollResult) { + // If (1) heartbeat is enabled, and (2) no data was received or + // send for a duration exceeding our heart-beat period, send a + // ping to the server. + if (pollResult == PollResultType_Timeout && + exceedSendHeartBeatTimeOut()) + { + std::stringstream ss; + ss << kHeartBeatPingMessage << "::" << _heartBeatPeriod << "s"; + sendPing(ss.str()); + return; + } + while (true) { int N = (int) _rxbuf.size(); @@ -185,7 +211,8 @@ namespace ix _socket->close(); setReadyState(CLOSED); } - }); + }, + _heartBeatPeriod); } bool WebSocketTransport::isSendBufferEmpty() const @@ -557,6 +584,9 @@ namespace ix _txbuf.erase(_txbuf.begin(), _txbuf.begin() + ret); } } + + std::lock_guard lck(_lastSendTimePointMutex); + _lastSendTimePoint = std::chrono::steady_clock::now(); } void WebSocketTransport::close() diff --git a/ixwebsocket/IXWebSocketTransport.h b/ixwebsocket/IXWebSocketTransport.h index a2289e83..90585c78 100644 --- a/ixwebsocket/IXWebSocketTransport.h +++ b/ixwebsocket/IXWebSocketTransport.h @@ -57,7 +57,8 @@ namespace ix WebSocketTransport(); ~WebSocketTransport(); - void configure(const WebSocketPerMessageDeflateOptions& perMessageDeflateOptions); + void configure(const WebSocketPerMessageDeflateOptions& perMessageDeflateOptions, + int hearBeatPeriod); WebSocketInitResult connectToUrl(const std::string& url, // Client int timeoutSecs); @@ -116,6 +117,16 @@ namespace ix // Used to cancel dns lookup + socket connect + http upgrade std::atomic _requestInitCancellation; + + // Optional Heartbeat + int _heartBeatPeriod; + static const int kDefaultHeartBeatPeriod; + const static std::string kHeartBeatPingMessage; + mutable std::mutex _lastSendTimePointMutex; + std::chrono::time_point _lastSendTimePoint; + + // No data was send through the socket for longer that the hearbeat period + bool exceedSendHeartBeatTimeOut(); void sendOnSocket(); WebSocketSendInfo sendData(wsheader_type::opcode_type type, diff --git a/makefile b/makefile index 5fe7df6c..04e384f8 100644 --- a/makefile +++ b/makefile @@ -24,6 +24,8 @@ test_server: (cd test && npm i ws && node broadcast-server.js) # env TEST=Websocket_server make test +# env TEST=websocket_server make test +# env TEST=heartbeat make test test: python test/run.py diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index c763530e..3a133b4c 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -34,6 +34,7 @@ set (SOURCES if (NOT WIN32) list(APPEND SOURCES IXWebSocketServerTest.cpp + IXWebSocketHeartBeatTest.cpp cmd_websocket_chat.cpp IXWebSocketTestConnectionDisconnection.cpp ) diff --git a/test/IXWebSocketHeartBeatTest.cpp b/test/IXWebSocketHeartBeatTest.cpp new file mode 100644 index 00000000..cc1d5510 --- /dev/null +++ b/test/IXWebSocketHeartBeatTest.cpp @@ -0,0 +1,222 @@ +/* + * IXWebSocketHeartBeatTest.cpp + * Author: Benjamin Sergeant + * Copyright (c) 2019 Machine Zone. All rights reserved. + */ + +#include +#include +#include +#include +#include + +#include "IXTest.h" + +#include "catch.hpp" + +using namespace ix; + +namespace +{ + class WebSocketClient + { + public: + WebSocketClient(int port); + + void subscribe(const std::string& channel); + void start(); + void stop(); + bool isReady() const; + void sendMessage(const std::string& text); + + private: + ix::WebSocket _webSocket; + int _port; + }; + + WebSocketClient::WebSocketClient(int port) + : _port(port) + { + ; + } + + bool WebSocketClient::isReady() const + { + return _webSocket.getReadyState() == ix::WebSocket_ReadyState_Open; + } + + void WebSocketClient::stop() + { + _webSocket.stop(); + } + + void WebSocketClient::start() + { + std::string url; + { + std::stringstream ss; + ss << "ws://localhost:" + << _port + << "/"; + + url = ss.str(); + } + + _webSocket.setUrl(url); + + // The important bit for this test. + // Set a 1 second hearbeat ; if no traffic is present on the connection for 1 second + // a ping message will be sent by the client. + _webSocket.setHeartBeatPeriod(1); + + std::stringstream ss; + log(std::string("Connecting to url: ") + url); + + _webSocket.setOnMessageCallback( + [](ix::WebSocketMessageType messageType, + const std::string& str, + size_t wireSize, + const ix::WebSocketErrorInfo& error, + const ix::WebSocketOpenInfo& openInfo, + const ix::WebSocketCloseInfo& closeInfo) + { + std::stringstream ss; + if (messageType == ix::WebSocket_MessageType_Open) + { + log("client connected"); + } + else if (messageType == ix::WebSocket_MessageType_Close) + { + log("client disconnected"); + } + else if (messageType == ix::WebSocket_MessageType_Error) + { + ss << "Error ! " << error.reason; + log(ss.str()); + } + else if (messageType == ix::WebSocket_MessageType_Pong) + { + ss << "Received pong message " << str; + log(ss.str()); + } + else if (messageType == ix::WebSocket_MessageType_Ping) + { + ss << "Received ping message " << str; + log(ss.str()); + } + else if (messageType == ix::WebSocket_MessageType_Message) + { + ss << "Received message " << str; + log(ss.str()); + } + else + { + ss << "Invalid ix::WebSocketMessageType"; + log(ss.str()); + } + }); + + _webSocket.start(); + } + + void WebSocketClient::sendMessage(const std::string& text) + { + _webSocket.send(text); + } + + bool startServer(ix::WebSocketServer& server, std::atomic& receivedPingMessages) + { + // A dev/null server + server.setOnConnectionCallback( + [&server, &receivedPingMessages](std::shared_ptr webSocket) + { + webSocket->setOnMessageCallback( + [webSocket, &server, &receivedPingMessages](ix::WebSocketMessageType messageType, + const std::string& str, + size_t wireSize, + const ix::WebSocketErrorInfo& error, + const ix::WebSocketOpenInfo& openInfo, + const ix::WebSocketCloseInfo& closeInfo) + { + if (messageType == ix::WebSocket_MessageType_Open) + { + Logger() << "New server connection"; + Logger() << "Uri: " << openInfo.uri; + Logger() << "Headers:"; + for (auto it : openInfo.headers) + { + Logger() << it.first << ": " << it.second; + } + } + else if (messageType == ix::WebSocket_MessageType_Close) + { + log("Server closed connection"); + } + else if (messageType == ix::WebSocket_MessageType_Ping) + { + log("Server received a ping"); + receivedPingMessages++; + } + } + ); + } + ); + + auto res = server.listen(); + if (!res.first) + { + log(res.second); + return false; + } + + server.start(); + return true; + } +} + +TEST_CASE("Websocket_heartbeat", "[heartbeat]") +{ + SECTION("Make sure that ping messages are sent during heartbeat.") + { + ix::setupWebSocketTrafficTrackerCallback(); + + int port = 8093; + ix::WebSocketServer server(port); + std::atomic serverReceivedPingMessages(0); + REQUIRE(startServer(server, serverReceivedPingMessages)); + + std::string session = ix::generateSessionId(); + WebSocketClient webSocketClientA(port); + WebSocketClient webSocketClientB(port); + + webSocketClientA.start(); + webSocketClientB.start(); + + // Wait for all chat instance to be ready + while (true) + { + if (webSocketClientA.isReady() && webSocketClientB.isReady()) break; + ix::msleep(10); + } + + REQUIRE(server.getClients().size() == 2); + + ix::msleep(900); + webSocketClientB.sendMessage("hello world"); + ix::msleep(900); + webSocketClientB.sendMessage("hello world"); + ix::msleep(900); + + webSocketClientA.stop(); + webSocketClientB.stop(); + + REQUIRE(serverReceivedPingMessages >= 2); + REQUIRE(serverReceivedPingMessages <= 4); + + // Give us 500ms for the server to notice that clients went away + ix::msleep(500); + REQUIRE(server.getClients().size() == 0); + + ix::reportWebSocketTraffic(); + } +} diff --git a/test/IXWebSocketTestConnectionDisconnection.cpp b/test/IXWebSocketTestConnectionDisconnection.cpp index 945a2940..d9bf9bc8 100644 --- a/test/IXWebSocketTestConnectionDisconnection.cpp +++ b/test/IXWebSocketTestConnectionDisconnection.cpp @@ -52,12 +52,12 @@ namespace log(std::string("Connecting to url: ") + url); _webSocket.setOnMessageCallback( - [this](ix::WebSocketMessageType messageType, - const std::string& str, - size_t wireSize, - const ix::WebSocketErrorInfo& error, - const ix::WebSocketOpenInfo& openInfo, - const ix::WebSocketCloseInfo& closeInfo) + [](ix::WebSocketMessageType messageType, + const std::string& str, + size_t wireSize, + const ix::WebSocketErrorInfo& error, + const ix::WebSocketOpenInfo& openInfo, + const ix::WebSocketCloseInfo& closeInfo) { std::stringstream ss; if (messageType == ix::WebSocket_MessageType_Open) diff --git a/test/cmd_websocket_chat.cpp b/test/cmd_websocket_chat.cpp index da131df1..b3a49f45 100644 --- a/test/cmd_websocket_chat.cpp +++ b/test/cmd_websocket_chat.cpp @@ -4,11 +4,6 @@ * Copyright (c) 2017 Machine Zone. All rights reserved. */ -// -// Simple chat program that talks to the node.js server at -// websocket_chat_server/broacast-server.js -// - #include #include #include diff --git a/test/run.py b/test/run.py index 2f78c61d..3cee95bd 100644 --- a/test/run.py +++ b/test/run.py @@ -19,7 +19,7 @@ if osName == 'Windows': testBinary ='ixwebsocket_unittest.exe' else: generator = '' - make = 'make' + make = 'make -j6' testBinary ='./ixwebsocket_unittest' sanitizersFlags = {