diff --git a/CMakeLists.txt b/CMakeLists.txt index 3274de27..13e22143 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -20,6 +20,7 @@ set( IXWEBSOCKET_SOURCES ixwebsocket/IXWebSocket.cpp ixwebsocket/IXWebSocketServer.cpp ixwebsocket/IXWebSocketTransport.cpp + ixwebsocket/IXWebSocketHandshake.cpp ixwebsocket/IXWebSocketPerMessageDeflate.cpp ixwebsocket/IXWebSocketPerMessageDeflateOptions.cpp ) @@ -34,6 +35,7 @@ set( IXWEBSOCKET_HEADERS ixwebsocket/IXWebSocket.h ixwebsocket/IXWebSocketServer.h ixwebsocket/IXWebSocketTransport.h + ixwebsocket/IXWebSocketHandshake.h ixwebsocket/IXWebSocketSendInfo.h ixwebsocket/IXWebSocketErrorInfo.h ixwebsocket/IXWebSocketPerMessageDeflate.h diff --git a/ixwebsocket/IXSocket.h b/ixwebsocket/IXSocket.h index 3454c4a1..7cdc4ea9 100644 --- a/ixwebsocket/IXSocket.h +++ b/ixwebsocket/IXSocket.h @@ -57,8 +57,5 @@ namespace ix std::atomic _sockfd; std::mutex _socketMutex; EventFd _eventfd; - - private: }; - } diff --git a/ixwebsocket/IXSocketConnect.cpp b/ixwebsocket/IXSocketConnect.cpp index ee974121..fc005e46 100644 --- a/ixwebsocket/IXSocketConnect.cpp +++ b/ixwebsocket/IXSocketConnect.cpp @@ -53,20 +53,19 @@ namespace ix // This is important so that we don't block the main UI thread when shutting down a connection which is // already trying to reconnect, and can be blocked waiting for ::connect to respond. // - bool SocketConnect::connectToAddress(const struct addrinfo *address, - int& sockfd, - std::string& errMsg, - const CancellationRequest& isCancellationRequested) + int SocketConnect::connectToAddress(const struct addrinfo *address, + std::string& errMsg, + const CancellationRequest& isCancellationRequested) { - sockfd = -1; - + errMsg = "no error"; + int fd = socket(address->ai_family, address->ai_socktype, address->ai_protocol); if (fd < 0) { errMsg = "Cannot create a socket"; - return false; + return -1; } // Set the socket to non blocking mode, so that slow responses cannot @@ -78,7 +77,7 @@ namespace ix { closeSocket(fd); errMsg = strerror(errno); - return false; + return -1; } // @@ -127,19 +126,18 @@ namespace ix { closeSocket(fd); errMsg = strerror(optval); - return false; + return -1; } else { // Success ! - sockfd = fd; - return true; + return fd; } } closeSocket(fd); errMsg = "connect timed out after 60 seconds"; - return false; + return -1; } int SocketConnect::connect(const std::string& hostname, @@ -161,14 +159,13 @@ namespace ix // iterate through the records to find a working peer struct addrinfo *address; - bool success = false; for (address = res; address != nullptr; address = address->ai_next) { // // Second try to connect to the remote host // - success = connectToAddress(address, sockfd, errMsg, isCancellationRequested); - if (success) + sockfd = connectToAddress(address, errMsg, isCancellationRequested); + if (sockfd != -1) { break; } diff --git a/ixwebsocket/IXSocketConnect.h b/ixwebsocket/IXSocketConnect.h index 4c7c7d4e..61de312d 100644 --- a/ixwebsocket/IXSocketConnect.h +++ b/ixwebsocket/IXSocketConnect.h @@ -24,10 +24,9 @@ namespace ix static void configure(int sockfd); private: - static bool connectToAddress(const struct addrinfo *address, - int& sockfd, - std::string& errMsg, - const CancellationRequest& isCancellationRequested); + static int connectToAddress(const struct addrinfo *address, + std::string& errMsg, + const CancellationRequest& isCancellationRequested); }; } diff --git a/ixwebsocket/IXWebSocket.cpp b/ixwebsocket/IXWebSocket.cpp index 12d9e98c..f5849e36 100644 --- a/ixwebsocket/IXWebSocket.cpp +++ b/ixwebsocket/IXWebSocket.cpp @@ -6,6 +6,7 @@ #include "IXWebSocket.h" #include "IXSetThreadName.h" +#include "IXWebSocketHandshake.h" #include #include diff --git a/ixwebsocket/IXWebSocketHandshake.cpp b/ixwebsocket/IXWebSocketHandshake.cpp new file mode 100644 index 00000000..980e7fe4 --- /dev/null +++ b/ixwebsocket/IXWebSocketHandshake.cpp @@ -0,0 +1,448 @@ +/* + * IXWebSocketHandshake.h + * Author: Benjamin Sergeant + * Copyright (c) 2019 Machine Zone, Inc. All rights reserved. + */ + +#include "IXWebSocketHandshake.h" +#include "IXSocketConnect.h" + +#include "libwshandshake.hpp" + +#include +#include +#include +#include +#include + + +namespace ix +{ + WebSocketHandshake::WebSocketHandshake(std::atomic& requestInitCancellation, + std::shared_ptr socket, + WebSocketPerMessageDeflate& perMessageDeflate, + WebSocketPerMessageDeflateOptions& perMessageDeflateOptions, + std::atomic& enablePerMessageDeflate) : + _requestInitCancellation(requestInitCancellation), + _socket(socket), + _perMessageDeflate(perMessageDeflate), + _perMessageDeflateOptions(perMessageDeflateOptions), + _enablePerMessageDeflate(enablePerMessageDeflate) + { + + } + + bool WebSocketHandshake::parseUrl(const std::string& url, + std::string& protocol, + std::string& host, + std::string& path, + std::string& query, + int& port) + { + std::regex ex("(ws|wss)://([^/ :]+):?([^/ ]*)(/?[^ #?]*)\\x3f?([^ #]*)#?([^ ]*)"); + std::cmatch what; + if (!regex_match(url.c_str(), what, ex)) + { + return false; + } + + std::string portStr; + + protocol = std::string(what[1].first, what[1].second); + host = std::string(what[2].first, what[2].second); + portStr = std::string(what[3].first, what[3].second); + path = std::string(what[4].first, what[4].second); + query = std::string(what[5].first, what[5].second); + + if (portStr.empty()) + { + if (protocol == "ws") + { + port = 80; + } + else if (protocol == "wss") + { + port = 443; + } + else + { + // Invalid protocol. Should be caught by regex check + // but this missing branch trigger cpplint linter. + return false; + } + } + else + { + std::stringstream ss; + ss << portStr; + ss >> port; + } + + if (path.empty()) + { + path = "/"; + } + else if (path[0] != '/') + { + path = '/' + path; + } + + if (!query.empty()) + { + path += "?"; + path += query; + } + + return true; + } + + void WebSocketHandshake::printUrl(const std::string& url) + { + std::string protocol, host, path, query; + int port {0}; + + if (!WebSocketHandshake::parseUrl(url, protocol, host, + path, query, port)) + { + return; + } + + std::cout << "[" << url << "]" << std::endl; + std::cout << protocol << std::endl; + std::cout << host << std::endl; + std::cout << port << std::endl; + std::cout << path << std::endl; + std::cout << query << std::endl; + std::cout << "-------------------------------" << std::endl; + } + + std::string WebSocketHandshake::genRandomString(const int len) + { + std::string alphanum = + "0123456789" + "ABCDEFGH" + "abcdefgh"; + + std::random_device r; + std::default_random_engine e1(r()); + std::uniform_int_distribution dist(0, (int) alphanum.size() - 1); + + std::string s; + s.resize(len); + + for (int i = 0; i < len; ++i) + { + int x = dist(e1); + s[i] = alphanum[x]; + } + + return s; + } + + + std::pair WebSocketHandshake::parseHttpHeaders( + const CancellationRequest& isCancellationRequested) + { + WebSocketHttpHeaders headers; + + char line[256]; + int i; + + while (true) + { + int colon = 0; + + for (i = 0; + i < 2 || (i < 255 && line[i-2] != '\r' && line[i-1] != '\n'); + ++i) + { + if (!_socket->readByte(line+i, isCancellationRequested)) + { + return std::make_pair(false, headers); + } + + if (line[i] == ':' && colon == 0) + { + colon = i; + } + } + if (line[0] == '\r' && line[1] == '\n') + { + break; + } + + // line is a single header entry. split by ':', and add it to our + // header map. ignore lines with no colon. + if (colon > 0) + { + line[i] = '\0'; + std::string lineStr(line); + // colon is ':', colon+1 is ' ', colon+2 is the start of the value. + // i is end of string (\0), i-colon is length of string minus key; + // subtract 1 for '\0', 1 for '\n', 1 for '\r', + // 1 for the ' ' after the ':', and total is -4 + std::string name(lineStr.substr(0, colon)); + std::string value(lineStr.substr(colon + 2, i - colon - 4)); + + // Make the name lower case. + std::transform(name.begin(), name.end(), name.begin(), ::tolower); + + headers[name] = value; + } + } + + return std::make_pair(true, headers); + } + + WebSocketInitResult WebSocketHandshake::sendErrorResponse(int code, const std::string& reason) + { + std::stringstream ss; + ss << "HTTP/1.1 "; + ss << code; + ss << "\r\n"; + ss << reason; + ss << "\r\n"; + + // FIXME refactoring + auto start = std::chrono::system_clock::now(); + auto timeout = std::chrono::seconds(1); + + auto isCancellationRequested = [start, timeout]() -> bool + { + auto now = std::chrono::system_clock::now(); + if ((now - start) > timeout) return true; + + // No cancellation request + return false; + }; + + if (!_socket->writeBytes(ss.str(), isCancellationRequested)) + { + return WebSocketInitResult(false, 500, "Timed out while sending error response"); + } + + return WebSocketInitResult(false, code, reason); + } + + WebSocketInitResult WebSocketHandshake::clientHandshake(const std::string& url, + const std::string& host, + const std::string& path, + int port) + { + _requestInitCancellation = false; + + // FIXME: timeout should be configurable + auto start = std::chrono::system_clock::now(); + auto timeout = std::chrono::seconds(10); + + auto isCancellationRequested = [this, start, timeout]() -> bool + { + // Was an explicit cancellation requested ? + if (_requestInitCancellation) return true; + + auto now = std::chrono::system_clock::now(); + if ((now - start) > timeout) return true; + + // No cancellation request + return false; + }; + + std::string errMsg; + bool success = _socket->connect(host, port, errMsg, isCancellationRequested); + if (!success) + { + std::stringstream ss; + ss << "Unable to connect to " << host + << " on port " << port + << ", error: " << errMsg; + return WebSocketInitResult(false, 0, ss.str()); + } + + // + // Generate a random 24 bytes string which looks like it is base64 encoded + // y3JJHMbDL1EzLkh9GBhXDw== + // 0cb3Vd9HkbpVVumoS3Noka== + // + // See https://stackoverflow.com/questions/18265128/what-is-sec-websocket-key-for + // + std::string secWebSocketKey = genRandomString(22); + secWebSocketKey += "=="; + + std::stringstream ss; + ss << "GET " << path << " HTTP/1.1\r\n"; + ss << "Host: "<< host << ":" << port << "\r\n"; + ss << "Upgrade: websocket\r\n"; + ss << "Connection: Upgrade\r\n"; + ss << "Sec-WebSocket-Version: 13\r\n"; + ss << "Sec-WebSocket-Key: " << secWebSocketKey << "\r\n"; + + if (_enablePerMessageDeflate) + { + ss << _perMessageDeflateOptions.generateHeader(); + } + + ss << "\r\n"; + + if (!_socket->writeBytes(ss.str(), isCancellationRequested)) + { + return WebSocketInitResult(false, 0, std::string("Failed sending GET request to ") + url); + } + + // Read HTTP status line + auto lineResult = _socket->readLine(isCancellationRequested); + auto lineValid = lineResult.first; + auto line = lineResult.second; + + if (!lineValid) + { + return WebSocketInitResult(false, 0, + std::string("Failed reading HTTP status line from ") + url); + } + + // Validate status + int status; + + // HTTP/1.0 is too old. + if (sscanf(line.c_str(), "HTTP/1.0 %d", &status) == 1) + { + std::stringstream ss; + ss << "Server version is HTTP/1.0. Rejecting connection to " << host + << ", status: " << status + << ", HTTP Status line: " << line; + return WebSocketInitResult(false, status, ss.str()); + } + + // We want an 101 HTTP status + if (sscanf(line.c_str(), "HTTP/1.1 %d", &status) != 1 || status != 101) + { + std::stringstream ss; + ss << "Got bad status connecting to " << host + << ", status: " << status + << ", HTTP Status line: " << line; + return WebSocketInitResult(false, status, ss.str()); + } + + auto result = parseHttpHeaders(isCancellationRequested); + auto headersValid = result.first; + auto headers = result.second; + + if (!headersValid) + { + return WebSocketInitResult(false, status, "Error parsing HTTP headers"); + } + + char output[29] = {}; + WebSocketHandshakeKeyGen::generate(secWebSocketKey.c_str(), output); + if (std::string(output) != headers["sec-websocket-accept"]) + { + std::string errorMsg("Invalid Sec-WebSocket-Accept value"); + return WebSocketInitResult(false, status, errorMsg); + } + + if (_enablePerMessageDeflate) + { + // Parse the server response. Does it support deflate ? + std::string header = headers["sec-websocket-extensions"]; + WebSocketPerMessageDeflateOptions webSocketPerMessageDeflateOptions(header); + + // If the server does not support that extension, disable it. + if (!webSocketPerMessageDeflateOptions.enabled()) + { + _enablePerMessageDeflate = false; + } + // Otherwise try to initialize the deflate engine (zlib) + else if (!_perMessageDeflate.init(webSocketPerMessageDeflateOptions)) + { + return WebSocketInitResult( + false, 0,"Failed to initialize per message deflate engine"); + } + } + + return WebSocketInitResult(true, status, "", headers); + } + + WebSocketInitResult WebSocketHandshake::serverHandshake(int fd) + { + _requestInitCancellation = false; + + // Set the socket to non blocking mode + other tweaks + SocketConnect::configure(fd); + + // FIXME: timeout should be configurable + auto start = std::chrono::system_clock::now(); + auto timeout = std::chrono::seconds(3); + + auto isCancellationRequested = [this, start, timeout]() -> bool + { + // Was an explicit cancellation requested ? + if (_requestInitCancellation) return true; + + auto now = std::chrono::system_clock::now(); + if ((now - start) > timeout) return true; + + // No cancellation request + return false; + }; + + std::string remote = std::string("remote fd ") + std::to_string(fd); + + // Read first line + auto lineResult = _socket->readLine(isCancellationRequested); + auto lineValid = lineResult.first; + auto line = lineResult.second; + + if (!lineValid) + { + return sendErrorResponse(400, "Error reading HTTP request line"); + } + + // FIXME: Validate line content (GET /) + + auto result = parseHttpHeaders(isCancellationRequested); + auto headersValid = result.first; + auto headers = result.second; + + if (!headersValid) + { + return sendErrorResponse(400, "Error parsing HTTP headers"); + } + + if (headers.find("sec-websocket-key") == headers.end()) + { + return sendErrorResponse(400, "Missing Sec-WebSocket-Key value"); + } + + char output[29] = {}; + WebSocketHandshakeKeyGen::generate(headers["sec-websocket-key"].c_str(), output); + + std::stringstream ss; + ss << "HTTP/1.1 101\r\n"; + ss << "Sec-WebSocket-Accept: " << std::string(output) << "\r\n"; + + // Parse the client headers. Does it support deflate ? + std::string header = headers["sec-websocket-extensions"]; + WebSocketPerMessageDeflateOptions webSocketPerMessageDeflateOptions(header); + + // If the client has requested that extension, enable it. + if (webSocketPerMessageDeflateOptions.enabled()) + { + _enablePerMessageDeflate = true; + + if (!_perMessageDeflate.init(webSocketPerMessageDeflateOptions)) + { + return WebSocketInitResult( + false, 0,"Failed to initialize per message deflate engine"); + } + ss << webSocketPerMessageDeflateOptions.generateHeader(); + } + + ss << "\r\n"; + + if (!_socket->writeBytes(ss.str(), isCancellationRequested)) + { + return WebSocketInitResult(false, 0, std::string("Failed sending response to ") + remote); + } + + return WebSocketInitResult(true, 200, "", headers); + } +} diff --git a/ixwebsocket/IXWebSocketHandshake.h b/ixwebsocket/IXWebSocketHandshake.h new file mode 100644 index 00000000..46116928 --- /dev/null +++ b/ixwebsocket/IXWebSocketHandshake.h @@ -0,0 +1,77 @@ +/* + * IXWebSocketHandshake.h + * Author: Benjamin Sergeant + * Copyright (c) 2019 Machine Zone, Inc. All rights reserved. + */ + +#pragma once + +#include "IXCancellationRequest.h" +#include "IXWebSocketHttpHeaders.h" +#include "IXWebSocketPerMessageDeflate.h" +#include "IXWebSocketPerMessageDeflateOptions.h" +#include "IXSocket.h" + +#include +#include +#include +#include + +namespace ix +{ + struct WebSocketInitResult + { + bool success; + int http_status; + std::string errorStr; + WebSocketHttpHeaders headers; + + WebSocketInitResult(bool s = false, + int status = 0, + const std::string& e = std::string(), + WebSocketHttpHeaders h = WebSocketHttpHeaders()) + { + success = s; + http_status = status; + errorStr = e; + headers = h; + } + }; + + class WebSocketHandshake { + public: + WebSocketHandshake(std::atomic& requestInitCancellation, + std::shared_ptr _socket, + WebSocketPerMessageDeflate& perMessageDeflate, + WebSocketPerMessageDeflateOptions& perMessageDeflateOptions, + std::atomic& enablePerMessageDeflate); + + WebSocketInitResult clientHandshake(const std::string& url, + const std::string& host, + const std::string& path, + int port); + WebSocketInitResult serverHandshake(int fd); + + static bool parseUrl(const std::string& url, + std::string& protocol, + std::string& host, + std::string& path, + std::string& query, + int& port); + + private: + static void printUrl(const std::string& url); + std::string genRandomString(const int len); + + // Parse HTTP headers + std::pair parseHttpHeaders(const CancellationRequest& isCancellationRequested); + WebSocketInitResult sendErrorResponse(int code, const std::string& reason); + + std::atomic& _requestInitCancellation; + std::shared_ptr _socket; + WebSocketPerMessageDeflate& _perMessageDeflate; + WebSocketPerMessageDeflateOptions& _perMessageDeflateOptions; + std::atomic& _enablePerMessageDeflate; + }; +} + diff --git a/ixwebsocket/IXWebSocketTransport.cpp b/ixwebsocket/IXWebSocketTransport.cpp index ff0d43a9..4d439a1f 100644 --- a/ixwebsocket/IXWebSocketTransport.cpp +++ b/ixwebsocket/IXWebSocketTransport.cpp @@ -9,10 +9,9 @@ // #include "IXWebSocketTransport.h" +#include "IXWebSocketHandshake.h" #include "IXWebSocketHttpHeaders.h" -#include "IXSocketConnect.h" -#include "IXSocket.h" #ifdef IXWEBSOCKET_USE_TLS # ifdef __APPLE__ # include "IXSocketAppleSSL.h" @@ -21,8 +20,6 @@ # endif #endif -#include "libwshandshake.hpp" - #include #include @@ -32,9 +29,6 @@ #include #include #include -#include -#include -#include namespace ix @@ -60,175 +54,13 @@ namespace ix _enablePerMessageDeflate = _perMessageDeflateOptions.enabled(); } - bool WebSocketTransport::parseUrl(const std::string& url, - std::string& protocol, - std::string& host, - std::string& path, - std::string& query, - int& port) - { - std::regex ex("(ws|wss)://([^/ :]+):?([^/ ]*)(/?[^ #?]*)\\x3f?([^ #]*)#?([^ ]*)"); - std::cmatch what; - if (!regex_match(url.c_str(), what, ex)) - { - return false; - } - - std::string portStr; - - protocol = std::string(what[1].first, what[1].second); - host = std::string(what[2].first, what[2].second); - portStr = std::string(what[3].first, what[3].second); - path = std::string(what[4].first, what[4].second); - query = std::string(what[5].first, what[5].second); - - if (portStr.empty()) - { - if (protocol == "ws") - { - port = 80; - } - else if (protocol == "wss") - { - port = 443; - } - else - { - // Invalid protocol. Should be caught by regex check - // but this missing branch trigger cpplint linter. - return false; - } - } - else - { - std::stringstream ss; - ss << portStr; - ss >> port; - } - - if (path.empty()) - { - path = "/"; - } - else if (path[0] != '/') - { - path = '/' + path; - } - - if (!query.empty()) - { - path += "?"; - path += query; - } - - return true; - } - - void WebSocketTransport::printUrl(const std::string& url) - { - std::string protocol, host, path, query; - int port {0}; - - if (!WebSocketTransport::parseUrl(url, protocol, host, - path, query, port)) - { - return; - } - - std::cout << "[" << url << "]" << std::endl; - std::cout << protocol << std::endl; - std::cout << host << std::endl; - std::cout << port << std::endl; - std::cout << path << std::endl; - std::cout << query << std::endl; - std::cout << "-------------------------------" << std::endl; - } - - std::pair WebSocketTransport::parseHttpHeaders(const CancellationRequest& isCancellationRequested) - { - WebSocketHttpHeaders headers; - - char line[256]; - int i; - - while (true) - { - int colon = 0; - - for (i = 0; - i < 2 || (i < 255 && line[i-2] != '\r' && line[i-1] != '\n'); - ++i) - { - if (!_socket->readByte(line+i, isCancellationRequested)) - { - return std::make_pair(false, headers); - } - - if (line[i] == ':' && colon == 0) - { - colon = i; - } - } - if (line[0] == '\r' && line[1] == '\n') - { - break; - } - - // line is a single header entry. split by ':', and add it to our - // header map. ignore lines with no colon. - if (colon > 0) - { - line[i] = '\0'; - std::string lineStr(line); - // colon is ':', colon+1 is ' ', colon+2 is the start of the value. - // i is end of string (\0), i-colon is length of string minus key; - // subtract 1 for '\0', 1 for '\n', 1 for '\r', - // 1 for the ' ' after the ':', and total is -4 - std::string name(lineStr.substr(0, colon)); - std::string value(lineStr.substr(colon + 2, i - colon - 4)); - - // Make the name lower case. - std::transform(name.begin(), name.end(), name.begin(), ::tolower); - - headers[name] = value; - } - } - - return std::make_pair(true, headers); - } - - std::string WebSocketTransport::genRandomString(const int len) - { - std::string alphanum = - "0123456789" - "ABCDEFGH" - "abcdefgh"; - - std::random_device r; - std::default_random_engine e1(r()); - std::uniform_int_distribution dist(0, (int) alphanum.size() - 1); - - std::string s; - s.resize(len); - - for (int i = 0; i < len; ++i) - { - int x = dist(e1); - s[i] = alphanum[x]; - } - - return s; - } - // Client WebSocketInitResult WebSocketTransport::connectToUrl(const std::string& url) { std::string protocol, host, path, query; int port; - _requestInitCancellation = false; - - if (!WebSocketTransport::parseUrl(url, protocol, host, + if (!WebSocketHandshake::parseUrl(url, protocol, host, path, query, port)) { return WebSocketInitResult(false, 0, @@ -254,243 +86,38 @@ namespace ix _socket = std::make_shared(); } - // FIXME: timeout should be configurable - auto start = std::chrono::system_clock::now(); - auto timeout = std::chrono::seconds(10); + WebSocketHandshake webSocketHandshake(_requestInitCancellation, + _socket, + _perMessageDeflate, + _perMessageDeflateOptions, + _enablePerMessageDeflate); - auto isCancellationRequested = [this, start, timeout]() -> bool + auto result = webSocketHandshake.clientHandshake(url, host, path, port); + if (result.success) { - // Was an explicit cancellation requested ? - if (_requestInitCancellation) return true; - - auto now = std::chrono::system_clock::now(); - if ((now - start) > timeout) return true; - - // No cancellation request - return false; - }; - - std::string errMsg; - bool success = _socket->connect(host, port, errMsg, isCancellationRequested); - if (!success) - { - std::stringstream ss; - ss << "Unable to connect to " << host - << " on port " << port - << ", error: " << errMsg; - return WebSocketInitResult(false, 0, ss.str()); + setReadyState(OPEN); } - - // - // Generate a random 24 bytes string which looks like it is base64 encoded - // y3JJHMbDL1EzLkh9GBhXDw== - // 0cb3Vd9HkbpVVumoS3Noka== - // - // See https://stackoverflow.com/questions/18265128/what-is-sec-websocket-key-for - // - std::string secWebSocketKey = genRandomString(22); - secWebSocketKey += "=="; - - std::stringstream ss; - ss << "GET " << path << " HTTP/1.1\r\n"; - ss << "Host: "<< host << ":" << port << "\r\n"; - ss << "Upgrade: websocket\r\n"; - ss << "Connection: Upgrade\r\n"; - ss << "Sec-WebSocket-Version: 13\r\n"; - ss << "Sec-WebSocket-Key: " << secWebSocketKey << "\r\n"; - - if (_enablePerMessageDeflate) - { - ss << _perMessageDeflateOptions.generateHeader(); - } - - ss << "\r\n"; - - if (!_socket->writeBytes(ss.str(), isCancellationRequested)) - { - return WebSocketInitResult(false, 0, std::string("Failed sending GET request to ") + url); - } - - // Read HTTP status line - auto lineResult = _socket->readLine(isCancellationRequested); - auto lineValid = lineResult.first; - auto line = lineResult.second; - - if (!lineValid) - { - return WebSocketInitResult(false, 0, - std::string("Failed reading HTTP status line from ") + url); - } - - // Validate status - int status; - - // HTTP/1.0 is too old. - if (sscanf(line.c_str(), "HTTP/1.0 %d", &status) == 1) - { - std::stringstream ss; - ss << "Server version is HTTP/1.0. Rejecting connection to " << host - << ", status: " << status - << ", HTTP Status line: " << line; - return WebSocketInitResult(false, status, ss.str()); - } - - // We want an 101 HTTP status - if (sscanf(line.c_str(), "HTTP/1.1 %d", &status) != 1 || status != 101) - { - std::stringstream ss; - ss << "Got bad status connecting to " << host - << ", status: " << status - << ", HTTP Status line: " << line; - return WebSocketInitResult(false, status, ss.str()); - } - - auto result = parseHttpHeaders(isCancellationRequested); - auto headersValid = result.first; - auto headers = result.second; - - if (!headersValid) - { - return WebSocketInitResult(false, status, "Error parsing HTTP headers"); - } - - char output[29] = {}; - WebSocketHandshake::generate(secWebSocketKey.c_str(), output); - if (std::string(output) != headers["sec-websocket-accept"]) - { - std::string errorMsg("Invalid Sec-WebSocket-Accept value"); - return WebSocketInitResult(false, status, errorMsg); - } - - if (_enablePerMessageDeflate) - { - // Parse the server response. Does it support deflate ? - std::string header = headers["sec-websocket-extensions"]; - WebSocketPerMessageDeflateOptions webSocketPerMessageDeflateOptions(header); - - // If the server does not support that extension, disable it. - if (!webSocketPerMessageDeflateOptions.enabled()) - { - _enablePerMessageDeflate = false; - } - // Otherwise try to initialize the deflate engine (zlib) - else if (!_perMessageDeflate.init(webSocketPerMessageDeflateOptions)) - { - return WebSocketInitResult( - false, 0,"Failed to initialize per message deflate engine"); - } - } - - setReadyState(OPEN); - - return WebSocketInitResult(true, status, "", headers); - } - - WebSocketInitResult WebSocketTransport::sendErrorResponse(int code, std::string reason) - { - std::stringstream ss; - ss << "HTTP/1.1 "; - ss << code; - ss << "\r\n"; - ss << reason; - ss << "\r\n"; - - auto isCancellationRequested = [this]() -> bool - { - return _requestInitCancellation; - }; - - if (!_socket->writeBytes(ss.str(), isCancellationRequested)) - { - return WebSocketInitResult(false, 500, "Failed sending response"); - } - - return WebSocketInitResult(false, code, reason); + return result; } // Server WebSocketInitResult WebSocketTransport::connectToSocket(int fd) { - _requestInitCancellation = false; - - // Set the socket to non blocking mode + other tweaks - SocketConnect::configure(fd); - _socket.reset(); _socket = std::make_shared(fd); - // FIXME: timeout should be configurable - auto start = std::chrono::system_clock::now(); - auto timeout = std::chrono::seconds(3); + WebSocketHandshake webSocketHandshake(_requestInitCancellation, + _socket, + _perMessageDeflate, + _perMessageDeflateOptions, + _enablePerMessageDeflate); - auto isCancellationRequested = [this, start, timeout]() -> bool + auto result = webSocketHandshake.serverHandshake(fd); + if (result.success) { - // Was an explicit cancellation requested ? - if (_requestInitCancellation) return true; - - auto now = std::chrono::system_clock::now(); - if ((now - start) > timeout) return true; - - // No cancellation request - return false; - }; - - std::string remote = std::string("remote fd ") + std::to_string(fd); - - // Read first line - auto lineResult = _socket->readLine(isCancellationRequested); - auto lineValid = lineResult.first; - auto line = lineResult.second; - - // FIXME: Validate line content (GET /) - - auto result = parseHttpHeaders(isCancellationRequested); - auto headersValid = result.first; - auto headers = result.second; - - if (!headersValid) - { - return sendErrorResponse(400, "Error parsing HTTP headers"); + setReadyState(OPEN); } - - if (headers.find("sec-websocket-key") == headers.end()) - { - return sendErrorResponse(400, "Missing Sec-WebSocket-Key value"); - } - - char output[29] = {}; - WebSocketHandshake::generate(headers["sec-websocket-key"].c_str(), output); - - std::stringstream ss; - ss << "HTTP/1.1 101\r\n"; - ss << "Sec-WebSocket-Accept: " << std::string(output) << "\r\n"; - - // Parse the client headers. Does it support deflate ? - std::string header = headers["sec-websocket-extensions"]; - WebSocketPerMessageDeflateOptions webSocketPerMessageDeflateOptions(header); - - // If the client has requested that extension, enable it. - if (webSocketPerMessageDeflateOptions.enabled()) - { - _enablePerMessageDeflate = true; - - if (!_perMessageDeflate.init(webSocketPerMessageDeflateOptions)) - { - return WebSocketInitResult( - false, 0,"Failed to initialize per message deflate engine"); - } - ss << webSocketPerMessageDeflateOptions.generateHeader(); - } - - ss << "\r\n"; - - if (!_socket->writeBytes(ss.str(), isCancellationRequested)) - { - return WebSocketInitResult(false, 0, std::string("Failed sending response to ") + remote); - } - - setReadyState(OPEN); - return WebSocketInitResult(true, 200, "", headers); + return result; } WebSocketTransport::ReadyStateValues WebSocketTransport::getReadyState() const diff --git a/ixwebsocket/IXWebSocketTransport.h b/ixwebsocket/IXWebSocketTransport.h index a40c0aec..966f2402 100644 --- a/ixwebsocket/IXWebSocketTransport.h +++ b/ixwebsocket/IXWebSocketTransport.h @@ -22,30 +22,12 @@ #include "IXWebSocketPerMessageDeflateOptions.h" #include "IXWebSocketHttpHeaders.h" #include "IXCancellationRequest.h" +#include "IXWebSocketHandshake.h" namespace ix { class Socket; - struct WebSocketInitResult - { - bool success; - int http_status; - std::string errorStr; - WebSocketHttpHeaders headers; - - WebSocketInitResult(bool s = false, - int status = 0, - const std::string& e = std::string(), - WebSocketHttpHeaders h = WebSocketHttpHeaders()) - { - success = s; - http_status = status; - errorStr = e; - headers = h; - } - }; - class WebSocketTransport { public: @@ -89,14 +71,6 @@ namespace ix void setOnCloseCallback(const OnCloseCallback& onCloseCallback); void dispatch(const OnMessageCallback& onMessageCallback); - static void printUrl(const std::string& url); - static bool parseUrl(const std::string& url, - std::string& protocol, - std::string& host, - std::string& path, - std::string& query, - int& port); - private: std::string _url; std::string _origin; @@ -161,11 +135,5 @@ namespace ix unsigned getRandomUnsigned(); void unmaskReceiveBuffer(const wsheader_type& ws); - std::string genRandomString(const int len); - - // Parse HTTP headers - std::pair parseHttpHeaders(const CancellationRequest& isCancellationRequested); - - WebSocketInitResult sendErrorResponse(int code, std::string reason); }; } diff --git a/ixwebsocket/libwshandshake.hpp b/ixwebsocket/libwshandshake.hpp index 5b90dcc8..588e1b6e 100644 --- a/ixwebsocket/libwshandshake.hpp +++ b/ixwebsocket/libwshandshake.hpp @@ -21,7 +21,7 @@ #include #include -class WebSocketHandshake { +class WebSocketHandshakeKeyGen { template struct static_for { void operator()(uint32_t *a, uint32_t *b) { diff --git a/test/cmd_websocket_chat.cpp b/test/cmd_websocket_chat.cpp index b78da6b6..900cbf82 100644 --- a/test/cmd_websocket_chat.cpp +++ b/test/cmd_websocket_chat.cpp @@ -226,7 +226,7 @@ namespace } } -TEST_CASE("Websocket chat", "[websocket_chat]") +TEST_CASE("Websocket_chat", "[websocket_chat]") { SECTION("Exchange and count sent/received messages.") { diff --git a/test/run.sh b/test/run.sh index 2b0cd1f2..4367ba44 100644 --- a/test/run.sh +++ b/test/run.sh @@ -2,7 +2,7 @@ mkdir build cd build -cmake .. || exit 1 +cmake -DCMAKE_BUILD_TYPE=Debug .. || exit 1 make || exit 1 ./ixwebsocket_unittest ${TEST}