Merge pull request #6 from machinezone/user/bsergeant/server

Add support for writing websocket servers (IXWebSocketServer)
This commit is contained in:
Benjamin Sergeant 2019-01-03 18:47:30 -08:00 committed by GitHub
commit c236ff66e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
42 changed files with 1893 additions and 473 deletions

View File

@ -10,13 +10,18 @@ set (CMAKE_CXX_STANDARD 11)
set (CXX_STANDARD_REQUIRED ON) set (CXX_STANDARD_REQUIRED ON)
set (CMAKE_CXX_EXTENSIONS OFF) set (CMAKE_CXX_EXTENSIONS OFF)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic -Wshorten-64-to-32")
set( IXWEBSOCKET_SOURCES set( IXWEBSOCKET_SOURCES
ixwebsocket/IXEventFd.cpp ixwebsocket/IXEventFd.cpp
ixwebsocket/IXSocket.cpp ixwebsocket/IXSocket.cpp
ixwebsocket/IXSocketConnect.cpp ixwebsocket/IXSocketConnect.cpp
ixwebsocket/IXDNSLookup.cpp ixwebsocket/IXDNSLookup.cpp
ixwebsocket/IXCancellationRequest.cpp
ixwebsocket/IXWebSocket.cpp ixwebsocket/IXWebSocket.cpp
ixwebsocket/IXWebSocketServer.cpp
ixwebsocket/IXWebSocketTransport.cpp ixwebsocket/IXWebSocketTransport.cpp
ixwebsocket/IXWebSocketHandshake.cpp
ixwebsocket/IXWebSocketPerMessageDeflate.cpp ixwebsocket/IXWebSocketPerMessageDeflate.cpp
ixwebsocket/IXWebSocketPerMessageDeflateOptions.cpp ixwebsocket/IXWebSocketPerMessageDeflateOptions.cpp
) )
@ -29,7 +34,9 @@ set( IXWEBSOCKET_HEADERS
ixwebsocket/IXDNSLookup.h ixwebsocket/IXDNSLookup.h
ixwebsocket/IXCancellationRequest.h ixwebsocket/IXCancellationRequest.h
ixwebsocket/IXWebSocket.h ixwebsocket/IXWebSocket.h
ixwebsocket/IXWebSocketServer.h
ixwebsocket/IXWebSocketTransport.h ixwebsocket/IXWebSocketTransport.h
ixwebsocket/IXWebSocketHandshake.h
ixwebsocket/IXWebSocketSendInfo.h ixwebsocket/IXWebSocketSendInfo.h
ixwebsocket/IXWebSocketErrorInfo.h ixwebsocket/IXWebSocketErrorInfo.h
ixwebsocket/IXWebSocketPerMessageDeflate.h ixwebsocket/IXWebSocketPerMessageDeflate.h

View File

@ -3,7 +3,7 @@
## Introduction ## Introduction
[*WebSocket*](https://en.wikipedia.org/wiki/WebSocket) is a computer communications protocol, providing full-duplex [*WebSocket*](https://en.wikipedia.org/wiki/WebSocket) is a computer communications protocol, providing full-duplex
communication channels over a single TCP connection. *IXWebSocket* is a C++ library for Websocket communication. The code is derived from [easywsclient](https://github.com/dhbaird/easywsclient) and from the [Satori C SDK](https://github.com/satori-com/satori-rtm-sdk-c). It has been tested on the following platforms. communication channels over a single TCP connection. *IXWebSocket* is a C++ library for client and server Websocket communication. The code is derived from [easywsclient](https://github.com/dhbaird/easywsclient) and from the [Satori C SDK](https://github.com/satori-com/satori-rtm-sdk-c). It has been tested on the following platforms.
* macOS * macOS
* iOS * iOS
@ -15,7 +15,7 @@ communication channels over a single TCP connection. *IXWebSocket* is a C++ libr
The examples folder countains a simple chat program, using a node.js broadcast server. The examples folder countains a simple chat program, using a node.js broadcast server.
Here is what the API looks like. Here is what the client API looks like.
``` ```
ix::WebSocket webSocket; ix::WebSocket webSocket;
@ -50,10 +50,66 @@ webSocket.send("hello world");
webSocket.stop() webSocket.stop()
``` ```
Here is what the server API looks like. Note that server support is very recent and subject to changes.
```
// Run a server on localhost at a given port.
// Bound host name, max connections and listen backlog can also be passed in as parameters.
ix::WebSocketServer server(port);
server.setOnConnectionCallback(
[&server](std::shared_ptr<ix::WebSocket> webSocket)
{
webSocket->setOnMessageCallback(
[webSocket, &server](ix::WebSocketMessageType messageType,
const std::string& str,
size_t wireSize,
const ix::WebSocketErrorInfo& error,
const ix::WebSocketOpenInfo& openInfo,
const ix::WebSocketCloseInfo& closeInfo)
{
if (messageType == ix::WebSocket_MessageType_Open)
{
std::cerr << "New connection" << std::endl;
std::cerr << "Uri: " << openInfo.uri << std::endl;
std::cerr << "Headers:" << std::endl;
for (auto it : openInfo.headers)
{
std::cerr << it.first << ": " << it.second << std::endl;
}
}
else if (messageType == ix::WebSocket_MessageType_Message)
{
// For an echo server, we just send back to the client whatever was received by the client
// All connected clients are available in an std::set. See the broadcast cpp example.
webSocket->send(str);
}
}
);
}
);
auto res = server.listen();
if (!res.first)
{
// Error handling
return 1;
}
// Run the server in the background. Server can be stoped by calling server.stop()
server.start();
// Block until server.stop() is called.
server.wait();
```
## Build ## Build
CMakefiles for the library and the examples are available. This library has few dependencies, so it is possible to just add the source files into your project. CMakefiles for the library and the examples are available. This library has few dependencies, so it is possible to just add the source files into your project.
There is a Dockerfile for running some code on Linux, and a unittest which can be executed by typing `make test`.
## Implementation details ## Implementation details
### Per Message Deflate compression. ### Per Message Deflate compression.
@ -76,6 +132,7 @@ If the remote end (server) breaks the connection, the code will try to perpetual
* There is no text support for sending data, only the binary protocol is supported. Sending json or text over the binary protocol works well. * There is no text support for sending data, only the binary protocol is supported. Sending json or text over the binary protocol works well.
* Automatic reconnection works at the TCP socket level, and will detect remote end disconnects. However, if the device/computer network become unreachable (by turning off wifi), it is quite hard to reliably and timely detect it at the socket level using `recv` and `send` error codes. [Here](https://stackoverflow.com/questions/14782143/linux-socket-how-to-detect-disconnected-network-in-a-client-program) is a good discussion on the subject. This behavior is consistent with other runtimes such as node.js. One way to detect a disconnected device with low level C code is to do a name resolution with DNS but this can be expensive. Mobile devices have good and reliable API to do that. * Automatic reconnection works at the TCP socket level, and will detect remote end disconnects. However, if the device/computer network become unreachable (by turning off wifi), it is quite hard to reliably and timely detect it at the socket level using `recv` and `send` error codes. [Here](https://stackoverflow.com/questions/14782143/linux-socket-how-to-detect-disconnected-network-in-a-client-program) is a good discussion on the subject. This behavior is consistent with other runtimes such as node.js. One way to detect a disconnected device with low level C code is to do a name resolution with DNS but this can be expensive. Mobile devices have good and reliable API to do that.
* The server code is using select to detect incoming data, and creates one OS thread per connection. This isn't as scalable as strategies using epoll or kqueue.
## Examples ## Examples
@ -92,21 +149,38 @@ If the remote end (server) breaks the connection, the code will try to perpetual
Here's a simplistic diagram which explains how the code is structured in term of class/modules. Here's a simplistic diagram which explains how the code is structured in term of class/modules.
``` ```
+-----------------------+ +-----------------------+ --- Public
| | Start the receiving Background thread. Auto reconnection. Simple websocket Ping. | | Start the receiving Background thread. Auto reconnection. Simple websocket Ping.
| IXWebSocket | Interface used by C++ test clients. No IX dependencies. | IXWebSocket | Interface used by C++ test clients. No IX dependencies.
| | | |
+-----------------------+ +-----------------------+
| | | |
| IXWebSocketServer | Run a server and give each connections its own WebSocket object.
| | Each connection is handled in a new OS thread.
| |
+-----------------------+ --- Private
| |
| IXWebSocketTransport | Low level websocket code, framing, managing raw socket. Adapted from easywsclient. | IXWebSocketTransport | Low level websocket code, framing, managing raw socket. Adapted from easywsclient.
| | | |
+-----------------------+ +-----------------------+
| | | |
| IXWebSocketHandshake | Establish the connection between client and server.
| |
+-----------------------+
| |
| IXWebSocket | ws:// Unencrypted Socket handler | IXWebSocket | ws:// Unencrypted Socket handler
| IXWebSocketAppleSSL | wss:// TLS encrypted Socket AppleSSL handler. Used on iOS and macOS | IXWebSocketAppleSSL | wss:// TLS encrypted Socket AppleSSL handler. Used on iOS and macOS
| IXWebSocketOpenSSL | wss:// TLS encrypted Socket OpenSSL handler. Used on Android and Linux | IXWebSocketOpenSSL | wss:// TLS encrypted Socket OpenSSL handler. Used on Android and Linux
| | Can be used on macOS too. | | Can be used on macOS too.
+-----------------------+ +-----------------------+
| |
| IXSocketConnect | Connect to the remote host (client).
| |
+-----------------------+
| |
| IXDNSLookup | Does DNS resolution asynchronously so that it can be interrupted.
| |
+-----------------------+
``` ```
## API ## API

View File

@ -10,8 +10,10 @@ RUN apt-get -y install procps
RUN apt-get -y install lsof RUN apt-get -y install lsof
RUN apt-get -y install libz-dev RUN apt-get -y install libz-dev
RUN apt-get -y install vim RUN apt-get -y install vim
RUN apt-get -y install make
RUN apt-get -y install cmake
COPY . . COPY . .
WORKDIR examples/ws_connect WORKDIR test
RUN ["sh", "build_linux.sh"] RUN ["sh", "build_linux.sh"]

9
examples/broadcast_server/.gitignore vendored Normal file
View File

@ -0,0 +1,9 @@
CMakeCache.txt
package-lock.json
CMakeFiles
ixwebsocket_unittest
cmake_install.cmake
node_modules
ixwebsocket
Makefile
build

View File

@ -0,0 +1,30 @@
#
# Author: Benjamin Sergeant
# Copyright (c) 2018 Machine Zone, Inc. All rights reserved.
#
cmake_minimum_required (VERSION 3.4.1)
project (broadcast_server)
# There's -Weverything too for clang
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic -Wshorten-64-to-32")
set (OPENSSL_PREFIX /usr/local/opt/openssl) # Homebrew openssl
set (CMAKE_CXX_STANDARD 11)
option(USE_TLS "Add TLS support" ON)
add_subdirectory(${PROJECT_SOURCE_DIR}/../.. ixwebsocket)
include_directories(broadcast_server .)
add_executable(broadcast_server
broadcast_server.cpp)
if (APPLE AND USE_TLS)
target_link_libraries(broadcast_server "-framework foundation" "-framework security")
endif()
target_link_libraries(broadcast_server ixwebsocket)
install(TARGETS broadcast_server DESTINATION bin)

View File

@ -0,0 +1,74 @@
/*
* broadcast_server.cpp
* Author: Benjamin Sergeant
* Copyright (c) 2018 Machine Zone, Inc. All rights reserved.
*/
#include <iostream>
#include <sstream>
#include <ixwebsocket/IXWebSocketServer.h>
int main(int argc, char** argv)
{
int port = 8080;
if (argc == 2)
{
std::stringstream ss;
ss << argv[1];
ss >> port;
}
ix::WebSocketServer server(port);
server.setOnConnectionCallback(
[&server](std::shared_ptr<ix::WebSocket> webSocket)
{
webSocket->setOnMessageCallback(
[webSocket, &server](ix::WebSocketMessageType messageType,
const std::string& str,
size_t wireSize,
const ix::WebSocketErrorInfo& error,
const ix::WebSocketOpenInfo& openInfo,
const ix::WebSocketCloseInfo& closeInfo)
{
if (messageType == ix::WebSocket_MessageType_Open)
{
std::cerr << "New connection" << std::endl;
std::cerr << "Uri: " << openInfo.uri << std::endl;
std::cerr << "Headers:" << std::endl;
for (auto it : openInfo.headers)
{
std::cerr << it.first << ": " << it.second << std::endl;
}
}
else if (messageType == ix::WebSocket_MessageType_Close)
{
std::cerr << "Closed connection" << std::endl;
}
else if (messageType == ix::WebSocket_MessageType_Message)
{
for (auto&& client : server.getClients())
{
if (client != webSocket)
{
client->send(str);
}
}
}
}
);
}
);
auto res = server.listen();
if (!res.first)
{
std::cerr << res.second << std::endl;
return 1;
}
server.start();
server.wait();
return 0;
}

View File

@ -87,8 +87,8 @@ namespace
const std::string& str, const std::string& str,
size_t wireSize, size_t wireSize,
const ix::WebSocketErrorInfo& error, const ix::WebSocketErrorInfo& error,
const ix::WebSocketCloseInfo& closeInfo, const ix::WebSocketOpenInfo& openInfo,
const ix::WebSocketHttpHeaders& headers) const ix::WebSocketCloseInfo& closeInfo)
{ {
std::stringstream ss; std::stringstream ss;
if (messageType == ix::WebSocket_MessageType_Open) if (messageType == ix::WebSocket_MessageType_Open)

View File

@ -0,0 +1,30 @@
#
# Author: Benjamin Sergeant
# Copyright (c) 2018 Machine Zone, Inc. All rights reserved.
#
cmake_minimum_required (VERSION 3.4.1)
project (echo_server)
# There's -Weverything too for clang
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic -Wshorten-64-to-32")
set (OPENSSL_PREFIX /usr/local/opt/openssl) # Homebrew openssl
set (CMAKE_CXX_STANDARD 11)
option(USE_TLS "Add TLS support" ON)
add_subdirectory(${PROJECT_SOURCE_DIR}/../.. ixwebsocket)
include_directories(echo_server .)
add_executable(echo_server
echo_server.cpp)
if (APPLE AND USE_TLS)
target_link_libraries(echo_server "-framework foundation" "-framework security")
endif()
target_link_libraries(echo_server ixwebsocket)
install(TARGETS echo_server DESTINATION bin)

View File

@ -0,0 +1,68 @@
/*
* echo_server.cpp
* Author: Benjamin Sergeant
* Copyright (c) 2018 Machine Zone, Inc. All rights reserved.
*/
#include <iostream>
#include <sstream>
#include <ixwebsocket/IXWebSocketServer.h>
int main(int argc, char** argv)
{
int port = 8080;
if (argc == 2)
{
std::stringstream ss;
ss << argv[1];
ss >> port;
}
ix::WebSocketServer server(port);
server.setOnConnectionCallback(
[&server](std::shared_ptr<ix::WebSocket> webSocket)
{
webSocket->setOnMessageCallback(
[webSocket, &server](ix::WebSocketMessageType messageType,
const std::string& str,
size_t wireSize,
const ix::WebSocketErrorInfo& error,
const ix::WebSocketOpenInfo& openInfo,
const ix::WebSocketCloseInfo& closeInfo)
{
if (messageType == ix::WebSocket_MessageType_Open)
{
std::cerr << "New connection" << std::endl;
std::cerr << "Uri: " << openInfo.uri << std::endl;
std::cerr << "Headers:" << std::endl;
for (auto it : openInfo.headers)
{
std::cerr << it.first << ": " << it.second << std::endl;
}
}
else if (messageType == ix::WebSocket_MessageType_Close)
{
std::cerr << "Closed connection" << std::endl;
}
else if (messageType == ix::WebSocket_MessageType_Message)
{
webSocket->send(str);
}
}
);
}
);
auto res = server.listen();
if (!res.first)
{
std::cerr << res.second << std::endl;
return 1;
}
server.start();
server.wait();
return 0;
}

View File

@ -58,8 +58,8 @@ namespace
const std::string& str, const std::string& str,
size_t wireSize, size_t wireSize,
const ix::WebSocketErrorInfo& error, const ix::WebSocketErrorInfo& error,
const ix::WebSocketCloseInfo& closeInfo, const ix::WebSocketOpenInfo& openInfo,
const ix::WebSocketHttpHeaders& headers) const ix::WebSocketCloseInfo& closeInfo)
{ {
std::stringstream ss; std::stringstream ss;
if (messageType == ix::WebSocket_MessageType_Open) if (messageType == ix::WebSocket_MessageType_Open)

View File

@ -61,15 +61,16 @@ namespace
const std::string& str, const std::string& str,
size_t wireSize, size_t wireSize,
const ix::WebSocketErrorInfo& error, const ix::WebSocketErrorInfo& error,
const ix::WebSocketCloseInfo& closeInfo, const ix::WebSocketOpenInfo& openInfo,
const ix::WebSocketHttpHeaders& headers) const ix::WebSocketCloseInfo& closeInfo)
{ {
std::stringstream ss; std::stringstream ss;
if (messageType == ix::WebSocket_MessageType_Open) if (messageType == ix::WebSocket_MessageType_Open)
{ {
log("ws_connect: connected"); log("ws_connect: connected");
std::cout << "Uri: " << openInfo.uri << std::endl;
std::cout << "Handshake Headers:" << std::endl; std::cout << "Handshake Headers:" << std::endl;
for (auto it : headers) for (auto it : openInfo.headers)
{ {
std::cout << it.first << ": " << it.second << std::endl; std::cout << it.first << ": " << it.second << std::endl;
} }

View File

@ -0,0 +1,33 @@
/*
* IXCancellationRequest.cpp
* Author: Benjamin Sergeant
* Copyright (c) 2019 Machine Zone, Inc. All rights reserved.
*/
#include "IXCancellationRequest.h"
#include <chrono>
namespace ix
{
CancellationRequest makeCancellationRequestWithTimeout(int secs,
std::atomic<bool>& requestInitCancellation)
{
auto start = std::chrono::system_clock::now();
auto timeout = std::chrono::seconds(secs);
auto isCancellationRequested = [&requestInitCancellation, start, timeout]() -> bool
{
// Was an explicit cancellation requested ?
if (requestInitCancellation) return true;
auto now = std::chrono::system_clock::now();
if ((now - start) > timeout) return true;
// No cancellation request
return false;
};
return isCancellationRequested;
}
}

View File

@ -7,9 +7,13 @@
#pragma once #pragma once
#include <functional> #include <functional>
#include <atomic>
namespace ix namespace ix
{ {
using CancellationRequest = std::function<bool()>; using CancellationRequest = std::function<bool()>;
CancellationRequest makeCancellationRequestWithTimeout(int seconds,
std::atomic<bool>& requestInitCancellation);
} }

View File

@ -14,9 +14,7 @@
namespace ix namespace ix
{ {
// 60s timeout, see IXSocketConnect.cpp const int64_t DNSLookup::kDefaultWait = 10; // ms
const int64_t DNSLookup::kDefaultTimeout = 60 * 1000; // ms
const int64_t DNSLookup::kDefaultWait = 10; // ms
std::atomic<uint64_t> DNSLookup::_nextId(0); std::atomic<uint64_t> DNSLookup::_nextId(0);
std::set<uint64_t> DNSLookup::_activeJobs; std::set<uint64_t> DNSLookup::_activeJobs;
@ -112,7 +110,6 @@ namespace ix
_thread = std::thread(&DNSLookup::run, this); _thread = std::thread(&DNSLookup::run, this);
_thread.detach(); _thread.detach();
int64_t timeout = kDefaultTimeout;
std::unique_lock<std::mutex> lock(_conditionVariableMutex); std::unique_lock<std::mutex> lock(_conditionVariableMutex);
while (!_done) while (!_done)
@ -131,14 +128,6 @@ namespace ix
errMsg = "cancellation requested"; errMsg = "cancellation requested";
return nullptr; return nullptr;
} }
// Have we exceeded the timeout ?
timeout -= _wait;
if (timeout <= 0)
{
errMsg = "dns lookup timed out after 60 seconds";
return nullptr;
}
} }
// Maybe a cancellation request got in before the bg terminated ? // Maybe a cancellation request got in before the bg terminated ?

View File

@ -61,7 +61,6 @@ namespace ix
static std::set<uint64_t> _activeJobs; static std::set<uint64_t> _activeJobs;
static std::mutex _activeJobsMutex; static std::mutex _activeJobsMutex;
const static int64_t kDefaultTimeout;
const static int64_t kDefaultWait; const static int64_t kDefaultWait;
}; };
} }

