diff --git a/ixwebsocket/IXWebSocketHandshake.cpp b/ixwebsocket/IXWebSocketHandshake.cpp index 1bcc26e0..ad9ebfbf 100644 --- a/ixwebsocket/IXWebSocketHandshake.cpp +++ b/ixwebsocket/IXWebSocketHandshake.cpp @@ -116,6 +116,49 @@ namespace ix std::cout << "-------------------------------" << std::endl; } + std::string trim(const std::string& str) + { + std::string out(str); + out.erase(std::remove(out.begin(), out.end(), ' '), out.end()); + out.erase(std::remove(out.begin(), out.end(), '\r'), out.end()); + out.erase(std::remove(out.begin(), out.end(), '\n'), out.end()); + return out; + } + + std::tuple WebSocketHandshake::parseRequestLine(const std::string& line) + { + // Request-Line = Method SP Request-URI SP HTTP-Version CRLF + std::string token; + std::stringstream tokenStream(line); + std::vector tokens; + + // Split by ' ' + while (std::getline(tokenStream, token, ' ')) + { + tokens.push_back(token); + } + + std::string method; + if (tokens.size() >= 1) + { + method = trim(tokens[0]); + } + + std::string requestUri; + if (tokens.size() >= 2) + { + requestUri = trim(tokens[1]); + } + + std::string httpVersion; + if (tokens.size() >= 3) + { + httpVersion = trim(tokens[2]); + } + + return std::make_tuple(method, requestUri, httpVersion); + } + std::string WebSocketHandshake::genRandomString(const int len) { std::string alphanum = @@ -363,8 +406,23 @@ namespace ix return sendErrorResponse(400, "Error reading HTTP request line"); } - // FIXME: Validate line content (GET /) + // Validate request line (GET /foo HTTP/1.1\r\n) + auto requestLine = parseRequestLine(line); + auto method = std::get<0>(requestLine); + auto uri = std::get<1>(requestLine); + auto httpVersion = std::get<2>(requestLine); + if (method != "GET") + { + return sendErrorResponse(400, "Invalid HTTP method, need GET, got " + method); + } + + if (httpVersion != "HTTP/1.1") + { + return sendErrorResponse(400, "Invalid HTTP version, need HTTP/1.1, got: " + httpVersion); + } + + // Retrieve and validate HTTP headers auto result = parseHttpHeaders(isCancellationRequested); auto headersValid = result.first; auto headers = result.second; @@ -379,6 +437,29 @@ namespace ix return sendErrorResponse(400, "Missing Sec-WebSocket-Key value"); } + if (headers["upgrade"] != "websocket") + { + return sendErrorResponse(400, "Invalid or missing Upgrade header"); + } + + if (headers.find("sec-websocket-version") == headers.end()) + { + return sendErrorResponse(400, "Missing Sec-WebSocket-Version value"); + } + + { + std::stringstream ss; + ss << headers["sec-websocket-version"]; + int version; + ss >> version; + + if (version != 13) + { + return sendErrorResponse(400, "Invalid Sec-WebSocket-Version, " + "need 13, got" + ss.str()); + } + } + char output[29] = {}; WebSocketHandshakeKeyGen::generate(headers["sec-websocket-key"].c_str(), output); diff --git a/ixwebsocket/IXWebSocketHandshake.h b/ixwebsocket/IXWebSocketHandshake.h index 46116928..2d9054c3 100644 --- a/ixwebsocket/IXWebSocketHandshake.h +++ b/ixwebsocket/IXWebSocketHandshake.h @@ -16,6 +16,7 @@ #include #include #include +#include namespace ix { @@ -67,6 +68,8 @@ namespace ix std::pair parseHttpHeaders(const CancellationRequest& isCancellationRequested); WebSocketInitResult sendErrorResponse(int code, const std::string& reason); + std::tuple parseRequestLine(const std::string& line); + std::atomic& _requestInitCancellation; std::shared_ptr _socket; WebSocketPerMessageDeflate& _perMessageDeflate; @@ -74,4 +77,3 @@ namespace ix std::atomic& _enablePerMessageDeflate; }; } - diff --git a/test/IXWebSocketServerTest.cpp b/test/IXWebSocketServerTest.cpp index b72d6c76..ab783c85 100644 --- a/test/IXWebSocketServerTest.cpp +++ b/test/IXWebSocketServerTest.cpp @@ -155,7 +155,13 @@ TEST_CASE("Websocket_server", "[websocket_server]") bool success = socket.connect(host, port, errMsg, isCancellationRequested); REQUIRE(success); - socket.writeBytes("GET /\r\nSec-WebSocket-Key: foobar\r\n\r\n", isCancellationRequested); + socket.writeBytes("GET / HTTP/1.1\r\n" + "Upgrade: websocket\r\n" + "Sec-WebSocket-Version: 13\r\n" + "Sec-WebSocket-Key: foobar\r\n" + "\r\n", + isCancellationRequested); + auto lineResult = socket.readLine(isCancellationRequested); auto lineValid = lineResult.first; auto line = lineResult.second;