diff --git a/CMakeLists.txt b/CMakeLists.txt index 7d4c63be..a82c9c05 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -28,6 +28,7 @@ set( IXWEBSOCKET_SOURCES ixwebsocket/IXCancellationRequest.cpp ixwebsocket/IXNetSystem.cpp ixwebsocket/IXWebSocket.cpp + ixwebsocket/IXWebSocketMessageQueue.cpp ixwebsocket/IXWebSocketServer.cpp ixwebsocket/IXWebSocketTransport.cpp ixwebsocket/IXWebSocketHandshake.cpp @@ -55,6 +56,7 @@ set( IXWEBSOCKET_HEADERS ixwebsocket/IXNetSystem.h ixwebsocket/IXProgressCallback.h ixwebsocket/IXWebSocket.h + ixwebsocket/IXWebSocketMessageQueue.h ixwebsocket/IXWebSocketServer.h ixwebsocket/IXWebSocketTransport.h ixwebsocket/IXWebSocketHandshake.h diff --git a/ixwebsocket/IXWebSocket.h b/ixwebsocket/IXWebSocket.h index 331923ca..f25a88ec 100644 --- a/ixwebsocket/IXWebSocket.h +++ b/ixwebsocket/IXWebSocket.h @@ -119,7 +119,11 @@ namespace ix void close(uint16_t code = 1000, const std::string& reason = "Normal closure"); + // Set callback to receive websocket messages. + // Be aware: your callback will be executed from websocket's internal thread! + // To receive message events in your thread, look at WebSocketMessageQueue class void setOnMessageCallback(const OnMessageCallback& callback); + static void setTrafficTrackerCallback(const OnTrafficTrackerCallback& callback); static void resetTrafficTrackerCallback(); diff --git a/ixwebsocket/IXWebSocketMessageQueue.cpp b/ixwebsocket/IXWebSocketMessageQueue.cpp new file mode 100644 index 00000000..64e9ddb6 --- /dev/null +++ b/ixwebsocket/IXWebSocketMessageQueue.cpp @@ -0,0 +1,121 @@ +/* + * IXWebSocketMessageQueue.cpp + * Author: Korchynskyi Dmytro + * Copyright (c) 2017-2019 Machine Zone, Inc. All rights reserved. + */ + +#include "IXWebSocketMessageQueue.h" + +namespace ix +{ + + WebSocketMessageQueue::WebSocketMessageQueue(WebSocket* websocket) + { + bindWebsocket(websocket); + } + + WebSocketMessageQueue::~WebSocketMessageQueue() + { + if (!_messages.empty()) + { + // not handled all messages + } + + bindWebsocket(nullptr); + } + + void WebSocketMessageQueue::bindWebsocket(WebSocket * websocket) + { + if (_websocket == websocket) return; + + // unbind old + if (_websocket) + { + // set dummy callback just to avoid crash + _websocket->setOnMessageCallback([]( + WebSocketMessageType, + const std::string&, + size_t, + const WebSocketErrorInfo&, + const WebSocketOpenInfo&, + const WebSocketCloseInfo&) + {}); + } + + _websocket = websocket; + + // bind new + if (_websocket) + { + _websocket->setOnMessageCallback([this]( + WebSocketMessageType type, + const std::string& str, + size_t wireSize, + const WebSocketErrorInfo& errorInfo, + const WebSocketOpenInfo& openInfo, + const WebSocketCloseInfo& closeInfo) + { + MessagePtr message(new Message()); + + message->type = type; + message->str = str; + message->wireSize = wireSize; + message->errorInfo = errorInfo; + message->openInfo = openInfo; + message->closeInfo = closeInfo; + + { + std::lock_guard lock(_messagesMutex); + _messages.emplace_back(std::move(message)); + } + }); + } + } + + void WebSocketMessageQueue::setOnMessageCallback(const OnMessageCallback& callback) + { + _onMessageUserCallback = callback; + } + + void WebSocketMessageQueue::setOnMessageCallback(OnMessageCallback&& callback) + { + _onMessageUserCallback = std::move(callback); + } + + WebSocketMessageQueue::MessagePtr WebSocketMessageQueue::popMessage() + { + MessagePtr message; + std::lock_guard lock(_messagesMutex); + + if (!_messages.empty()) + { + message = std::move(_messages.front()); + _messages.pop_front(); + } + + return message; + } + + void WebSocketMessageQueue::poll(int count) + { + if (!_onMessageUserCallback) + return; + + MessagePtr message; + + while (count > 0 && (message = popMessage())) + { + _onMessageUserCallback( + message->type, + message->str, + message->wireSize, + message->errorInfo, + message->openInfo, + message->closeInfo + ); + + --count; + } + } + +} diff --git a/ixwebsocket/IXWebSocketMessageQueue.h b/ixwebsocket/IXWebSocketMessageQueue.h new file mode 100644 index 00000000..b8b85c25 --- /dev/null +++ b/ixwebsocket/IXWebSocketMessageQueue.h @@ -0,0 +1,53 @@ +/* + * IXWebSocketMessageQueue.h + * Author: Korchynskyi Dmytro + * Copyright (c) 2017-2019 Machine Zone, Inc. All rights reserved. + */ + +#pragma once + +#include "IXWebSocket.h" +#include +#include +#include + +namespace ix +{ + // + // A helper class to dispatch websocket message callbacks in your thread. + // + class WebSocketMessageQueue + { + public: + WebSocketMessageQueue(WebSocket* websocket = nullptr); + ~WebSocketMessageQueue(); + + void bindWebsocket(WebSocket* websocket); + + void setOnMessageCallback(const OnMessageCallback& callback); + void setOnMessageCallback(OnMessageCallback&& callback); + + void poll(int count = 512); + + protected: + struct Message + { + WebSocketMessageType type; + std::string str; + size_t wireSize; + WebSocketErrorInfo errorInfo; + WebSocketOpenInfo openInfo; + WebSocketCloseInfo closeInfo; + }; + + using MessagePtr = std::shared_ptr; + + MessagePtr popMessage(); + + private: + WebSocket* _websocket = nullptr; + OnMessageCallback _onMessageUserCallback; + std::mutex _messagesMutex; + std::list _messages; + }; +} diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 77ae3630..d27a796f 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -36,6 +36,7 @@ set (SOURCES IXWebSocketServerTest.cpp IXWebSocketTestConnectionDisconnection.cpp IXUrlParserTest.cpp + IXWebSocketMessageQTest.cpp IXWebSocketServerTest.cpp ) diff --git a/test/IXWebSocketMessageQTest.cpp b/test/IXWebSocketMessageQTest.cpp new file mode 100644 index 00000000..034a2ac0 --- /dev/null +++ b/test/IXWebSocketMessageQTest.cpp @@ -0,0 +1,191 @@ +/* + * IXWebSocketMessageQTest.cpp + * Author: Korchynskyi Dmytro + * Copyright (c) 2019 Machine Zone. All rights reserved. + */ + +#include +#include +#include + +#include "IXTest.h" +#include "catch.hpp" +#include + +using namespace ix; + +namespace +{ + bool startServer(ix::WebSocketServer& server) + { + server.setOnConnectionCallback( + [&server](std::shared_ptr webSocket, + std::shared_ptr connectionState) + { + webSocket->setOnMessageCallback( + [connectionState, &server](ix::WebSocketMessageType messageType, + const std::string & str, + size_t wireSize, + const ix::WebSocketErrorInfo & error, + const ix::WebSocketOpenInfo & openInfo, + const ix::WebSocketCloseInfo & closeInfo) + { + if (messageType == ix::WebSocketMessageType::Open) + { + Logger() << "New connection"; + connectionState->computeId(); + Logger() << "id: " << connectionState->getId(); + Logger() << "Uri: " << openInfo.uri; + Logger() << "Headers:"; + for (auto it : openInfo.headers) + { + Logger() << it.first << ": " << it.second; + } + } + else if (messageType == ix::WebSocketMessageType::Close) + { + Logger() << "Closed connection"; + } + else if (messageType == ix::WebSocketMessageType::Message) + { + Logger() << "Message received: " << str; + + for (auto&& client : server.getClients()) + { + client->send(str); + } + } + } + ); + } + ); + + auto res = server.listen(); + if (!res.first) + { + Logger() << res.second; + return false; + } + + server.start(); + return true; + } + + class MsgQTestClient + { + public: + MsgQTestClient() + { + msgQ.bindWebsocket(&ws); + + msgQ.setOnMessageCallback([this](WebSocketMessageType messageType, + const std::string & str, + size_t wireSize, + const WebSocketErrorInfo & error, + const WebSocketOpenInfo & openInfo, + const WebSocketCloseInfo & closeInfo) + { + REQUIRE(mainThreadId == std::this_thread::get_id()); + + std::stringstream ss; + if (messageType == WebSocketMessageType::Open) + { + log("client connected"); + sendNextMessage(); + } + else if (messageType == WebSocketMessageType::Close) + { + log("client disconnected"); + } + else if (messageType == WebSocketMessageType::Error) + { + ss << "Error ! " << error.reason; + log(ss.str()); + testDone = true; + } + else if (messageType == WebSocketMessageType::Pong) + { + ss << "Received pong message " << str; + log(ss.str()); + } + else if (messageType == WebSocketMessageType::Ping) + { + ss << "Received ping message " << str; + log(ss.str()); + } + else if (messageType == WebSocketMessageType::Message) + { + REQUIRE(str.compare("Hey dude!") == 0); + ++receivedCount; + ss << "Received message " << str; + log(ss.str()); + sendNextMessage(); + } + else + { + ss << "Invalid WebSocketMessageType"; + log(ss.str()); + testDone = true; + } + }); + } + + void sendNextMessage() + { + if (receivedCount >= 3) + { + testDone = true; + succeeded = true; + } + else + { + auto info = ws.sendText("Hey dude!"); + if (info.success) + log("sent message"); + else + log("send failed"); + } + } + + void run(const std::string& url) + { + mainThreadId = std::this_thread::get_id(); + testDone = false; + receivedCount = 0; + + ws.setUrl(url); + ws.start(); + + while (!testDone) + { + msgQ.poll(); + msleep(50); + } + } + + bool isSucceeded() const { return succeeded; } + + private: + WebSocket ws; + WebSocketMessageQueue msgQ; + bool testDone = false; + uint32_t receivedCount = 0; + std::thread::id mainThreadId; + bool succeeded = false; + }; +} + +TEST_CASE("Websocket_message_queue", "[websocket_message_q]") +{ + SECTION("Send several messages") + { + int port = getFreePort(); + WebSocketServer server(port); + REQUIRE(startServer(server)); + + MsgQTestClient testClient; + testClient.run("ws://127.0.0.1:" + std::to_string(port)); + REQUIRE(testClient.isSucceeded()); + } + +}