View File

@ -37,8 +37,8 @@
namespace ix namespace ix
{ {
Socket::Socket() : Socket::Socket(int fd) :
_sockfd(-1) _sockfd(fd)
{ {
} }
@ -162,4 +162,94 @@ namespace ix
WSACleanup(); WSACleanup();
#endif #endif
} }
bool Socket::readByte(void* buffer,
const CancellationRequest& isCancellationRequested)
{
while (true)
{
if (isCancellationRequested()) return false;
int ret;
ret = recv(buffer, 1);
// We read one byte, as needed, all good.
if (ret == 1)
{
return true;
}
// There is possibly something to be read, try again
else if (ret < 0 && (getErrno() == EWOULDBLOCK ||
getErrno() == EAGAIN))
{
// Wait with a timeout until something is written.
// This way we are not busy looping
fd_set rfds;
struct timeval timeout;
timeout.tv_sec = 0;
timeout.tv_usec = 1 * 1000; // 1ms
FD_ZERO(&rfds);
FD_SET(_sockfd, &rfds);
select(_sockfd + 1, &rfds, nullptr, nullptr, &timeout);
continue;
}
// There was an error during the read, abort
else
{
return false;
}
}
}
bool Socket::writeBytes(const std::string& str,
const CancellationRequest& isCancellationRequested)
{
while (true)
{
if (isCancellationRequested()) return false;
char* buffer = const_cast<char*>(str.c_str());
int len = (int) str.size();
int ret = send(buffer, len);
// We wrote some bytes, as needed, all good.
if (ret > 0)
{
return ret == len;
}
// There is possibly something to be write, try again
else if (ret < 0 && (getErrno() == EWOULDBLOCK ||
getErrno() == EAGAIN))
{
continue;
}
// There was an error during the write, abort
else
{
return false;
}
}
}
std::pair<bool, std::string> Socket::readLine(const CancellationRequest& isCancellationRequested)
{
char c;
std::string line;
line.reserve(64);
for (int i = 0; i < 2 || (line[i-2] != '\r' && line[i-1] != '\n'); ++i)
{
if (!readByte(&c, isCancellationRequested))
{
return std::make_pair(false, std::string());
}
line += c;
}
return std::make_pair(true, line);
}
} }

View File

