/* * IXWebSocketHandshake.h * Author: Benjamin Sergeant * Copyright (c) 2019 Machine Zone, Inc. All rights reserved. */ #include "IXWebSocketHandshake.h" #include "IXSocketConnect.h" #include "IXUrlParser.h" #include "libwshandshake.hpp" #include #include #include #include #include namespace ix { WebSocketHandshake::WebSocketHandshake(std::atomic& requestInitCancellation, std::shared_ptr socket, WebSocketPerMessageDeflate& perMessageDeflate, WebSocketPerMessageDeflateOptions& perMessageDeflateOptions, std::atomic& enablePerMessageDeflate) : _requestInitCancellation(requestInitCancellation), _socket(socket), _perMessageDeflate(perMessageDeflate), _perMessageDeflateOptions(perMessageDeflateOptions), _enablePerMessageDeflate(enablePerMessageDeflate) { } std::string WebSocketHandshake::trim(const std::string& str) { std::string out(str); out.erase(std::remove(out.begin(), out.end(), ' '), out.end()); out.erase(std::remove(out.begin(), out.end(), '\r'), out.end()); out.erase(std::remove(out.begin(), out.end(), '\n'), out.end()); return out; } bool WebSocketHandshake::insensitiveStringCompare(const std::string& a, const std::string& b) { return std::equal(a.begin(), a.end(), b.begin(), b.end(), [](char a, char b) { return tolower(a) == tolower(b); }); } std::tuple WebSocketHandshake::parseRequestLine(const std::string& line) { // Request-Line = Method SP Request-URI SP HTTP-Version CRLF std::string token; std::stringstream tokenStream(line); std::vector tokens; // Split by ' ' while (std::getline(tokenStream, token, ' ')) { tokens.push_back(token); } std::string method; if (tokens.size() >= 1) { method = trim(tokens[0]); } std::string requestUri; if (tokens.size() >= 2) { requestUri = trim(tokens[1]); } std::string httpVersion; if (tokens.size() >= 3) { httpVersion = trim(tokens[2]); } return std::make_tuple(method, requestUri, httpVersion); } std::string WebSocketHandshake::genRandomString(const int len) { std::string alphanum = "0123456789" "ABCDEFGH" "abcdefgh"; std::random_device r; std::default_random_engine e1(r()); std::uniform_int_distribution dist(0, (int) alphanum.size() - 1); std::string s; s.resize(len); for (int i = 0; i < len; ++i) { int x = dist(e1); s[i] = alphanum[x]; } return s; } WebSocketInitResult WebSocketHandshake::sendErrorResponse(int code, const std::string& reason) { std::stringstream ss; ss << "HTTP/1.1 "; ss << code; ss << " "; ss << reason; ss << "\r\n"; // Socket write can only be cancelled through a timeout here, not manually. static std::atomic requestInitCancellation(false); auto isCancellationRequested = makeCancellationRequestWithTimeout(1, requestInitCancellation); if (!_socket->writeBytes(ss.str(), isCancellationRequested)) { return WebSocketInitResult(false, 500, "Timed out while sending error response"); } return WebSocketInitResult(false, code, reason); } WebSocketInitResult WebSocketHandshake::clientHandshake(const std::string& url, const std::string& host, const std::string& path, int port, int timeoutSecs) { _requestInitCancellation = false; auto isCancellationRequested = makeCancellationRequestWithTimeout(timeoutSecs, _requestInitCancellation); std::string errMsg; bool success = _socket->connect(host, port, errMsg, isCancellationRequested); if (!success) { std::stringstream ss; ss << "Unable to connect to " << host << " on port " << port << ", error: " << errMsg; return WebSocketInitResult(false, 0, ss.str()); } // // Generate a random 24 bytes string which looks like it is base64 encoded // y3JJHMbDL1EzLkh9GBhXDw== // 0cb3Vd9HkbpVVumoS3Noka== // // See https://stackoverflow.com/questions/18265128/what-is-sec-websocket-key-for // std::string secWebSocketKey = genRandomString(22); secWebSocketKey += "=="; std::stringstream ss; ss << "GET " << path << " HTTP/1.1\r\n"; ss << "Host: "<< host << ":" << port << "\r\n"; ss << "Upgrade: websocket\r\n"; ss << "Connection: Upgrade\r\n"; ss << "Sec-WebSocket-Version: 13\r\n"; ss << "Sec-WebSocket-Key: " << secWebSocketKey << "\r\n"; if (_enablePerMessageDeflate) { ss << _perMessageDeflateOptions.generateHeader(); } ss << "\r\n"; if (!_socket->writeBytes(ss.str(), isCancellationRequested)) { return WebSocketInitResult(false, 0, std::string("Failed sending GET request to ") + url); } // Read HTTP status line auto lineResult = _socket->readLine(isCancellationRequested); auto lineValid = lineResult.first; auto line = lineResult.second; if (!lineValid) { return WebSocketInitResult(false, 0, std::string("Failed reading HTTP status line from ") + url); } // Validate status int status; // HTTP/1.0 is too old. if (sscanf(line.c_str(), "HTTP/1.0 %d", &status) == 1) { std::stringstream ss; ss << "Server version is HTTP/1.0. Rejecting connection to " << host << ", status: " << status << ", HTTP Status line: " << line; return WebSocketInitResult(false, status, ss.str()); } // We want an 101 HTTP status if (sscanf(line.c_str(), "HTTP/1.1 %d", &status) != 1 || status != 101) { std::stringstream ss; ss << "Got bad status connecting to " << host << ", status: " << status << ", HTTP Status line: " << line; return WebSocketInitResult(false, status, ss.str()); } auto result = parseHttpHeaders(_socket, isCancellationRequested); auto headersValid = result.first; auto headers = result.second; if (!headersValid) { return WebSocketInitResult(false, status, "Error parsing HTTP headers"); } // Check the presence of the connection field if (headers.find("connection") == headers.end()) { std::string errorMsg("Missing connection value"); return WebSocketInitResult(false, status, errorMsg); } // Check the value of the connection field // Some websocket servers (Go/Gorilla?) send lowercase values for the // connection header, so do a case insensitive comparison if (!insensitiveStringCompare(headers["connection"], "Upgrade")) { std::stringstream ss; ss << "Invalid connection value: " << headers["connection"]; return WebSocketInitResult(false, status, ss.str()); } char output[29] = {}; WebSocketHandshakeKeyGen::generate(secWebSocketKey.c_str(), output); if (std::string(output) != headers["sec-websocket-accept"]) { std::string errorMsg("Invalid Sec-WebSocket-Accept value"); return WebSocketInitResult(false, status, errorMsg); } if (_enablePerMessageDeflate) { // Parse the server response. Does it support deflate ? std::string header = headers["sec-websocket-extensions"]; WebSocketPerMessageDeflateOptions webSocketPerMessageDeflateOptions(header); // If the server does not support that extension, disable it. if (!webSocketPerMessageDeflateOptions.enabled()) { _enablePerMessageDeflate = false; } // Otherwise try to initialize the deflate engine (zlib) else if (!_perMessageDeflate.init(webSocketPerMessageDeflateOptions)) { return WebSocketInitResult( false, 0,"Failed to initialize per message deflate engine"); } } return WebSocketInitResult(true, status, "", headers, path); } WebSocketInitResult WebSocketHandshake::serverHandshake(int fd, int timeoutSecs) { _requestInitCancellation = false; // Set the socket to non blocking mode + other tweaks SocketConnect::configure(fd); auto isCancellationRequested = makeCancellationRequestWithTimeout(timeoutSecs, _requestInitCancellation); std::string remote = std::string("remote fd ") + std::to_string(fd); // Read first line auto lineResult = _socket->readLine(isCancellationRequested); auto lineValid = lineResult.first; auto line = lineResult.second; if (!lineValid) { return sendErrorResponse(400, "Error reading HTTP request line"); } // Validate request line (GET /foo HTTP/1.1\r\n) auto requestLine = parseRequestLine(line); auto method = std::get<0>(requestLine); auto uri = std::get<1>(requestLine); auto httpVersion = std::get<2>(requestLine); if (method != "GET") { return sendErrorResponse(400, "Invalid HTTP method, need GET, got " + method); } if (httpVersion != "HTTP/1.1") { return sendErrorResponse(400, "Invalid HTTP version, need HTTP/1.1, got: " + httpVersion); } // Retrieve and validate HTTP headers auto result = parseHttpHeaders(_socket, isCancellationRequested); auto headersValid = result.first; auto headers = result.second; if (!headersValid) { return sendErrorResponse(400, "Error parsing HTTP headers"); } if (headers.find("sec-websocket-key") == headers.end()) { return sendErrorResponse(400, "Missing Sec-WebSocket-Key value"); } if (headers["upgrade"] != "websocket") { return sendErrorResponse(400, "Invalid or missing Upgrade header"); } if (headers.find("sec-websocket-version") == headers.end()) { return sendErrorResponse(400, "Missing Sec-WebSocket-Version value"); } { std::stringstream ss; ss << headers["sec-websocket-version"]; int version; ss >> version; if (version != 13) { return sendErrorResponse(400, "Invalid Sec-WebSocket-Version, " "need 13, got" + ss.str()); } } char output[29] = {}; WebSocketHandshakeKeyGen::generate(headers["sec-websocket-key"].c_str(), output); std::stringstream ss; ss << "HTTP/1.1 101 Switching Protocols\r\n"; ss << "Sec-WebSocket-Accept: " << std::string(output) << "\r\n"; ss << "Upgrade: websocket\r\n"; ss << "Connection: Upgrade\r\n"; // Parse the client headers. Does it support deflate ? std::string header = headers["sec-websocket-extensions"]; WebSocketPerMessageDeflateOptions webSocketPerMessageDeflateOptions(header); // If the client has requested that extension, enable it. if (webSocketPerMessageDeflateOptions.enabled()) { _enablePerMessageDeflate = true; if (!_perMessageDeflate.init(webSocketPerMessageDeflateOptions)) { return WebSocketInitResult( false, 0,"Failed to initialize per message deflate engine"); } ss << webSocketPerMessageDeflateOptions.generateHeader(); } ss << "\r\n"; if (!_socket->writeBytes(ss.str(), isCancellationRequested)) { return WebSocketInitResult(false, 0, std::string("Failed sending response to ") + remote); } return WebSocketInitResult(true, 200, "", headers, uri); } }