(websocket server) add a new simpler API to handle client connections / that API does not trigger a memory leak while the previous one did

This commit is contained in:
Benjamin Sergeant 2020-07-23 19:29:41 -07:00
parent ffde283a4b
commit 2798886c0b
16 changed files with 367 additions and 258 deletions

View File

@ -1,6 +1,10 @@
# Changelog
All changes to this project will be documented in this file.
## [9.10.0] - 2020-07-23
(websocket server) add a new simpler API to handle client connections / that API does not trigger a memory leak while the previous one did
## [9.9.3] - 2020-07-17
(build) merge platform specific files which were used to have different implementations for setting a thread name into a single file, to make it easier to include every source files and build the ixwebsocket library (fix #226)

View File

@ -246,6 +246,8 @@ uint32_t m = webSocket.getMaxWaitBetweenReconnectionRetries();
## WebSocket server API
### Legacy api
```cpp
#include <ixwebsocket/IXWebSocketServer.h>
@ -312,6 +314,74 @@ server.wait();
```
### New api
The new API does not require to use 2 nested callbacks, which is a bit annoying. The real fix is that there was a memory leak due to a shared_ptr cycle, due to passing down a shared_ptr<WebSocket> down to the callbacks.
The webSocket reference is guaranteed to be always valid ; by design the callback will never be invoked with a null webSocket object.
```cpp
#include <ixwebsocket/IXWebSocketServer.h>
...
// Run a server on localhost at a given port.
// Bound host name, max connections and listen backlog can also be passed in as parameters.
ix::WebSocketServer server(port);
server.setOnClientMessageCallback(std::shared_ptr<ConnectionState> connectionState,
ConnectionInfo& connectionInfo,
WebSocket& webSocket,
const WebSocketMessagePtr& msg)
{
// The ConnectionInfo object contains information about the connection,
// at this point only the client ip address and the port.
std::cout << "Remote ip: " << connectionInfo.remoteIp << std::endl;
if (msg->type == ix::WebSocketMessageType::Open)
{
std::cout << "New connection" << std::endl;
// A connection state object is available, and has a default id
// You can subclass ConnectionState and pass an alternate factory
// to override it. It is useful if you want to store custom
// attributes per connection (authenticated bool flag, attributes, etc...)
std::cout << "id: " << connectionState->getId() << std::endl;
// The uri the client did connect to.
std::cout << "Uri: " << msg->openInfo.uri << std::endl;
std::cout << "Headers:" << std::endl;
for (auto it : msg->openInfo.headers)
{
std::cout << it.first << ": " << it.second << std::endl;
}
}
else if (msg->type == ix::WebSocketMessageType::Message)
{
// For an echo server, we just send back to the client whatever was received by the server
// All connected clients are available in an std::set. See the broadcast cpp example.
// Second parameter tells whether we are sending the message in binary or text mode.
// Here we send it in the same mode as it was received.
webSocket.send(msg->str, msg->binary);
}
);
auto res = server.listen();
if (!res.first)
{
// Error handling
return 1;
}
// Run the server in the background. Server can be stoped by calling server.stop()
server.start();
// Block until server.stop() is called.
server.wait();
```
## HTTP client API
```cpp

View File

@ -125,9 +125,8 @@ namespace ix
if (std::get<0>(ret))
{
auto response = _onConnectionCallback(std::get<2>(ret),
connectionState,
std::move(connectionInfo));
auto response =
_onConnectionCallback(std::get<2>(ret), connectionState, std::move(connectionInfo));
if (!Http::sendResponse(response, socket))
{
logError("Cannot send response");
@ -203,8 +202,7 @@ namespace ix
// See https://developer.mozilla.org/en-US/docs/Web/HTTP/Redirections
//
setOnConnectionCallback(
[this,
redirectUrl](HttpRequestPtr request,
[this, redirectUrl](HttpRequestPtr request,
std::shared_ptr<ConnectionState> /*connectionState*/,
std::unique_ptr<ConnectionInfo> connectionInfo) -> HttpResponsePtr {
WebSocketHttpHeaders headers;

View File

@ -7,17 +7,17 @@
// unix systems
#if defined(__APPLE__) || defined(__linux__) || defined(BSD)
# include <pthread.h>
#include <pthread.h>
#endif
// freebsd needs this header as well
#if defined(BSD)
# include <pthread_np.h>
#include <pthread_np.h>
#endif
// Windows
#ifdef _WIN32
# include <Windows.h>
#include <Windows.h>
#endif
namespace ix

View File

@ -379,10 +379,13 @@ namespace ix
// Launch the handleConnection work asynchronously in its own thread.
std::lock_guard<std::mutex> lock(_connectionsThreadsMutex);
_connectionsThreads.push_back(std::make_pair(
_connectionsThreads.push_back(
std::make_pair(connectionState,
std::thread(&SocketServer::handleConnection,
this,
std::move(socket),
connectionState,
std::thread(
&SocketServer::handleConnection, this, std::move(socket), connectionState, std::move(connectionInfo))));
std::move(connectionInfo))));
}
}

View File

@ -71,6 +71,11 @@ namespace ix
_onConnectionCallback = callback;
}
void WebSocketServer::setOnClientMessageCallback(const OnClientMessageCallback& callback)
{
_onClientMessageCallback = callback;
}
void WebSocketServer::handleConnection(std::unique_ptr<Socket> socket,
std::shared_ptr<ConnectionState> connectionState,
std::unique_ptr<ConnectionInfo> connectionInfo)
@ -78,7 +83,26 @@ namespace ix
setThreadName("WebSocketServer::" + connectionState->getId());
auto webSocket = std::make_shared<WebSocket>();
if (_onConnectionCallback)
{
_onConnectionCallback(webSocket, connectionState, std::move(connectionInfo));
}
else if (_onClientMessageCallback)
{
webSocket->setOnMessageCallback(
[this, &ws = *webSocket.get(), connectionState, &ci = *connectionInfo.get()](
const WebSocketMessagePtr& msg) {
_onClientMessageCallback(connectionState, ci, ws, msg);
});
}
else
{
logError(
"WebSocketServer Application developer error: No server callback is registerered.");
logError("Missing call to setOnConnectionCallback or setOnClientMessageCallback.");
connectionState->setTerminated();
return;
}
webSocket->disableAutomaticReconnection();

View File

@ -23,9 +23,15 @@ namespace ix
{
public:
using OnConnectionCallback =
std::function<void(std::shared_ptr<WebSocket>, std::shared_ptr<ConnectionState>,
std::function<void(std::shared_ptr<WebSocket>,
std::shared_ptr<ConnectionState>,
std::unique_ptr<ConnectionInfo> connectionInfo)>;
using OnClientMessageCallback = std::function<void(std::shared_ptr<ConnectionState>,
ConnectionInfo&,
WebSocket&,
const WebSocketMessagePtr&)>;
WebSocketServer(int port = SocketServer::kDefaultPort,
const std::string& host = SocketServer::kDefaultHost,
int backlog = SocketServer::kDefaultTcpBacklog,
@ -40,6 +46,7 @@ namespace ix
void disablePerMessageDeflate();
void setOnConnectionCallback(const OnConnectionCallback& callback);
void setOnClientMessageCallback(const OnClientMessageCallback& callback);
// Get all the connected clients
std::set<std::shared_ptr<WebSocket>> getClients();
@ -53,6 +60,7 @@ namespace ix
bool _enablePerMessageDeflate;
OnConnectionCallback _onConnectionCallback;
OnClientMessageCallback _onClientMessageCallback;
std::mutex _clientsMutex;
std::set<std::shared_ptr<WebSocket>> _clients;

View File

@ -6,4 +6,4 @@
#pragma once
#define IX_WEBSOCKET_VERSION "9.9.3"
#define IX_WEBSOCKET_VERSION "9.10.0"

View File

@ -88,8 +88,8 @@ namespace ix
std::shared_ptr<ConnectionState> connectionState,
std::unique_ptr<ConnectionInfo> connectionInfo) {
auto remoteIp = connectionInfo->remoteIp;
webSocket->setOnMessageCallback(
[webSocket, connectionState, remoteIp, &server](const ix::WebSocketMessagePtr& msg) {
webSocket->setOnMessageCallback([webSocket, connectionState, remoteIp, &server](
const ix::WebSocketMessagePtr& msg) {
if (msg->type == ix::WebSocketMessageType::Open)
{
TLogger() << "New connection";

View File

@ -189,12 +189,13 @@ namespace
bool preferTLS = true;
server.setTLSOptions(makeServerTLSOptions(preferTLS));
server.setOnConnectionCallback([&server, &connectionId](
std::shared_ptr<ix::WebSocket> webSocket,
server.setOnConnectionCallback(
[&server, &connectionId](std::shared_ptr<ix::WebSocket> webSocket,
std::shared_ptr<ConnectionState> connectionState,
std::unique_ptr<ConnectionInfo> connectionInfo) {
auto remoteIp = connectionInfo->remoteIp;
webSocket->setOnMessageCallback([webSocket, connectionState, remoteIp, &connectionId, &server](
webSocket->setOnMessageCallback(
[webSocket, connectionState, remoteIp, &connectionId, &server](
const ix::WebSocketMessagePtr& msg) {
if (msg->type == ix::WebSocketMessageType::Open)
{

View File

@ -198,8 +198,8 @@ namespace
std::unique_ptr<ConnectionInfo> connectionInfo) {
auto remoteIp = connectionInfo->remoteIp;
webSocket->setOnMessageCallback(
[webSocket, connectionState, remoteIp, &server](const ix::WebSocketMessagePtr& msg) {
webSocket->setOnMessageCallback([webSocket, connectionState, remoteIp, &server](
const ix::WebSocketMessagePtr& msg) {
if (msg->type == ix::WebSocketMessageType::Open)
{
TLogger() << "New connection";

View File

@ -5,13 +5,11 @@
*/
#include "IXTest.h"
#include "catch.hpp"
#include <memory>
#include <sstream>
#include <ixwebsocket/IXWebSocket.h>
#include <ixwebsocket/IXWebSocketServer.h>
#include <memory>
#include <sstream>
using namespace ix;
@ -69,8 +67,7 @@ namespace
std::stringstream ss;
log(std::string("Connecting to url: ") + url);
_webSocket.setOnMessageCallback([this](const ix::WebSocketMessagePtr& msg)
{
_webSocket.setOnMessageCallback([this](const ix::WebSocketMessagePtr& msg) {
std::stringstream ss;
if (msg->type == ix::WebSocketMessageType::Open)
{
@ -118,32 +115,35 @@ TEST_CASE("Websocket leak test")
int port = getFreePort();
WebSocketServer server(port);
server.setOnConnectionCallback([&webSocketPtr](std::shared_ptr<ix::WebSocket> webSocket,
server.setOnConnectionCallback(
[&webSocketPtr](std::shared_ptr<ix::WebSocket> webSocket,
std::shared_ptr<ConnectionState> connectionState,
std::unique_ptr<ConnectionInfo> connectionInfo)
{
std::unique_ptr<ConnectionInfo> connectionInfo) {
// original ptr in WebSocketServer::handleConnection and the callback argument
REQUIRE(webSocket.use_count() == 2);
webSocket->setOnMessageCallback([&webSocketPtr, webSocket, connectionState](const ix::WebSocketMessagePtr& msg)
{
webSocket->setOnMessageCallback([&webSocketPtr, webSocket, connectionState](
const ix::WebSocketMessagePtr& msg) {
if (msg->type == ix::WebSocketMessageType::Open)
{
log(std::string("New connection id: ") + connectionState->getId());
// original ptr in WebSocketServer::handleConnection, captured ptr of this callback, and ptr in WebSocketServer::_clients
// original ptr in WebSocketServer::handleConnection, captured ptr of
// this callback, and ptr in WebSocketServer::_clients
REQUIRE(webSocket.use_count() == 3);
webSocketPtr = std::shared_ptr<WebSocket>(webSocket);
REQUIRE(webSocket.use_count() == 4);
}
else if (msg->type == ix::WebSocketMessageType::Close)
{
log(std::string("Client closed connection id: ") + connectionState->getId());
log(std::string("Client closed connection id: ") +
connectionState->getId());
}
else
{
log(std::string(msg->str));
}
});
// original ptr in WebSocketServer::handleConnection, argument of this callback, and captured ptr in websocket callback
// original ptr in WebSocketServer::handleConnection, argument of this callback,
// and captured ptr in websocket callback
REQUIRE(webSocket.use_count() == 3);
});
@ -169,7 +169,8 @@ TEST_CASE("Websocket leak test")
ix::msleep(500);
REQUIRE(server.getClients().size() == 0);
// websocket should only be referenced by webSocketPtr but is still used by the websocket callback
// websocket should only be referenced by webSocketPtr but is still used by the
// websocket callback
REQUIRE(webSocketPtr.use_count() == 1);
webSocketPtr->setOnMessageCallback(nullptr);
// websocket should only be referenced by webSocketPtr

View File

@ -33,12 +33,13 @@ namespace ix
};
server.setConnectionStateFactory(factory);
server.setOnConnectionCallback([&server, &connectionId](
std::shared_ptr<ix::WebSocket> webSocket,
server.setOnConnectionCallback(
[&server, &connectionId](std::shared_ptr<ix::WebSocket> webSocket,
std::shared_ptr<ConnectionState> connectionState,
std::unique_ptr<ConnectionInfo> connectionInfo) {
auto remoteIp = connectionInfo->remoteIp;
webSocket->setOnMessageCallback([webSocket, connectionState, remoteIp, &connectionId, &server](
webSocket->setOnMessageCallback(
[webSocket, connectionState, remoteIp, &connectionId, &server](
const ix::WebSocketMessagePtr& msg) {
if (msg->type == ix::WebSocketMessageType::Open)
{

View File

@ -21,7 +21,8 @@ bool startServer(ix::WebSocketServer& server, std::string& subProtocols)
std::shared_ptr<ConnectionState> connectionState,
std::unique_ptr<ConnectionInfo> connectionInfo) {
auto remoteIp = connectionInfo->remoteIp;
webSocket->setOnMessageCallback([webSocket, connectionState, remoteIp, &server, &subProtocols](
webSocket->setOnMessageCallback(
[webSocket, connectionState, remoteIp, &server, &subProtocols](
const ix::WebSocketMessagePtr& msg) {
if (msg->type == ix::WebSocketMessageType::Open)
{

View File

@ -484,7 +484,24 @@ int main(int argc, char** argv)
cobraBotConfig.cobraConfig.socketTLSOptions = tlsOptions;
int ret = 1;
if (app.got_subcommand("transfer"))
if (app.got_subcommand("connect"))
{
ret = ix::ws_connect_main(url,
headers,
disableAutomaticReconnection,
disablePerMessageDeflate,
binaryMode,
maxWaitBetweenReconnectionRetries,
tlsOptions,
subprotocol,
pingIntervalSecs);
}
else if (app.got_subcommand("echo_server"))
{
ret = ix::ws_echo_server_main(
port, greetings, hostname, tlsOptions, ipv6, disablePerMessageDeflate, disablePong);
}
else if (app.got_subcommand("transfer"))
{
ret = ix::ws_transfer_main(port, hostname, tlsOptions);
}
@ -497,27 +514,10 @@ int main(int argc, char** argv)
bool enablePerMessageDeflate = false;
ret = ix::ws_receive_main(url, enablePerMessageDeflate, delayMs, tlsOptions);
}
else if (app.got_subcommand("connect"))
{
ret = ix::ws_connect_main(url,
headers,
disableAutomaticReconnection,
disablePerMessageDeflate,
binaryMode,
maxWaitBetweenReconnectionRetries,
tlsOptions,
subprotocol,
pingIntervalSecs);
}
else if (app.got_subcommand("chat"))
{
ret = ix::ws_chat_main(url, user);
}
else if (app.got_subcommand("echo_server"))
{
ret = ix::ws_echo_server_main(
port, greetings, hostname, tlsOptions, ipv6, disablePerMessageDeflate, disablePong);
}
else if (app.got_subcommand("broadcast_server"))
{
ret = ix::ws_broadcast_server_main(port, hostname, tlsOptions);

View File

@ -42,13 +42,12 @@ namespace ix
server.disablePong();
}
server.setOnConnectionCallback(
[greetings](std::shared_ptr<ix::WebSocket> webSocket,
std::shared_ptr<ConnectionState> connectionState,
std::unique_ptr<ConnectionInfo> connectionInfo) {
auto remoteIp = connectionInfo->remoteIp;
webSocket->setOnMessageCallback(
[webSocket, connectionState, remoteIp, greetings](const WebSocketMessagePtr& msg) {
server.setOnClientMessageCallback(
[greetings](std::shared_ptr<ConnectionState> connectionState,
ConnectionInfo& connectionInfo,
WebSocket& webSocket,
const WebSocketMessagePtr& msg) {
auto remoteIp = connectionInfo.remoteIp;
if (msg->type == ix::WebSocketMessageType::Open)
{
spdlog::info("New connection");
@ -63,7 +62,7 @@ namespace ix
if (greetings)
{
webSocket->sendText("Welcome !");
webSocket.sendText("Welcome !");
}
}
else if (msg->type == ix::WebSocketMessageType::Close)
@ -83,10 +82,9 @@ namespace ix
else if (msg->type == ix::WebSocketMessageType::Message)
{
spdlog::info("Received {} bytes", msg->wireSize);
webSocket->send(msg->str, msg->binary);
webSocket.send(msg->str, msg->binary);
}
});
});
auto res = server.listen();
if (!res.first)