@ -14,15 +14,13 @@
#include "IXEventFd.h" #include "IXEventFd.h"
#include "IXCancellationRequest.h" #include "IXCancellationRequest.h"
struct addrinfo;
namespace ix namespace ix
{ {
class Socket { class Socket {
public: public:
using OnPollCallback = std::function<void()>; using OnPollCallback = std::function<void()>;
Socket(); Socket(int fd = -1);
virtual ~Socket(); virtual ~Socket();
void configure(); void configure();
@ -41,6 +39,14 @@ namespace ix
virtual int send(const std::string& buffer); virtual int send(const std::string& buffer);
virtual int recv(void* buffer, size_t length); virtual int recv(void* buffer, size_t length);
// Blocking and cancellable versions, working with socket that can be set
// to non blocking mode. Used during HTTP upgrade.
bool readByte(void* buffer,
const CancellationRequest& isCancellationRequested);
bool writeBytes(const std::string& str,
const CancellationRequest& isCancellationRequested);
std::pair<bool, std::string> readLine(const CancellationRequest& isCancellationRequested);
int getErrno() const; int getErrno() const;
static bool init(); // Required on Windows to initialize WinSocket static bool init(); // Required on Windows to initialize WinSocket
static void cleanup(); // Required on Windows to cleanup WinSocket static void cleanup(); // Required on Windows to cleanup WinSocket
@ -51,8 +57,5 @@ namespace ix
std::atomic<int> _sockfd; std::atomic<int> _sockfd;
std::mutex _socketMutex; std::mutex _socketMutex;
EventFd _eventfd; EventFd _eventfd;
private:
}; };
} }

View File

@ -143,7 +143,7 @@ std::string getSSLErrorDescription(OSStatus status)
namespace ix namespace ix
{ {
SocketAppleSSL::SocketAppleSSL() : SocketAppleSSL::SocketAppleSSL(int fd) : Socket(fd),
_sslContext(nullptr) _sslContext(nullptr)
{ {
; ;

View File

@ -19,7 +19,7 @@ namespace ix
class SocketAppleSSL : public Socket class SocketAppleSSL : public Socket
{ {
public: public:
SocketAppleSSL(); SocketAppleSSL(int fd = -1);
~SocketAppleSSL(); ~SocketAppleSSL();
virtual bool connect(const std::string& host, virtual bool connect(const std::string& host,

View File

@ -53,12 +53,11 @@ namespace ix
// This is important so that we don't block the main UI thread when shutting down a connection which is // This is important so that we don't block the main UI thread when shutting down a connection which is
// already trying to reconnect, and can be blocked waiting for ::connect to respond. // already trying to reconnect, and can be blocked waiting for ::connect to respond.
// //
bool SocketConnect::connectToAddress(const struct addrinfo *address, int SocketConnect::connectToAddress(const struct addrinfo *address,
int& sockfd, std::string& errMsg,
std::string& errMsg, const CancellationRequest& isCancellationRequested)
const CancellationRequest& isCancellationRequested)
{ {
sockfd = -1; errMsg = "no error";
int fd = socket(address->ai_family, int fd = socket(address->ai_family,
address->ai_socktype, address->ai_socktype,
@ -66,7 +65,7 @@ namespace ix
if (fd < 0) if (fd < 0)
{ {
errMsg = "Cannot create a socket"; errMsg = "Cannot create a socket";
return false; return -1;
} }
// Set the socket to non blocking mode, so that slow responses cannot // Set the socket to non blocking mode, so that slow responses cannot
@ -78,24 +77,12 @@ namespace ix
{ {
closeSocket(fd); closeSocket(fd);
errMsg = strerror(errno); errMsg = strerror(errno);
return false; return -1;
} }
// for (;;)
// If during a connection attempt the request remains idle for longer
// than the timeout interval, the request is considered to have timed
// out. The default timeout interval is 60 seconds.
//
// See https://developer.apple.com/documentation/foundation/nsmutableurlrequest/1414063-timeoutinterval?language=objc
//
// 60 seconds timeout, each time we wait for 50ms with select -> 1200 attempts
//
int selectTimeOut = 50 * 1000; // In micro-seconds => 50ms
int maxRetries = 60 * 1000 * 1000 / selectTimeOut;
for (int i = 0; i < maxRetries; ++i)
{ {
if (isCancellationRequested()) if (isCancellationRequested()) // Must handle timeout as well
{ {
closeSocket(fd); closeSocket(fd);
errMsg = "Cancelled"; errMsg = "Cancelled";
@ -106,10 +93,10 @@ namespace ix
FD_ZERO(&wfds); FD_ZERO(&wfds);
FD_SET(fd, &wfds); FD_SET(fd, &wfds);
// 50ms timeout // 50ms select timeout
struct timeval timeout; struct timeval timeout;
timeout.tv_sec = 0; timeout.tv_sec = 0;
timeout.tv_usec = selectTimeOut; timeout.tv_usec = 50 * 1000;
select(fd + 1, nullptr, &wfds, nullptr, &timeout); select(fd + 1, nullptr, &wfds, nullptr, &timeout);
@ -127,19 +114,18 @@ namespace ix
{ {
closeSocket(fd); closeSocket(fd);
errMsg = strerror(optval); errMsg = strerror(optval);
return false; return -1;
} }
else else
{ {
// Success ! // Success !
sockfd = fd; return fd;
return true;
} }
} }
closeSocket(fd); closeSocket(fd);
errMsg = "connect timed out after 60 seconds"; errMsg = "connect timed out after 60 seconds";
return false; return -1;
} }
int SocketConnect::connect(const std::string& hostname, int SocketConnect::connect(const std::string& hostname,
@ -161,14 +147,13 @@ namespace ix
// iterate through the records to find a working peer // iterate through the records to find a working peer
struct addrinfo *address; struct addrinfo *address;
bool success = false;
for (address = res; address != nullptr; address = address->ai_next) for (address = res; address != nullptr; address = address->ai_next)
{ {
// //
// Second try to connect to the remote host // Second try to connect to the remote host
// //
success = connectToAddress(address, sockfd, errMsg, isCancellationRequested); sockfd = connectToAddress(address, errMsg, isCancellationRequested);
if (success) if (sockfd != -1)
{ {
break; break;
} }
@ -178,6 +163,7 @@ namespace ix
return sockfd; return sockfd;
} }
// FIXME: configure is a terrible name
void SocketConnect::configure(int sockfd) void SocketConnect::configure(int sockfd)
{ {
// 1. disable Nagle's algorithm // 1. disable Nagle's algorithm

View File

@ -21,13 +21,12 @@ namespace ix
std::string& errMsg, std::string& errMsg,
const CancellationRequest& isCancellationRequested); const CancellationRequest& isCancellationRequested);
private:
static bool connectToAddress(const struct addrinfo *address,
int& sockfd,
std::string& errMsg,
const CancellationRequest& isCancellationRequested);
static void configure(int sockfd); static void configure(int sockfd);
private:
static int connectToAddress(const struct addrinfo *address,
std::string& errMsg,
const CancellationRequest& isCancellationRequested);
}; };
} }

View File

@ -74,7 +74,7 @@ SSL *openssl_create_connection(SSL_CTX *ctx, int socket)
namespace ix namespace ix
{ {
SocketOpenSSL::SocketOpenSSL() : SocketOpenSSL::SocketOpenSSL(int fd) : Socket(fd),
_ssl_connection(nullptr), _ssl_connection(nullptr),
_ssl_context(nullptr) _ssl_context(nullptr)
{ {

View File

@ -22,7 +22,7 @@ namespace ix
class SocketOpenSSL : public Socket class SocketOpenSSL : public Socket
{ {
public: public:
SocketOpenSSL(); SocketOpenSSL(int fd = -1);
~SocketOpenSSL(); ~SocketOpenSSL();
virtual bool connect(const std::string& host, virtual bool connect(const std::string& host,

View File

@ -82,7 +82,6 @@ namespace ix
void SocketSChannel::secureSocket() void SocketSChannel::secureSocket()
{ {
// there will be a lot to do here ... // there will be a lot to do here ...
// FIXME do something with sockerror
} }
void SocketSChannel::close() void SocketSChannel::close()

View File

@ -6,6 +6,7 @@
#include "IXWebSocket.h" #include "IXWebSocket.h"
#include "IXSetThreadName.h" #include "IXSetThreadName.h"
#include "IXWebSocketHandshake.h"
#include <iostream> #include <iostream>
#include <cmath> #include <cmath>
@ -29,12 +30,22 @@ namespace
namespace ix namespace ix
{ {
OnTrafficTrackerCallback WebSocket::_onTrafficTrackerCallback = nullptr; OnTrafficTrackerCallback WebSocket::_onTrafficTrackerCallback = nullptr;
const int WebSocket::kDefaultHandShakeTimeoutSecs(60);
WebSocket::WebSocket() : WebSocket::WebSocket() :
_onMessageCallback(OnMessageCallback()), _onMessageCallback(OnMessageCallback()),
_stop(false), _stop(false),
_automaticReconnection(true) _automaticReconnection(true),
_handshakeTimeoutSecs(kDefaultHandShakeTimeoutSecs)
{ {
_ws.setOnCloseCallback(
[this](uint16_t code, const std::string& reason, size_t wireSize)
{
_onMessageCallback(WebSocket_MessageType_Close, "", wireSize,
WebSocketErrorInfo(), WebSocketOpenInfo(),
WebSocketCloseInfo(code, reason));
}
);
} }
WebSocket::~WebSocket() WebSocket::~WebSocket()
@ -75,13 +86,16 @@ namespace ix
void WebSocket::stop() void WebSocket::stop()
{ {
bool automaticReconnection = _automaticReconnection;
// This value needs to be forced when shutting down, it is restored later
_automaticReconnection = false; _automaticReconnection = false;
close(); close();
if (!_thread.joinable()) if (!_thread.joinable())
{ {
_automaticReconnection = true; _automaticReconnection = automaticReconnection;
return; return;
} }
@ -89,35 +103,46 @@ namespace ix
_thread.join(); _thread.join();
_stop = false; _stop = false;
_automaticReconnection = true; _automaticReconnection = automaticReconnection;
} }
WebSocketInitResult WebSocket::connect() WebSocketInitResult WebSocket::connect(int timeoutSecs)
{ {
{ {
std::lock_guard<std::mutex> lock(_configMutex); std::lock_guard<std::mutex> lock(_configMutex);
_ws.configure(_url, _perMessageDeflateOptions); _ws.configure(_perMessageDeflateOptions);
} }
_ws.setOnCloseCallback( WebSocketInitResult status = _ws.connectToUrl(_url, timeoutSecs);
[this](uint16_t code, const std::string& reason, size_t wireSize)
{
_onMessageCallback(WebSocket_MessageType_Close, "", wireSize,
WebSocketErrorInfo(),
WebSocketCloseInfo(code, reason),
WebSocketHttpHeaders());
}
);
WebSocketInitResult status = _ws.init();
if (!status.success) if (!status.success)
{ {
return status; return status;
} }
_onMessageCallback(WebSocket_MessageType_Open, "", 0, _onMessageCallback(WebSocket_MessageType_Open, "", 0,
WebSocketErrorInfo(), WebSocketCloseInfo(), WebSocketErrorInfo(),
status.headers); WebSocketOpenInfo(status.uri, status.headers),
WebSocketCloseInfo());
return status;
}
WebSocketInitResult WebSocket::connectToSocket(int fd, int timeoutSecs)
{
{
std::lock_guard<std::mutex> lock(_configMutex);
_ws.configure(_perMessageDeflateOptions);
}
WebSocketInitResult status = _ws.connectToSocket(fd, timeoutSecs);
if (!status.success)
{
return status;
}
_onMessageCallback(WebSocket_MessageType_Open, "", 0,
WebSocketErrorInfo(),
WebSocketOpenInfo(status.uri, status.headers),
WebSocketCloseInfo());
return status; return status;
} }
@ -151,7 +176,7 @@ namespace ix
break; break;
} }
status = connect(); status = connect(_handshakeTimeoutSecs);
if (!status.success && !_stop) if (!status.success && !_stop)
{ {
@ -162,8 +187,8 @@ namespace ix
connectErr.reason = status.errorStr; connectErr.reason = status.errorStr;
connectErr.http_status = status.http_status; connectErr.http_status = status.http_status;
_onMessageCallback(WebSocket_MessageType_Error, "", 0, _onMessageCallback(WebSocket_MessageType_Error, "", 0,
connectErr, WebSocketCloseInfo(), connectErr, WebSocketOpenInfo(),
WebSocketHttpHeaders()); WebSocketCloseInfo());
std::this_thread::sleep_for(duration); std::this_thread::sleep_for(duration);
} }
@ -218,11 +243,16 @@ namespace ix
webSocketErrorInfo.decompressionError = decompressionError; webSocketErrorInfo.decompressionError = decompressionError;
_onMessageCallback(webSocketMessageType, msg, wireSize, _onMessageCallback(webSocketMessageType, msg, wireSize,
webSocketErrorInfo, WebSocketCloseInfo(), webSocketErrorInfo, WebSocketOpenInfo(),
WebSocketHttpHeaders()); WebSocketCloseInfo());
WebSocket::invokeTrafficTrackerCallback(msg.size(), true); WebSocket::invokeTrafficTrackerCallback(msg.size(), true);
}); });
// 4. In blocking mode, getting out of this function is triggered by
// an explicit disconnection from the callback, or by the remote end
// closing the connection, ie isConnected() == false.
if (!_thread.joinable() && !isConnected() && !_automaticReconnection) return;
} }
} }
@ -314,4 +344,14 @@ namespace ix
case WebSocket_ReadyState_Closed: return "CLOSED"; case WebSocket_ReadyState_Closed: return "CLOSED";
} }
} }
void WebSocket::enableAutomaticReconnection()
{
_automaticReconnection = true;
}
void WebSocket::disableAutomaticReconnection()
{
_automaticReconnection = false;
}
} }

View File

@ -41,6 +41,20 @@ namespace ix
WebSocket_MessageType_Pong = 5 WebSocket_MessageType_Pong = 5
}; };
struct WebSocketOpenInfo
{
std::string uri;
WebSocketHttpHeaders headers;
WebSocketOpenInfo(const std::string& u = std::string(),
const WebSocketHttpHeaders& h = WebSocketHttpHeaders())
: uri(u)
, headers(h)
{
;
}
};
struct WebSocketCloseInfo struct WebSocketCloseInfo
{ {
uint16_t code; uint16_t code;
@ -59,8 +73,9 @@ namespace ix
const std::string&, const std::string&,
size_t wireSize, size_t wireSize,
const WebSocketErrorInfo&, const WebSocketErrorInfo&,
const WebSocketCloseInfo&, const WebSocketOpenInfo&,
const WebSocketHttpHeaders&)>; const WebSocketCloseInfo&)>;
using OnTrafficTrackerCallback = std::function<void(size_t size, bool incoming)>; using OnTrafficTrackerCallback = std::function<void(size_t size, bool incoming)>;
class WebSocket class WebSocket
@ -71,9 +86,16 @@ namespace ix
void setUrl(const std::string& url); void setUrl(const std::string& url);
void setPerMessageDeflateOptions(const WebSocketPerMessageDeflateOptions& perMessageDeflateOptions); void setPerMessageDeflateOptions(const WebSocketPerMessageDeflateOptions& perMessageDeflateOptions);
void setHandshakeTimeout(int _handshakeTimeoutSecs);
// Run asynchronously, by calling start and stop.
void start(); void start();
void stop(); void stop();
// Run in blocking mode, by connecting first manually, and then calling run.
WebSocketInitResult connect(int timeoutSecs);
void run();
WebSocketSendInfo send(const std::string& text); WebSocketSendInfo send(const std::string& text);
WebSocketSendInfo ping(const std::string& text); WebSocketSendInfo ping(const std::string& text);
void close(); void close();
@ -86,18 +108,23 @@ namespace ix
const std::string& getUrl() const; const std::string& getUrl() const;
const WebSocketPerMessageDeflateOptions& getPerMessageDeflateOptions() const; const WebSocketPerMessageDeflateOptions& getPerMessageDeflateOptions() const;
void enableAutomaticReconnection();
void disableAutomaticReconnection();
private: private:
void run();
WebSocketSendInfo sendMessage(const std::string& text, bool ping); WebSocketSendInfo sendMessage(const std::string& text, bool ping);
WebSocketInitResult connect();
bool isConnected() const; bool isConnected() const;
bool isClosing() const; bool isClosing() const;
void reconnectPerpetuallyIfDisconnected(); void reconnectPerpetuallyIfDisconnected();
std::string readyStateToString(ReadyState readyState); std::string readyStateToString(ReadyState readyState);
static void invokeTrafficTrackerCallback(size_t size, bool incoming); static void invokeTrafficTrackerCallback(size_t size, bool incoming);
// Server
void setSocketFileDescriptor(int fd);
WebSocketInitResult connectToSocket(int fd, int timeoutSecs);
WebSocketTransport _ws; WebSocketTransport _ws;
std::string _url; std::string _url;
@ -111,5 +138,10 @@ namespace ix
std::atomic<bool> _automaticReconnection; std::atomic<bool> _automaticReconnection;
std::thread _thread; std::thread _thread;
std::mutex _writeMutex; std::mutex _writeMutex;
std::atomic<int> _handshakeTimeoutSecs;
static const int kDefaultHandShakeTimeoutSecs;
friend class WebSocketServer;
}; };
} }

View File

@ -0,0 +1,496 @@
/*
* IXWebSocketHandshake.h
* Author: Benjamin Sergeant
* Copyright (c) 2019 Machine Zone, Inc. All rights reserved.
*/
#include "IXWebSocketHandshake.h"
#include "IXSocketConnect.h"
#include "libwshandshake.hpp"
#include <iostream>
#include <sstream>
#include <regex>
#include <random>
#include <algorithm>
namespace ix
{
WebSocketHandshake::WebSocketHandshake(std::atomic<bool>& requestInitCancellation,
std::shared_ptr<Socket> socket,
WebSocketPerMessageDeflate& perMessageDeflate,
WebSocketPerMessageDeflateOptions& perMessageDeflateOptions,
std::atomic<bool>& enablePerMessageDeflate) :
_requestInitCancellation(requestInitCancellation),
_socket(socket),
_perMessageDeflate(perMessageDeflate),
_perMessageDeflateOptions(perMessageDeflateOptions),
_enablePerMessageDeflate(enablePerMessageDeflate)
{
}
bool WebSocketHandshake::parseUrl(const std::string& url,
std::string& protocol,
std::string& host,
std::string& path,
std::string& query,
int& port)
{
std::regex ex("(ws|wss)://([^/ :]+):?([^/ ]*)(/?[^ #?]*)\\x3f?([^ #]*)#?([^ ]*)");
std::cmatch what;
if (!regex_match(url.c_str(), what, ex))
{
return false;
}
std::string portStr;
protocol = std::string(what[1].first, what[1].second);
host = std::string(what[2].first, what[2].second);
portStr = std::string(what[3].first, what[3].second);
path = std::string(what[4].first, what[4].second);
query = std::string(what[5].first, what[5].second);
if (portStr.empty())
{
if (protocol == "ws")
{
port = 80;
}
else if (protocol == "wss")
{
port = 443;
}
else
{
// Invalid protocol. Should be caught by regex check
// but this missing branch trigger cpplint linter.
return false;
}
}
else
{
std::stringstream ss;
ss << portStr;
ss >> port;
}
if (path.empty())
{
path = "/";
}
else if (path[0] != '/')
{
path = '/' + path;
}
if (!query.empty())
{
path += "?";
path += query;
}
return true;
}
void WebSocketHandshake::printUrl(const std::string& url)
{
std::string protocol, host, path, query;
int port {0};
if (!WebSocketHandshake::parseUrl(url, protocol, host,
path, query, port))
{
return;
}
std::cout << "[" << url << "]" << std::endl;
std::cout << protocol << std::endl;
std::cout << host << std::endl;
std::cout << port << std::endl;
std::cout << path << std::endl;
std::cout << query << std::endl;
std::cout << "-------------------------------" << std::endl;
}
std::string trim(const std::string& str)
{
std::string out(str);
out.erase(std::remove(out.begin(), out.end(), ' '), out.end());
out.erase(std::remove(out.begin(), out.end(), '\r'), out.end());
out.erase(std::remove(out.begin(), out.end(), '\n'), out.end());
return out;
}
std::tuple<std::string, std::string, std::string> WebSocketHandshake::parseRequestLine(const std::string& line)
{
// Request-Line = Method SP Request-URI SP HTTP-Version CRLF
std::string token;
std::stringstream tokenStream(line);
std::vector<std::string> tokens;
// Split by ' '
while (std::getline(tokenStream, token, ' '))
{
tokens.push_back(token);
}
std::string method;
if (tokens.size() >= 1)
{
method = trim(tokens[0]);
}
std::string requestUri;
if (tokens.size() >= 2)
{
requestUri = trim(tokens[1]);
}
std::string httpVersion;
if (tokens.size() >= 3)
{
httpVersion = trim(tokens[2]);
}
return std::make_tuple(method, requestUri, httpVersion);
}
std::string WebSocketHandshake::genRandomString(const int len)
{
std::string alphanum =
"0123456789"
"ABCDEFGH"
"abcdefgh";
std::random_device r;
std::default_random_engine e1(r());
std::uniform_int_distribution<int> dist(0, (int) alphanum.size() - 1);
std::string s;
s.resize(len);
for (int i = 0; i < len; ++i)
{
int x = dist(e1);
s[i] = alphanum[x];
}
return s;
}
std::pair<bool, WebSocketHttpHeaders> WebSocketHandshake::parseHttpHeaders(
const CancellationRequest& isCancellationRequested)
{
WebSocketHttpHeaders headers;
char line[256];
int i;
while (true)
{
int colon = 0;
for (i = 0;
i < 2 || (i < 255 && line[i-2] != '\r' && line[i-1] != '\n');
++i)
{
if (!_socket->readByte(line+i, isCancellationRequested))
{
return std::make_pair(false, headers);
}
if (line[i] == ':' && colon == 0)
{
colon = i;
}
}
if (line[0] == '\r' && line[1] == '\n')
{
break;
}
// line is a single header entry. split by ':', and add it to our
// header map. ignore lines with no colon.
if (colon > 0)
{
line[i] = '\0';
std::string lineStr(line);
// colon is ':', colon+1 is ' ', colon+2 is the start of the value.
// i is end of string (\0), i-colon is length of string minus key;
// subtract 1 for '\0', 1 for '\n', 1 for '\r',
// 1 for the ' ' after the ':', and total is -4
std::string name(lineStr.substr(0, colon));
std::string value(lineStr.substr(colon + 2, i - colon - 4));
// Make the name lower case.
std::transform(name.begin(), name.end(), name.begin(), ::tolower);
headers[name] = value;
}
}
return std::make_pair(true, headers);
}
WebSocketInitResult WebSocketHandshake::sendErrorResponse(int code, const std::string& reason)
{
std::stringstream ss;
ss << "HTTP/1.1 ";
ss << code;
ss << "\r\n";
ss << reason;
ss << "\r\n";
// Socket write can only be cancelled through a timeout here, not manually.
static std::atomic<bool> requestInitCancellation(false);
auto isCancellationRequested =
makeCancellationRequestWithTimeout(1, requestInitCancellation);
if (!_socket->writeBytes(ss.str(), isCancellationRequested))
{
return WebSocketInitResult(false, 500, "Timed out while sending error response");
}
return WebSocketInitResult(false, code, reason);
}
WebSocketInitResult WebSocketHandshake::clientHandshake(const std::string& url,
const std::string& host,
const std::string& path,
int port,
int timeoutSecs)
{
_requestInitCancellation = false;
auto isCancellationRequested =
makeCancellationRequestWithTimeout(timeoutSecs, _requestInitCancellation);
std::string errMsg;
bool success = _socket->connect(host, port, errMsg, isCancellationRequested);
if (!success)
{
std::stringstream ss;
ss << "Unable to connect to " << host
<< " on port " << port
<< ", error: " << errMsg;
return WebSocketInitResult(false, 0, ss.str());
}
//
// Generate a random 24 bytes string which looks like it is base64 encoded
// y3JJHMbDL1EzLkh9GBhXDw==
// 0cb3Vd9HkbpVVumoS3Noka==
//
// See https://stackoverflow.com/questions/18265128/what-is-sec-websocket-key-for
//
std::string secWebSocketKey = genRandomString(22);
secWebSocketKey += "==";
std::stringstream ss;
ss << "GET " << path << " HTTP/1.1\r\n";
ss << "Host: "<< host << ":" << port << "\r\n";
ss << "Upgrade: websocket\r\n";
ss << "Connection: Upgrade\r\n";
ss << "Sec-WebSocket-Version: 13\r\n";
ss << "Sec-WebSocket-Key: " << secWebSocketKey << "\r\n";
if (_enablePerMessageDeflate)
{
ss << _perMessageDeflateOptions.generateHeader();
}
ss << "\r\n";
if (!_socket->writeBytes(ss.str(), isCancellationRequested))
{
return WebSocketInitResult(false, 0, std::string("Failed sending GET request to ") + url);
}
// Read HTTP status line
auto lineResult = _socket->readLine(isCancellationRequested);
auto lineValid = lineResult.first;
auto line = lineResult.second;
if (!lineValid)
{
return WebSocketInitResult(false, 0,
std::string("Failed reading HTTP status line from ") + url);
}
// Validate status
int status;
// HTTP/1.0 is too old.
if (sscanf(line.c_str(), "HTTP/1.0 %d", &status) == 1)
{
std::stringstream ss;
ss << "Server version is HTTP/1.0. Rejecting connection to " << host
<< ", status: " << status
<< ", HTTP Status line: " << line;
return WebSocketInitResult(false, status, ss.str());
}
// We want an 101 HTTP status
if (sscanf(line.c_str(), "HTTP/1.1 %d", &status) != 1 || status != 101)
{
std::stringstream ss;
ss << "Got bad status connecting to " << host
<< ", status: " << status
<< ", HTTP Status line: " << line;
return WebSocketInitResult(false, status, ss.str());
}
auto result = parseHttpHeaders(isCancellationRequested);
auto headersValid = result.first;
auto headers = result.second;
if (!headersValid)
{
return WebSocketInitResult(false, status, "Error parsing HTTP headers");
}
char output[29] = {};
WebSocketHandshakeKeyGen::generate(secWebSocketKey.c_str(), output);
if (std::string(output) != headers["sec-websocket-accept"])
{
std::string errorMsg("Invalid Sec-WebSocket-Accept value");
return WebSocketInitResult(false, status, errorMsg);
}
if (_enablePerMessageDeflate)
{
// Parse the server response. Does it support deflate ?
std::string header = headers["sec-websocket-extensions"];
WebSocketPerMessageDeflateOptions webSocketPerMessageDeflateOptions(header);
// If the server does not support that extension, disable it.
if (!webSocketPerMessageDeflateOptions.enabled())
{
_enablePerMessageDeflate = false;
}
// Otherwise try to initialize the deflate engine (zlib)
else if (!_perMessageDeflate.init(webSocketPerMessageDeflateOptions))
{
return WebSocketInitResult(
false, 0,"Failed to initialize per message deflate engine");
}
}
return WebSocketInitResult(true, status, "", headers, path);
}
WebSocketInitResult WebSocketHandshake::serverHandshake(int fd, int timeoutSecs)
{
_requestInitCancellation = false;
// Set the socket to non blocking mode + other tweaks
SocketConnect::configure(fd);
auto isCancellationRequested =
makeCancellationRequestWithTimeout(timeoutSecs, _requestInitCancellation);
std::string remote = std::string("remote fd ") + std::to_string(fd);
// Read first line
auto lineResult = _socket->readLine(isCancellationRequested);
auto lineValid = lineResult.first;
auto line = lineResult.second;
if (!lineValid)
{
return sendErrorResponse(400, "Error reading HTTP request line");
}
// Validate request line (GET /foo HTTP/1.1\r\n)
auto requestLine = parseRequestLine(line);
auto method = std::get<0>(requestLine);
auto uri = std::get<1>(requestLine);
auto httpVersion = std::get<2>(requestLine);
if (method != "GET")
{
return sendErrorResponse(400, "Invalid HTTP method, need GET, got " + method);
}
if (httpVersion != "HTTP/1.1")
{
return sendErrorResponse(400, "Invalid HTTP version, need HTTP/1.1, got: " + httpVersion);
}
// Retrieve and validate HTTP headers
auto result = parseHttpHeaders(isCancellationRequested);
auto headersValid = result.first;
auto headers = result.second;
if (!headersValid)
{
return sendErrorResponse(400, "Error parsing HTTP headers");
}
if (headers.find("sec-websocket-key") == headers.end())
{
return sendErrorResponse(400, "Missing Sec-WebSocket-Key value");
}
if (headers["upgrade"] != "websocket")
{
return sendErrorResponse(400, "Invalid or missing Upgrade header");
}
if (headers.find("sec-websocket-version") == headers.end())
{
return sendErrorResponse(400, "Missing Sec-WebSocket-Version value");
}
{
std::stringstream ss;
ss << headers["sec-websocket-version"];
int version;
ss >> version;
if (version != 13)
{
return sendErrorResponse(400, "Invalid Sec-WebSocket-Version, "
"need 13, got" + ss.str());
}
}
char output[29] = {};
WebSocketHandshakeKeyGen::generate(headers["sec-websocket-key"].c_str(), output);
std::stringstream ss;
ss << "HTTP/1.1 101\r\n";
ss << "Sec-WebSocket-Accept: " << std::string(output) << "\r\n";
// Parse the client headers. Does it support deflate ?
std::string header = headers["sec-websocket-extensions"];
WebSocketPerMessageDeflateOptions webSocketPerMessageDeflateOptions(header);
// If the client has requested that extension, enable it.
if (webSocketPerMessageDeflateOptions.enabled())
{
_enablePerMessageDeflate = true;
if (!_perMessageDeflate.init(webSocketPerMessageDeflateOptions))
{
return WebSocketInitResult(
false, 0,"Failed to initialize per message deflate engine");
}
ss << webSocketPerMessageDeflateOptions.generateHeader();
}
ss << "\r\n";
if (!_socket->writeBytes(ss.str(), isCancellationRequested))
{
return WebSocketInitResult(false, 0, std::string("Failed sending response to ") + remote);
}
return WebSocketInitResult(true, 200, "", headers, uri);
}
}

View File

@ -0,0 +1,85 @@
/*
* IXWebSocketHandshake.h
* Author: Benjamin Sergeant
* Copyright (c) 2019 Machine Zone, Inc. All rights reserved.
*/
#pragma once
#include "IXCancellationRequest.h"
#include "IXWebSocketHttpHeaders.h"
#include "IXWebSocketPerMessageDeflate.h"
#include "IXWebSocketPerMessageDeflateOptions.h"
#include "IXSocket.h"
#include <string>
#include <atomic>
#include <chrono>
#include <memory>
#include <tuple>
namespace ix
{
struct WebSocketInitResult
{
bool success;
int http_status;
std::string errorStr;
WebSocketHttpHeaders headers;
std::string uri;
WebSocketInitResult(bool s = false,
int status = 0,
const std::string& e = std::string(),
WebSocketHttpHeaders h = WebSocketHttpHeaders(),
const std::string& u = std::string())
{
success = s;
http_status = status;
errorStr = e;
headers = h;
uri = u;
}
};
class WebSocketHandshake {
public:
WebSocketHandshake(std::atomic<bool>& requestInitCancellation,
std::shared_ptr<Socket> _socket,
WebSocketPerMessageDeflate& perMessageDeflate,
WebSocketPerMessageDeflateOptions& perMessageDeflateOptions,
std::atomic<bool>& enablePerMessageDeflate);
WebSocketInitResult clientHandshake(const std::string& url,
const std::string& host,
const std::string& path,
int port,
int timeoutSecs);
WebSocketInitResult serverHandshake(int fd,
int timeoutSecs);
static bool parseUrl(const std::string& url,
std::string& protocol,
std::string& host,
std::string& path,
std::string& query,
int& port);
private:
static void printUrl(const std::string& url);
std::string genRandomString(const int len);
// Parse HTTP headers
std::pair<bool, WebSocketHttpHeaders> parseHttpHeaders(const CancellationRequest& isCancellationRequested);
WebSocketInitResult sendErrorResponse(int code, const std::string& reason);
std::tuple<std::string, std::string, std::string> parseRequestLine(const std::string& line);
std::atomic<bool>& _requestInitCancellation;
std::shared_ptr<Socket> _socket;
WebSocketPerMessageDeflate& _perMessageDeflate;
WebSocketPerMessageDeflateOptions& _perMessageDeflateOptions;
std::atomic<bool>& _enablePerMessageDeflate;
};
}

View File

@ -0,0 +1,281 @@
/*
* IXWebSocketServer.cpp
* Author: Benjamin Sergeant
* Copyright (c) 2018 Machine Zone, Inc. All rights reserved.
*/
#include "IXWebSocketServer.h"
#include "IXWebSocketTransport.h"
#include "IXWebSocket.h"
#include "IXSocketConnect.h"
#include <sstream>
#include <future>
#include <netdb.h>
#include <stdio.h>
#include <arpa/inet.h>
#include <sys/socket.h>
#include <string.h>
namespace ix
{
const int WebSocketServer::kDefaultPort(8080);
const std::string WebSocketServer::kDefaultHost("127.0.0.1");
const int WebSocketServer::kDefaultTcpBacklog(5);
const size_t WebSocketServer::kDefaultMaxConnections(32);
const int WebSocketServer::kDefaultHandShakeTimeoutSecs(3); // 3 seconds
WebSocketServer::WebSocketServer(int port,
const std::string& host,
int backlog,
size_t maxConnections,
int handshakeTimeoutSecs) :
_port(port),
_host(host),
_backlog(backlog),
_maxConnections(maxConnections),
_handshakeTimeoutSecs(handshakeTimeoutSecs),
_stop(false)
{
}
WebSocketServer::~WebSocketServer()
{
stop();
}
void WebSocketServer::setOnConnectionCallback(const OnConnectionCallback& callback)
{
_onConnectionCallback = callback;
}
void WebSocketServer::logError(const std::string& str)
{
std::lock_guard<std::mutex> lock(_logMutex);
std::cerr << str << std::endl;
}
void WebSocketServer::logInfo(const std::string& str)
{
std::lock_guard<std::mutex> lock(_logMutex);
std::cout << str << std::endl;
}
std::pair<bool, std::string> WebSocketServer::listen()
{
struct sockaddr_in server; // server address information
// Get a socket for accepting connections.
if ((_serverFd = socket(AF_INET, SOCK_STREAM, 0)) < 0)
{
std::stringstream ss;
ss << "WebSocketServer::listen() error creating socket): "
<< strerror(errno);
return std::make_pair(false, ss.str());
}
// Make that socket reusable. (allow restarting this server at will)
int enable = 1;
if (setsockopt(_serverFd, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int)) < 0)
{
std::stringstream ss;
ss << "WebSocketServer::listen() error calling setsockopt(SO_REUSEADDR): "
<< strerror(errno);
return std::make_pair(false, ss.str());
}
// Bind the socket to the server address.
server.sin_family = AF_INET;
server.sin_port = htons(_port);
// Using INADDR_ANY trigger a pop-up box as binding to any address is detected
// by the osx firewall. We need to codesign the binary with a self-signed cert
// to allow that, but this is a bit of a pain. (this is what node or python would do).
//
// Using INADDR_LOOPBACK also does not work ... while it should.
// We default to 127.0.0.1 (localhost)
//
server.sin_addr.s_addr = inet_addr(_host.c_str());
if (bind(_serverFd, (struct sockaddr *)&server, sizeof(server)) < 0)
{
std::stringstream ss;
ss << "WebSocketServer::listen() error calling bind: "
<< strerror(errno);
return std::make_pair(false, ss.str());
}
/*
* Listen for connections. Specify the tcp backlog.
*/
if (::listen(_serverFd, _backlog) != 0)
{
std::stringstream ss;
ss << "WebSocketServer::listen() error calling listen: "
<< strerror(errno);
return std::make_pair(false, ss.str());
}
return std::make_pair(true, "");
}
void WebSocketServer::start()
{
if (_thread.joinable()) return; // we've already been started
_thread = std::thread(&WebSocketServer::run, this);
}
void WebSocketServer::wait()
{
std::unique_lock<std::mutex> lock(_conditionVariableMutex);
_conditionVariable.wait(lock);
}
void WebSocketServer::stop()
{
if (!_thread.joinable()) return; // nothing to do
auto clients = getClients();
for (auto client : clients)
{
client->close();
}
_stop = true;
_thread.join();
_stop = false;
_conditionVariable.notify_one();
}
void WebSocketServer::run()
{
// Set the socket to non blocking mode, so that accept calls are not blocking
SocketConnect::configure(_serverFd);
// Return value of std::async, ignored
std::future<void> f;
// Select arguments
fd_set rfds;
struct timeval timeout;
timeout.tv_sec = 0;
timeout.tv_usec = 10 * 1000; // 10ms
for (;;)
{
if (_stop) return;
FD_ZERO(&rfds);
FD_SET(_serverFd, &rfds);
select(_serverFd + 1, &rfds, nullptr, nullptr, &timeout);
if (!FD_ISSET(_serverFd, &rfds))
{
// We reached the select timeout, and no new connections are pending
continue;
}
// Accept a connection.
struct sockaddr_in client; // client address information
int clientFd; // socket connected to client
socklen_t addressLen = sizeof(socklen_t);
memset(&client, 0, sizeof(client));
if ((clientFd = accept(_serverFd, (struct sockaddr *)&client, &addressLen)) < 0)
{
if (errno != EWOULDBLOCK)
{
// FIXME: that error should be propagated
std::stringstream ss;
ss << "WebSocketServer::run() error accepting connection: "
<< strerror(errno);
logError(ss.str());
}
continue;
}
if (getConnectedClientsCount() >= _maxConnections)
{
std::stringstream ss;
ss << "WebSocketServer::run() reached max connections = "
<< _maxConnections << ". "
<< "Not accepting connection";
logError(ss.str());
::close(clientFd);
continue;
}
// Launch the handleConnection work asynchronously in its own thread.
//
// the destructor of a future returned by std::async blocks,
// so we need to declare it outside of this loop
f = std::async(std::launch::async,
&WebSocketServer::handleConnection,
this,
clientFd);
}
}
void WebSocketServer::handleConnection(int fd)
{
std::shared_ptr<WebSocket> webSocket(new WebSocket);
_onConnectionCallback(webSocket);
webSocket->disableAutomaticReconnection();
// Add this client to our client set
{
std::lock_guard<std::mutex> lock(_clientsMutex);
_clients.insert(webSocket);
}
auto status = webSocket->connectToSocket(fd, _handshakeTimeoutSecs);
if (status.success)
{
// Process incoming messages and execute callbacks
// until the connection is closed
webSocket->run();
}
else
{
std::stringstream ss;
ss << "WebSocketServer::handleConnection() error: "
<< status.http_status
<< " error: "
<< status.errorStr;
logError(ss.str());
}
// Remove this client from our client set
{
std::lock_guard<std::mutex> lock(_clientsMutex);
if (_clients.erase(webSocket) != 1)
{
logError("Cannot delete client");
}
}
logInfo("WebSocketServer::handleConnection() done");
}
std::set<std::shared_ptr<WebSocket>> WebSocketServer::getClients()
{
std::lock_guard<std::mutex> lock(_clientsMutex);
return _clients;
}
size_t WebSocketServer::getConnectedClientsCount()
{
return getClients().size();
}
}

View File

@ -0,0 +1,82 @@
/*
* IXWebSocketServer.h
* Author: Benjamin Sergeant
* Copyright (c) 2018 Machine Zone, Inc. All rights reserved.
*/
#pragma once
#include <utility> // pair
#include <string>
#include <set>
#include <thread>
#include <mutex>
#include <functional>
#include <memory>
#include <condition_variable>
#include "IXWebSocket.h"
namespace ix
{
using OnConnectionCallback = std::function<void(std::shared_ptr<WebSocket>)>;
class WebSocketServer {
public:
WebSocketServer(int port = WebSocketServer::kDefaultPort,
const std::string& host = WebSocketServer::kDefaultHost,
int backlog = WebSocketServer::kDefaultTcpBacklog,
size_t maxConnections = WebSocketServer::kDefaultMaxConnections,
int handshakeTimeoutSecs = WebSocketServer::kDefaultHandShakeTimeoutSecs);
virtual ~WebSocketServer();
void setOnConnectionCallback(const OnConnectionCallback& callback);
void start();
void wait();
void stop();
std::pair<bool, std::string> listen();
// Get all the connected clients
std::set<std::shared_ptr<WebSocket>> getClients();
private:
// Member variables
int _port;
std::string _host;
int _backlog;
size_t _maxConnections;
int _handshakeTimeoutSecs;
OnConnectionCallback _onConnectionCallback;
// socket for accepting connections
int _serverFd;
std::mutex _clientsMutex;
std::set<std::shared_ptr<WebSocket>> _clients;
std::mutex _logMutex;
std::atomic<bool> _stop;
std::thread _thread;
std::condition_variable _conditionVariable;
std::mutex _conditionVariableMutex;
const static int kDefaultPort;
const static std::string kDefaultHost;
const static int kDefaultTcpBacklog;
const static size_t kDefaultMaxConnections;
const static int kDefaultHandShakeTimeoutSecs;
// Methods
void run();
void handleConnection(int fd);
size_t getConnectedClientsCount();
// Logging
void logError(const std::string& str);
void logInfo(const std::string& str);
};
}

View File

@ -9,9 +9,9 @@
// //
#include "IXWebSocketTransport.h" #include "IXWebSocketTransport.h"
#include "IXWebSocketHandshake.h"
#include "IXWebSocketHttpHeaders.h" #include "IXWebSocketHttpHeaders.h"
#include "IXSocket.h"
#ifdef IXWEBSOCKET_USE_TLS #ifdef IXWEBSOCKET_USE_TLS
# ifdef __APPLE__ # ifdef __APPLE__
# include "IXSocketAppleSSL.h" # include "IXSocketAppleSSL.h"
@ -20,8 +20,6 @@
# endif # endif
#endif #endif
#include "libwshandshake.hpp"
#include <string.h> #include <string.h>
#include <stdlib.h> #include <stdlib.h>
@ -31,9 +29,6 @@
#include <cstdarg> #include <cstdarg>
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include <regex>
#include <random>
#include <algorithm>
namespace ix namespace ix
@ -53,133 +48,24 @@ namespace ix
; ;
} }
void WebSocketTransport::configure(const std::string& url, void WebSocketTransport::configure(const WebSocketPerMessageDeflateOptions& perMessageDeflateOptions)
const WebSocketPerMessageDeflateOptions& perMessageDeflateOptions)
{ {
_url = url;
_perMessageDeflateOptions = perMessageDeflateOptions; _perMessageDeflateOptions = perMessageDeflateOptions;
_enablePerMessageDeflate = _perMessageDeflateOptions.enabled(); _enablePerMessageDeflate = _perMessageDeflateOptions.enabled();
} }
bool WebSocketTransport::parseUrl(const std::string& url, // Client
std::string& protocol, WebSocketInitResult WebSocketTransport::connectToUrl(const std::string& url,
std::string& host, int timeoutSecs)
std::string& path,
std::string& query,
int& port)
{
std::regex ex("(ws|wss)://([^/ :]+):?([^/ ]*)(/?[^ #?]*)\\x3f?([^ #]*)#?([^ ]*)");
std::cmatch what;
if (!regex_match(url.c_str(), what, ex))
{
return false;
}
std::string portStr;
protocol = std::string(what[1].first, what[1].second);
host = std::string(what[2].first, what[2].second);
portStr = std::string(what[3].first, what[3].second);
path = std::string(what[4].first, what[4].second);
query = std::string(what[5].first, what[5].second);
if (portStr.empty())
{
if (protocol == "ws")
{
port = 80;
}
else if (protocol == "wss")
{
port = 443;
}
else
{
// Invalid protocol. Should be caught by regex check
// but this missing branch trigger cpplint linter.
return false;
}
}
else
{
std::stringstream ss;
ss << portStr;
ss >> port;
}
if (path.empty())
{
path = "/";
}
else if (path[0] != '/')
{
path = '/' + path;
}
if (!query.empty())
{
path += "?";
path += query;
}
return true;
}
void WebSocketTransport::printUrl(const std::string& url)
{
std::string protocol, host, path, query;
int port {0};
if (!WebSocketTransport::parseUrl(url, protocol, host,
path, query, port))
{
return;
}
std::cout << "[" << url << "]" << std::endl;
std::cout << protocol << std::endl;
std::cout << host << std::endl;
std::cout << port << std::endl;
std::cout << path << std::endl;
std::cout << query << std::endl;
std::cout << "-------------------------------" << std::endl;
}
std::string WebSocketTransport::genRandomString(const int len)
{
std::string alphanum =
"0123456789"
"ABCDEFGH"
"abcdefgh";
std::random_device r;
std::default_random_engine e1(r());
std::uniform_int_distribution<int> dist(0, (int) alphanum.size() - 1);
std::string s;
s.resize(len);
for (int i = 0; i < len; ++i)
{
int x = dist(e1);
s[i] = alphanum[x];
}
return s;
}
WebSocketInitResult WebSocketTransport::init()
{ {
std::string protocol, host, path, query; std::string protocol, host, path, query;
int port; int port;
_requestInitCancellation = false; if (!WebSocketHandshake::parseUrl(url, protocol, host,
if (!WebSocketTransport::parseUrl(_url, protocol, host,
path, query, port)) path, query, port))
{ {
return WebSocketInitResult(false, 0, return WebSocketInitResult(false, 0,
std::string("Could not parse URL ") + _url); std::string("Could not parse URL ") + url);
} }
if (protocol == "wss") if (protocol == "wss")
@ -201,165 +87,39 @@ namespace ix
_socket = std::make_shared<Socket>(); _socket = std::make_shared<Socket>();
} }
std::string errMsg; WebSocketHandshake webSocketHandshake(_requestInitCancellation,
bool success = _socket->connect(host, port, errMsg, _socket,
[this]() -> bool _perMessageDeflate,
{ _perMessageDeflateOptions,
return _requestInitCancellation; _enablePerMessageDeflate);
}
); auto result = webSocketHandshake.clientHandshake(url, host, path, port,
if (!success) timeoutSecs);
if (result.success)
{ {
std::stringstream ss; setReadyState(OPEN);
ss << "Unable to connect to " << host
<< " on port " << port
<< ", error: " << errMsg;
return WebSocketInitResult(false, 0, ss.str());
} }
return result;
}
// // Server
// Generate a random 24 bytes string which looks like it is base64 encoded WebSocketInitResult WebSocketTransport::connectToSocket(int fd, int timeoutSecs)
// y3JJHMbDL1EzLkh9GBhXDw== {
// 0cb3Vd9HkbpVVumoS3Noka== _socket.reset();
// _socket = std::make_shared<Socket>(fd);
// See https://stackoverflow.com/questions/18265128/what-is-sec-websocket-key-for
//
std::string secWebSocketKey = genRandomString(22);
secWebSocketKey += "==";
std::stringstream ss; WebSocketHandshake webSocketHandshake(_requestInitCancellation,
ss << "GET " << path << " HTTP/1.1\r\n"; _socket,
ss << "Host: "<< host << ":" << port << "\r\n"; _perMessageDeflate,
ss << "Upgrade: websocket\r\n"; _perMessageDeflateOptions,
ss << "Connection: Upgrade\r\n"; _enablePerMessageDeflate);
ss << "Sec-WebSocket-Version: 13\r\n";
ss << "Sec-WebSocket-Key: " << secWebSocketKey << "\r\n";
if (_enablePerMessageDeflate) auto result = webSocketHandshake.serverHandshake(fd, timeoutSecs);
if (result.success)
{ {
ss << _perMessageDeflateOptions.generateHeader(); setReadyState(OPEN);
} }
return result;
ss << "\r\n";
if (!writeBytes(ss.str()))
{
return WebSocketInitResult(false, 0, std::string("Failed sending GET request to ") + _url);
}
char line[256];
int i;
for (i = 0; i < 2 || (i < 255 && line[i-2] != '\r' && line[i-1] != '\n'); ++i)
{
if (!readByte(line+i))
{
return WebSocketInitResult(false, 0, std::string("Failed reading HTTP status line from ") + _url);
}
}
line[i] = 0;
if (i == 255)
{
return WebSocketInitResult(false, 0, std::string("Got bad status line connecting to ") + _url);
}
// Validate status
int status;
// HTTP/1.0 is too old.
if (sscanf(line, "HTTP/1.0 %d", &status) == 1)
{
std::stringstream ss;
ss << "Server version is HTTP/1.0. Rejecting connection to " << host
<< ", status: " << status
<< ", HTTP Status line: " << line;
return WebSocketInitResult(false, status, ss.str());
}
// We want an 101 HTTP status
if (sscanf(line, "HTTP/1.1 %d", &status) != 1 || status != 101)
{
std::stringstream ss;
ss << "Got bad status connecting to " << host
<< ", status: " << status
<< ", HTTP Status line: " << line;
return WebSocketInitResult(false, status, ss.str());
}
WebSocketHttpHeaders headers;
while (true)
{
int colon = 0;
for (i = 0;
i < 2 || (i < 255 && line[i-2] != '\r' && line[i-1] != '\n');
++i)
{
if (!readByte(line+i))
{
return WebSocketInitResult(false, status, std::string("Failed reading response header from ") + _url);
}
if (line[i] == ':' && colon == 0)
{
colon = i;
}
}
if (line[0] == '\r' && line[1] == '\n')
{
break;
}
// line is a single header entry. split by ':', and add it to our
// header map. ignore lines with no colon.
if (colon > 0)
{
line[i] = '\0';
std::string lineStr(line);
// colon is ':', colon+1 is ' ', colon+2 is the start of the value.
// i is end of string (\0), i-colon is length of string minus key;
// subtract 1 for '\0', 1 for '\n', 1 for '\r',
// 1 for the ' ' after the ':', and total is -4
std::string name(lineStr.substr(0, colon));
std::string value(lineStr.substr(colon + 2, i - colon - 4));
// Make the name lower case.
std::transform(name.begin(), name.end(), name.begin(), ::tolower);
headers[name] = value;
}
}
char output[29] = {};
WebSocketHandshake::generate(secWebSocketKey.c_str(), output);
if (std::string(output) != headers["sec-websocket-accept"])
{
std::string errorMsg("Invalid Sec-WebSocket-Accept value");
return WebSocketInitResult(false, status, errorMsg);
}
if (_enablePerMessageDeflate)
{
// Parse the server response. Does it support deflate ?
std::string header = headers["sec-websocket-extensions"];
WebSocketPerMessageDeflateOptions webSocketPerMessageDeflateOptions(header);
// If the server does not support that extension, disable it.
if (!webSocketPerMessageDeflateOptions.enabled())
{
_enablePerMessageDeflate = false;
}
if (!_perMessageDeflate.init(webSocketPerMessageDeflateOptions))
{
return WebSocketInitResult(
false, 0,"Failed to initialize per message deflate engine");
}
}
setReadyState(OPEN);
return WebSocketInitResult(true, status, "", headers);
} }
WebSocketTransport::ReadyStateValues WebSocketTransport::getReadyState() const WebSocketTransport::ReadyStateValues WebSocketTransport::getReadyState() const
@ -822,62 +582,4 @@ namespace ix
_socket->close(); _socket->close();
} }
bool WebSocketTransport::readByte(void* buffer)
{
while (true)
{
if (_readyState == CLOSING) return false;
int ret;
ret = _socket->recv(buffer, 1);
// We read one byte, as needed, all good.
if (ret == 1)
{
return true;
}
// There is possibly something to be read, try again
else if (ret < 0 && (_socket->getErrno() == EWOULDBLOCK ||
_socket->getErrno() == EAGAIN))
{
continue;
}
// There was an error during the read, abort
else
{
return false;
}
}
}
bool WebSocketTransport::writeBytes(const std::string& str)
{
while (true)
{
if (_readyState == CLOSING) return false;
char* buffer = const_cast<char*>(str.c_str());
int len = (int) str.size();
int ret = _socket->send(buffer, len);
// We wrote some bytes, as needed, all good.
if (ret > 0)
{
return ret == len;
}
// There is possibly something to be write, try again
else if (ret < 0 && (_socket->getErrno() == EWOULDBLOCK ||
_socket->getErrno() == EAGAIN))
{
continue;
}
// There was an error during the write, abort
else
{
return false;
}
}
}
} // namespace ix } // namespace ix

View File

@ -21,30 +21,13 @@
#include "IXWebSocketPerMessageDeflate.h" #include "IXWebSocketPerMessageDeflate.h"
#include "IXWebSocketPerMessageDeflateOptions.h" #include "IXWebSocketPerMessageDeflateOptions.h"
#include "IXWebSocketHttpHeaders.h" #include "IXWebSocketHttpHeaders.h"
#include "IXCancellationRequest.h"
#include "IXWebSocketHandshake.h"
namespace ix namespace ix
{ {
class Socket; class Socket;
struct WebSocketInitResult
{
bool success;
int http_status;
std::string errorStr;
WebSocketHttpHeaders headers;
WebSocketInitResult(bool s = false,
int status = 0,
const std::string& e = std::string(),
WebSocketHttpHeaders h = WebSocketHttpHeaders())
{
success = s;
http_status = status;
errorStr = e;
headers = h;
}
};
class WebSocketTransport class WebSocketTransport
{ {
public: public:
@ -74,9 +57,12 @@ namespace ix
WebSocketTransport(); WebSocketTransport();
~WebSocketTransport(); ~WebSocketTransport();
void configure(const std::string& url, void configure(const WebSocketPerMessageDeflateOptions& perMessageDeflateOptions);
const WebSocketPerMessageDeflateOptions& perMessageDeflateOptions);
WebSocketInitResult init(); WebSocketInitResult connectToUrl(const std::string& url, // Client
int timeoutSecs);
WebSocketInitResult connectToSocket(int fd, // Server
int timeoutSecs);
void poll(); void poll();
WebSocketSendInfo sendBinary(const std::string& message); WebSocketSendInfo sendBinary(const std::string& message);
@ -87,14 +73,6 @@ namespace ix
void setOnCloseCallback(const OnCloseCallback& onCloseCallback); void setOnCloseCallback(const OnCloseCallback& onCloseCallback);
void dispatch(const OnMessageCallback& onMessageCallback); void dispatch(const OnMessageCallback& onMessageCallback);
static void printUrl(const std::string& url);
static bool parseUrl(const std::string& url,
std::string& protocol,
std::string& host,
std::string& path,
std::string& query,
int& port);
private: private:
std::string _url; std::string _url;
std::string _origin; std::string _origin;
@ -159,10 +137,5 @@ namespace ix
unsigned getRandomUnsigned(); unsigned getRandomUnsigned();
void unmaskReceiveBuffer(const wsheader_type& ws); void unmaskReceiveBuffer(const wsheader_type& ws);
std::string genRandomString(const int len);
// Non blocking versions of read/write, used during http upgrade
bool readByte(void* buffer);
bool writeBytes(const std::string& str);
}; };
} }

View File

@ -0,0 +1,20 @@
/*
* IXSetThreadName_apple.cpp
* Author: Benjamin Sergeant
* Copyright (c) 2018 Machine Zone, Inc. All rights reserved.
*/
#include "../IXSetThreadName.h"
#include <pthread.h>
namespace ix
{
void setThreadName(const std::string& name)
{
//
// Apple reserves 16 bytes for its thread names
// Notice that the Apple version of pthread_setname_np
// does not take a pthread_t argument
//
pthread_setname_np(name.substr(0, 63).c_str());
}
}

View File

@ -21,7 +21,7 @@
#include <cstdint> #include <cstdint>
#include <cstddef> #include <cstddef>
class WebSocketHandshake { class WebSocketHandshakeKeyGen {
template <int N, typename T> template <int N, typename T>
struct static_for { struct static_for {
void operator()(uint32_t *a, uint32_t *b) { void operator()(uint32_t *a, uint32_t *b) {

View File

@ -0,0 +1,21 @@
/*
* IXSetThreadName_linux.cpp
* Author: Benjamin Sergeant
* Copyright (c) 2018 Machine Zone, Inc. All rights reserved.
*/
#include "../IXSetThreadName.h"
#include <pthread.h>
namespace ix
{
void setThreadName(const std::string& name)
{
//
// Linux only reserves 16 bytes for its thread names
// See prctl and PR_SET_NAME property in
// http://man7.org/linux/man-pages/man2/prctl.2.html
//
pthread_setname_np(pthread_self(),
name.substr(0, 15).c_str());
}
}

View File

@ -15,10 +15,17 @@ build:
(cd examples/chat ; mkdir -p build ; cd build ; cmake .. ; make) (cd examples/chat ; mkdir -p build ; cd build ; cmake .. ; make)
(cd examples/ping_pong ; mkdir -p build ; cd build ; cmake .. ; make) (cd examples/ping_pong ; mkdir -p build ; cd build ; cmake .. ; make)
(cd examples/ws_connect ; mkdir -p build ; cd build ; cmake .. ; make) (cd examples/ws_connect ; mkdir -p build ; cd build ; cmake .. ; make)
(cd examples/echo_server ; mkdir -p build ; cd build ; cmake .. ; make)
(cd examples/broadcast_server ; mkdir -p build ; cd build ; cmake .. ; make)
# That target is used to start a node server, but isn't required as we have
# a builtin C++ server started in the unittest now
test_server: test_server:
(cd test && npm i ws && node broadcast-server.js) (cd test && npm i ws && node broadcast-server.js)
# env TEST=Websocket_server make test
test: test:
(cd test && cmake . && make && ./ixwebsocket_unittest) (cd test && sh run.sh)
.PHONY: test .PHONY: test
.PHONY: build

1
test/.gitignore vendored
View File

@ -6,3 +6,4 @@ cmake_install.cmake
node_modules node_modules
ixwebsocket ixwebsocket
Makefile Makefile
build

View File

@ -19,6 +19,7 @@ include_directories(
add_executable(ixwebsocket_unittest add_executable(ixwebsocket_unittest
test_runner.cpp test_runner.cpp
cmd_websocket_chat.cpp cmd_websocket_chat.cpp
IXWebSocketServerTest.cpp
IXTest.cpp IXTest.cpp
msgpack11.cpp msgpack11.cpp
) )

View File

@ -0,0 +1,180 @@
/*
* IXWebSocketServerTest.cpp
* Author: Benjamin Sergeant
* Copyright (c) 2019 Machine Zone. All rights reserved.
*/
#include <iostream>
#include <ixwebsocket/IXSocket.h>
#include <ixwebsocket/IXWebSocket.h>
#include <ixwebsocket/IXWebSocketServer.h>
#include "IXTest.h"
#include "catch.hpp"
using namespace ix;
namespace ix
{
bool startServer(ix::WebSocketServer& server)
{
server.setOnConnectionCallback(
[&server](std::shared_ptr<ix::WebSocket> webSocket)
{
webSocket->setOnMessageCallback(
[webSocket, &server](ix::WebSocketMessageType messageType,
const std::string& str,
size_t wireSize,
const ix::WebSocketErrorInfo& error,
const ix::WebSocketOpenInfo& openInfo,
const ix::WebSocketCloseInfo& closeInfo)
{
if (messageType == ix::WebSocket_MessageType_Open)
{
std::cerr << "New connection" << std::endl;
std::cerr << "Uri: " << openInfo.uri << std::endl;
std::cerr << "Headers:" << std::endl;
for (auto it : openInfo.headers)
{
std::cerr << it.first << ": " << it.second << std::endl;
}
}
else if (messageType == ix::WebSocket_MessageType_Close)
{
std::cerr << "Closed connection" << std::endl;
}
else if (messageType == ix::WebSocket_MessageType_Message)
{
for (auto&& client : server.getClients())
{
if (client != webSocket)
{
client->send(str);
}
}
}
}
);
}
);
auto res = server.listen();
if (!res.first)
{
std::cerr << res.second << std::endl;
return false;
}
server.start();
return true;
}
}
TEST_CASE("Websocket_server", "[websocket_server]")
{
SECTION("Connect to the server, do not send anything. Should timeout and return 400")
{
int port = 8091;
ix::WebSocketServer server(port);
REQUIRE(startServer(server));
Socket socket;
std::string host("localhost");
std::string errMsg;
auto isCancellationRequested = []() -> bool
{
return false;
};
bool success = socket.connect(host, port, errMsg, isCancellationRequested);
REQUIRE(success);
auto lineResult = socket.readLine(isCancellationRequested);
auto lineValid = lineResult.first;
auto line = lineResult.second;
int status = -1;
REQUIRE(sscanf(line.c_str(), "HTTP/1.1 %d", &status) == 1);
REQUIRE(status == 400);
// FIXME: explicitely set a client timeout larger than the server one (3)
// Give us 500ms for the server to notice that clients went away
ix::msleep(500);
server.stop();
REQUIRE(server.getClients().size() == 0);
}
SECTION("Connect to the server. Send GET request without header. Should return 400")
{
int port = 8092;
ix::WebSocketServer server(port);
REQUIRE(startServer(server));
Socket socket;
std::string host("localhost");
std::string errMsg;
auto isCancellationRequested = []() -> bool
{
return false;
};
bool success = socket.connect(host, port, errMsg, isCancellationRequested);
REQUIRE(success);
std::cout << "writeBytes" << std::endl;
socket.writeBytes("GET /\r\n", isCancellationRequested);
auto lineResult = socket.readLine(isCancellationRequested);
auto lineValid = lineResult.first;
auto line = lineResult.second;
int status = -1;
REQUIRE(sscanf(line.c_str(), "HTTP/1.1 %d", &status) == 1);
REQUIRE(status == 400);
// FIXME: explicitely set a client timeout larger than the server one (3)
// Give us 500ms for the server to notice that clients went away
ix::msleep(500);
server.stop();
REQUIRE(server.getClients().size() == 0);
}
SECTION("Connect to the server. Send GET request with correct header")
{
int port = 8093;
ix::WebSocketServer server(port);
REQUIRE(startServer(server));
Socket socket;
std::string host("localhost");
std::string errMsg;
auto isCancellationRequested = []() -> bool
{
return false;
};
bool success = socket.connect(host, port, errMsg, isCancellationRequested);
REQUIRE(success);
socket.writeBytes("GET / HTTP/1.1\r\n"
"Upgrade: websocket\r\n"
"Sec-WebSocket-Version: 13\r\n"
"Sec-WebSocket-Key: foobar\r\n"
"\r\n",
isCancellationRequested);
auto lineResult = socket.readLine(isCancellationRequested);
auto lineValid = lineResult.first;
auto line = lineResult.second;
int status = -1;
REQUIRE(sscanf(line.c_str(), "HTTP/1.1 %d", &status) == 1);
REQUIRE(status == 101);
// Give us 500ms for the server to notice that clients went away
ix::msleep(500);
server.stop();
REQUIRE(server.getClients().size() == 0);
}
}

30
test/build_linux.sh Normal file
View File

@ -0,0 +1,30 @@
#!/bin/sh
#
# Author: Benjamin Sergeant
# Copyright (c) 2017-2018 Machine Zone, Inc. All rights reserved.
#
# 'manual' way of building. You can also use cmake.
g++ --std=c++11 \
-DIXWEBSOCKET_USE_TLS \
-g \
../ixwebsocket/IXEventFd.cpp \
../ixwebsocket/IXSocket.cpp \
../ixwebsocket/IXSetThreadName.cpp \
../ixwebsocket/IXWebSocketTransport.cpp \
../ixwebsocket/IXWebSocket.cpp \
../ixwebsocket/IXWebSocketServer.cpp \
../ixwebsocket/IXDNSLookup.cpp \
../ixwebsocket/IXSocketConnect.cpp \
../ixwebsocket/IXSocketOpenSSL.cpp \
../ixwebsocket/IXWebSocketPerMessageDeflate.cpp \
../ixwebsocket/IXWebSocketPerMessageDeflateOptions.cpp \
-I ../.. \
-I Catch2/single_include \
test_runner.cpp \
cmd_websocket_chat.cpp \
IXTest.cpp \
msgpack11.cpp \
-o ixwebsocket_unittest \
-lcrypto -lssl -lz -lpthread

View File

@ -13,6 +13,7 @@
#include <sstream> #include <sstream>
#include <queue> #include <queue>
#include <ixwebsocket/IXWebSocket.h> #include <ixwebsocket/IXWebSocket.h>
#include <ixwebsocket/IXWebSocketServer.h>
#include "msgpack11.hpp" #include "msgpack11.hpp"
#include "IXTest.h" #include "IXTest.h"
@ -82,7 +83,7 @@ namespace
void WebSocketChat::start() void WebSocketChat::start()
{ {
std::string url("ws://localhost:8080/"); std::string url("ws://localhost:8090/");
_webSocket.setUrl(url); _webSocket.setUrl(url);
std::stringstream ss; std::stringstream ss;
@ -93,8 +94,8 @@ namespace
const std::string& str, const std::string& str,
size_t wireSize, size_t wireSize,
const ix::WebSocketErrorInfo& error, const ix::WebSocketErrorInfo& error,
const ix::WebSocketCloseInfo& closeInfo, const ix::WebSocketOpenInfo& openInfo,
const ix::WebSocketHttpHeaders& headers) const ix::WebSocketCloseInfo& closeInfo)
{ {
std::stringstream ss; std::stringstream ss;
if (messageType == ix::WebSocket_MessageType_Open) if (messageType == ix::WebSocket_MessageType_Open)
@ -171,14 +172,71 @@ namespace
{ {
_webSocket.send(encodeMessage(text)); _webSocket.send(encodeMessage(text));
} }
bool startServer(ix::WebSocketServer& server)
{
server.setOnConnectionCallback(
[&server](std::shared_ptr<ix::WebSocket> webSocket)
{
webSocket->setOnMessageCallback(
[webSocket, &server](ix::WebSocketMessageType messageType,
const std::string& str,
size_t wireSize,
const ix::WebSocketErrorInfo& error,
const ix::WebSocketOpenInfo& openInfo,
const ix::WebSocketCloseInfo& closeInfo)
{
if (messageType == ix::WebSocket_MessageType_Open)
{
std::cerr << "New connection" << std::endl;
std::cerr << "Uri: " << openInfo.uri << std::endl;
std::cerr << "Headers:" << std::endl;
for (auto it : openInfo.headers)
{
std::cerr << it.first << ": " << it.second << std::endl;
}
}
else if (messageType == ix::WebSocket_MessageType_Close)
{
std::cerr << "Closed connection" << std::endl;
}
else if (messageType == ix::WebSocket_MessageType_Message)
{
for (auto&& client : server.getClients())
{
if (client != webSocket)
{
client->send(str);
}
}
}
}
);
}
);
auto res = server.listen();
if (!res.first)
{
std::cerr << res.second << std::endl;
return false;
}
server.start();
return true;
}
} }
TEST_CASE("Websocket chat", "[websocket_chat]") TEST_CASE("Websocket_chat", "[websocket_chat]")
{ {
SECTION("Exchange and count sent/received messages.") SECTION("Exchange and count sent/received messages.")
{ {
ix::setupWebSocketTrafficTrackerCallback(); ix::setupWebSocketTrafficTrackerCallback();
int port = 8090;
ix::WebSocketServer server(port);
REQUIRE(startServer(server));
std::string session = ix::generateSessionId(); std::string session = ix::generateSessionId();
WebSocketChat chatA("jean", session); WebSocketChat chatA("jean", session);
WebSocketChat chatB("paul", session); WebSocketChat chatB("paul", session);
@ -193,6 +251,8 @@ TEST_CASE("Websocket chat", "[websocket_chat]")
ix::msleep(10); ix::msleep(10);
} }
REQUIRE(server.getClients().size() == 2);
// Add a bit of extra time, for the subscription to be active // Add a bit of extra time, for the subscription to be active
ix::msleep(200); ix::msleep(200);
@ -212,6 +272,10 @@ TEST_CASE("Websocket chat", "[websocket_chat]")
REQUIRE(chatA.getReceivedMessagesCount() == 2); REQUIRE(chatA.getReceivedMessagesCount() == 2);
REQUIRE(chatB.getReceivedMessagesCount() == 3); REQUIRE(chatB.getReceivedMessagesCount() == 3);
// Give us 500ms for the server to notice that clients went away
ix::msleep(500);
REQUIRE(server.getClients().size() == 0);
ix::reportWebSocketTraffic(); ix::reportWebSocketTraffic();
} }
} }

8
test/run.sh Normal file
View File

@ -0,0 +1,8 @@
#!/bin/sh
mkdir build
cd build
cmake -DCMAKE_BUILD_TYPE=Debug .. || exit 1
make || exit 1
./ixwebsocket_unittest ${TEST}