2019-01-03 05:07:54 +01:00
|
|
|
/*
|
|
|
|
* IXWebSocketHandshake.h
|
|
|
|
* Author: Benjamin Sergeant
|
|
|
|
* Copyright (c) 2019 Machine Zone, Inc. All rights reserved.
|
|
|
|
*/
|
|
|
|
|
|
|
|
#include "IXWebSocketHandshake.h"
|
|
|
|
#include "IXSocketConnect.h"
|
2019-03-01 06:54:03 +01:00
|
|
|
#include "IXUrlParser.h"
|
2019-01-03 05:07:54 +01:00
|
|
|
|
|
|
|
#include "libwshandshake.hpp"
|
|
|
|
|
|
|
|
#include <iostream>
|
|
|
|
#include <sstream>
|
|
|
|
#include <random>
|
|
|
|
#include <algorithm>
|
|
|
|
|
|
|
|
|
2019-02-21 03:59:07 +01:00
|
|
|
namespace ix
|
2019-01-03 05:07:54 +01:00
|
|
|
{
|
|
|
|
WebSocketHandshake::WebSocketHandshake(std::atomic<bool>& requestInitCancellation,
|
|
|
|
std::shared_ptr<Socket> socket,
|
|
|
|
WebSocketPerMessageDeflate& perMessageDeflate,
|
|
|
|
WebSocketPerMessageDeflateOptions& perMessageDeflateOptions,
|
|
|
|
std::atomic<bool>& enablePerMessageDeflate) :
|
|
|
|
_requestInitCancellation(requestInitCancellation),
|
|
|
|
_socket(socket),
|
|
|
|
_perMessageDeflate(perMessageDeflate),
|
|
|
|
_perMessageDeflateOptions(perMessageDeflateOptions),
|
|
|
|
_enablePerMessageDeflate(enablePerMessageDeflate)
|
|
|
|
{
|
|
|
|
|
|
|
|
}
|
|
|
|
|
2019-01-06 23:26:11 +01:00
|
|
|
std::string WebSocketHandshake::trim(const std::string& str)
|
2019-01-03 22:41:06 +01:00
|
|
|
{
|
|
|
|
std::string out(str);
|
|
|
|
out.erase(std::remove(out.begin(), out.end(), ' '), out.end());
|
|
|
|
out.erase(std::remove(out.begin(), out.end(), '\r'), out.end());
|
|
|
|
out.erase(std::remove(out.begin(), out.end(), '\n'), out.end());
|
|
|
|
return out;
|
|
|
|
}
|
|
|
|
|
2019-01-26 01:17:47 +01:00
|
|
|
bool WebSocketHandshake::insensitiveStringCompare(const std::string& a, const std::string& b)
|
|
|
|
{
|
|
|
|
return std::equal(a.begin(), a.end(),
|
|
|
|
b.begin(), b.end(),
|
|
|
|
[](char a, char b)
|
|
|
|
{
|
|
|
|
return tolower(a) == tolower(b);
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2019-01-03 22:41:06 +01:00
|
|
|
std::tuple<std::string, std::string, std::string> WebSocketHandshake::parseRequestLine(const std::string& line)
|
|
|
|
{
|
|
|
|
// Request-Line = Method SP Request-URI SP HTTP-Version CRLF
|
|
|
|
std::string token;
|
|
|
|
std::stringstream tokenStream(line);
|
|
|
|
std::vector<std::string> tokens;
|
|
|
|
|
|
|
|
// Split by ' '
|
|
|
|
while (std::getline(tokenStream, token, ' '))
|
|
|
|
{
|
|
|
|
tokens.push_back(token);
|
|
|
|
}
|
|
|
|
|
|
|
|
std::string method;
|
|
|
|
if (tokens.size() >= 1)
|
|
|
|
{
|
|
|
|
method = trim(tokens[0]);
|
|
|
|
}
|
|
|
|
|
|
|
|
std::string requestUri;
|
|
|
|
if (tokens.size() >= 2)
|
|
|
|
{
|
|
|
|
requestUri = trim(tokens[1]);
|
|
|
|
}
|
|
|
|
|
|
|
|
std::string httpVersion;
|
|
|
|
if (tokens.size() >= 3)
|
|
|
|
{
|
|
|
|
httpVersion = trim(tokens[2]);
|
|
|
|
}
|
|
|
|
|
|
|
|
return std::make_tuple(method, requestUri, httpVersion);
|
|
|
|
}
|
|
|
|
|
2019-01-03 05:07:54 +01:00
|
|
|
std::string WebSocketHandshake::genRandomString(const int len)
|
|
|
|
{
|
2019-02-21 03:59:07 +01:00
|
|
|
std::string alphanum =
|
2019-01-03 05:07:54 +01:00
|
|
|
"0123456789"
|
|
|
|
"ABCDEFGH"
|
|
|
|
"abcdefgh";
|
|
|
|
|
|
|
|
std::random_device r;
|
|
|
|
std::default_random_engine e1(r());
|
|
|
|
std::uniform_int_distribution<int> dist(0, (int) alphanum.size() - 1);
|
|
|
|
|
|
|
|
std::string s;
|
|
|
|
s.resize(len);
|
|
|
|
|
|
|
|
for (int i = 0; i < len; ++i)
|
|
|
|
{
|
|
|
|
int x = dist(e1);
|
|
|
|
s[i] = alphanum[x];
|
|
|
|
}
|
|
|
|
|
|
|
|
return s;
|
|
|
|
}
|
|
|
|
|
|
|
|
WebSocketInitResult WebSocketHandshake::sendErrorResponse(int code, const std::string& reason)
|
|
|
|
{
|
|
|
|
std::stringstream ss;
|
|
|
|
ss << "HTTP/1.1 ";
|
|
|
|
ss << code;
|
2019-03-22 17:52:19 +01:00
|
|
|
ss << " ";
|
2019-01-03 05:07:54 +01:00
|
|
|
ss << reason;
|
|
|
|
ss << "\r\n";
|
|
|
|
|
2019-01-03 21:53:44 +01:00
|
|
|
// Socket write can only be cancelled through a timeout here, not manually.
|
|
|
|
static std::atomic<bool> requestInitCancellation(false);
|
|
|
|
auto isCancellationRequested =
|
|
|
|
makeCancellationRequestWithTimeout(1, requestInitCancellation);
|
2019-01-03 05:07:54 +01:00
|
|
|
|
|
|
|
if (!_socket->writeBytes(ss.str(), isCancellationRequested))
|
|
|
|
{
|
|
|
|
return WebSocketInitResult(false, 500, "Timed out while sending error response");
|
|
|
|
}
|
|
|
|
|
|
|
|
return WebSocketInitResult(false, code, reason);
|
|
|
|
}
|
|
|
|
|
|
|
|
WebSocketInitResult WebSocketHandshake::clientHandshake(const std::string& url,
|
|
|
|
const std::string& host,
|
|
|
|
const std::string& path,
|
2019-01-04 03:33:08 +01:00
|
|
|
int port,
|
|
|
|
int timeoutSecs)
|
2019-01-03 05:07:54 +01:00
|
|
|
{
|
|
|
|
_requestInitCancellation = false;
|
|
|
|
|
2019-02-21 03:59:07 +01:00
|
|
|
auto isCancellationRequested =
|
2019-01-04 03:33:08 +01:00
|
|
|
makeCancellationRequestWithTimeout(timeoutSecs, _requestInitCancellation);
|
2019-01-03 05:07:54 +01:00
|
|
|
|
|
|
|
std::string errMsg;
|
|
|
|
bool success = _socket->connect(host, port, errMsg, isCancellationRequested);
|
|
|
|
if (!success)
|
|
|
|
{
|
|
|
|
std::stringstream ss;
|
|
|
|
ss << "Unable to connect to " << host
|
|
|
|
<< " on port " << port
|
|
|
|
<< ", error: " << errMsg;
|
|
|
|
return WebSocketInitResult(false, 0, ss.str());
|
|
|
|
}
|
|
|
|
|
|
|
|
//
|
|
|
|
// Generate a random 24 bytes string which looks like it is base64 encoded
|
|
|
|
// y3JJHMbDL1EzLkh9GBhXDw==
|
|
|
|
// 0cb3Vd9HkbpVVumoS3Noka==
|
|
|
|
//
|
|
|
|
// See https://stackoverflow.com/questions/18265128/what-is-sec-websocket-key-for
|
|
|
|
//
|
|
|
|
std::string secWebSocketKey = genRandomString(22);
|
|
|
|
secWebSocketKey += "==";
|
|
|
|
|
|
|
|
std::stringstream ss;
|
|
|
|
ss << "GET " << path << " HTTP/1.1\r\n";
|
|
|
|
ss << "Host: "<< host << ":" << port << "\r\n";
|
|
|
|
ss << "Upgrade: websocket\r\n";
|
|
|
|
ss << "Connection: Upgrade\r\n";
|
|
|
|
ss << "Sec-WebSocket-Version: 13\r\n";
|
|
|
|
ss << "Sec-WebSocket-Key: " << secWebSocketKey << "\r\n";
|
|
|
|
|
|
|
|
if (_enablePerMessageDeflate)
|
|
|
|
{
|
|
|
|
ss << _perMessageDeflateOptions.generateHeader();
|
|
|
|
}
|
|
|
|
|
|
|
|
ss << "\r\n";
|
|
|
|
|
|
|
|
if (!_socket->writeBytes(ss.str(), isCancellationRequested))
|
|
|
|
{
|
|
|
|
return WebSocketInitResult(false, 0, std::string("Failed sending GET request to ") + url);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Read HTTP status line
|
|
|
|
auto lineResult = _socket->readLine(isCancellationRequested);
|
|
|
|
auto lineValid = lineResult.first;
|
|
|
|
auto line = lineResult.second;
|
|
|
|
|
|
|
|
if (!lineValid)
|
|
|
|
{
|
|
|
|
return WebSocketInitResult(false, 0,
|
|
|
|
std::string("Failed reading HTTP status line from ") + url);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Validate status
|
|
|
|
int status;
|
|
|
|
|
|
|
|
// HTTP/1.0 is too old.
|
|
|
|
if (sscanf(line.c_str(), "HTTP/1.0 %d", &status) == 1)
|
|
|
|
{
|
|
|
|
std::stringstream ss;
|
|
|
|
ss << "Server version is HTTP/1.0. Rejecting connection to " << host
|
|
|
|
<< ", status: " << status
|
|
|
|
<< ", HTTP Status line: " << line;
|
|
|
|
return WebSocketInitResult(false, status, ss.str());
|
|
|
|
}
|
|
|
|
|
|
|
|
// We want an 101 HTTP status
|
|
|
|
if (sscanf(line.c_str(), "HTTP/1.1 %d", &status) != 1 || status != 101)
|
|
|
|
{
|
|
|
|
std::stringstream ss;
|
|
|
|
ss << "Got bad status connecting to " << host
|
|
|
|
<< ", status: " << status
|
|
|
|
<< ", HTTP Status line: " << line;
|
|
|
|
return WebSocketInitResult(false, status, ss.str());
|
|
|
|
}
|
|
|
|
|
2019-03-01 06:54:03 +01:00
|
|
|
auto result = parseHttpHeaders(_socket, isCancellationRequested);
|
2019-01-03 05:07:54 +01:00
|
|
|
auto headersValid = result.first;
|
|
|
|
auto headers = result.second;
|
|
|
|
|
|
|
|
if (!headersValid)
|
|
|
|
{
|
|
|
|
return WebSocketInitResult(false, status, "Error parsing HTTP headers");
|
|
|
|
}
|
|
|
|
|
2019-01-26 01:17:47 +01:00
|
|
|
// Check the presence of the connection field
|
|
|
|
if (headers.find("connection") == headers.end())
|
2019-01-15 18:31:37 +01:00
|
|
|
{
|
2019-01-26 01:17:47 +01:00
|
|
|
std::string errorMsg("Missing connection value");
|
2019-01-15 18:31:37 +01:00
|
|
|
return WebSocketInitResult(false, status, errorMsg);
|
|
|
|
}
|
|
|
|
|
2019-01-26 01:17:47 +01:00
|
|
|
// Check the value of the connection field
|
2019-02-21 03:59:07 +01:00
|
|
|
// Some websocket servers (Go/Gorilla?) send lowercase values for the
|
2019-01-26 01:17:47 +01:00
|
|
|
// connection header, so do a case insensitive comparison
|
|
|
|
if (!insensitiveStringCompare(headers["connection"], "Upgrade"))
|
|
|
|
{
|
|
|
|
std::stringstream ss;
|
|
|
|
ss << "Invalid connection value: " << headers["connection"];
|
|
|
|
return WebSocketInitResult(false, status, ss.str());
|
|
|
|
}
|
|
|
|
|
2019-01-03 05:07:54 +01:00
|
|
|
char output[29] = {};
|
|
|
|
WebSocketHandshakeKeyGen::generate(secWebSocketKey.c_str(), output);
|
|
|
|
if (std::string(output) != headers["sec-websocket-accept"])
|
|
|
|
{
|
|
|
|
std::string errorMsg("Invalid Sec-WebSocket-Accept value");
|
|
|
|
return WebSocketInitResult(false, status, errorMsg);
|
|
|
|
}
|
|
|
|
|
|
|
|
if (_enablePerMessageDeflate)
|
|
|
|
{
|
|
|
|
// Parse the server response. Does it support deflate ?
|
|
|
|
std::string header = headers["sec-websocket-extensions"];
|
|
|
|
WebSocketPerMessageDeflateOptions webSocketPerMessageDeflateOptions(header);
|
|
|
|
|
|
|
|
// If the server does not support that extension, disable it.
|
|
|
|
if (!webSocketPerMessageDeflateOptions.enabled())
|
|
|
|
{
|
|
|
|
_enablePerMessageDeflate = false;
|
|
|
|
}
|
|
|
|
// Otherwise try to initialize the deflate engine (zlib)
|
|
|
|
else if (!_perMessageDeflate.init(webSocketPerMessageDeflateOptions))
|
|
|
|
{
|
|
|
|
return WebSocketInitResult(
|
|
|
|
false, 0,"Failed to initialize per message deflate engine");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2019-01-04 02:44:10 +01:00
|
|
|
return WebSocketInitResult(true, status, "", headers, path);
|
2019-01-03 05:07:54 +01:00
|
|
|
}
|
|
|
|
|
2019-01-04 03:33:08 +01:00
|
|
|
WebSocketInitResult WebSocketHandshake::serverHandshake(int fd, int timeoutSecs)
|
2019-01-03 05:07:54 +01:00
|
|
|
{
|
|
|
|
_requestInitCancellation = false;
|
|
|
|
|
|
|
|
// Set the socket to non blocking mode + other tweaks
|
|
|
|
SocketConnect::configure(fd);
|
|
|
|
|
2019-02-21 03:59:07 +01:00
|
|
|
auto isCancellationRequested =
|
2019-01-04 03:33:08 +01:00
|
|
|
makeCancellationRequestWithTimeout(timeoutSecs, _requestInitCancellation);
|
2019-01-03 05:07:54 +01:00
|
|
|
|
|
|
|
std::string remote = std::string("remote fd ") + std::to_string(fd);
|
|
|
|
|
|
|
|
// Read first line
|
|
|
|
auto lineResult = _socket->readLine(isCancellationRequested);
|
|
|
|
auto lineValid = lineResult.first;
|
|
|
|
auto line = lineResult.second;
|
|
|
|
|
|
|
|
if (!lineValid)
|
|
|
|
{
|
|
|
|
return sendErrorResponse(400, "Error reading HTTP request line");
|
|
|
|
}
|
2019-02-21 03:59:07 +01:00
|
|
|
|
2019-01-03 22:41:06 +01:00
|
|
|
// Validate request line (GET /foo HTTP/1.1\r\n)
|
|
|
|
auto requestLine = parseRequestLine(line);
|
|
|
|
auto method = std::get<0>(requestLine);
|
|
|
|
auto uri = std::get<1>(requestLine);
|
|
|
|
auto httpVersion = std::get<2>(requestLine);
|
2019-01-03 05:07:54 +01:00
|
|
|
|
2019-01-03 22:41:06 +01:00
|
|
|
if (method != "GET")
|
|
|
|
{
|
|
|
|
return sendErrorResponse(400, "Invalid HTTP method, need GET, got " + method);
|
|
|
|
}
|
|
|
|
|
|
|
|
if (httpVersion != "HTTP/1.1")
|
|
|
|
{
|
|
|
|
return sendErrorResponse(400, "Invalid HTTP version, need HTTP/1.1, got: " + httpVersion);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Retrieve and validate HTTP headers
|
2019-03-01 06:54:03 +01:00
|
|
|
auto result = parseHttpHeaders(_socket, isCancellationRequested);
|
2019-01-03 05:07:54 +01:00
|
|
|
auto headersValid = result.first;
|
|
|
|
auto headers = result.second;
|
|
|
|
|
|
|
|
if (!headersValid)
|
|
|
|
{
|
|
|
|
return sendErrorResponse(400, "Error parsing HTTP headers");
|
|
|
|
}
|
|
|
|
|
|
|
|
if (headers.find("sec-websocket-key") == headers.end())
|
|
|
|
{
|
|
|
|
return sendErrorResponse(400, "Missing Sec-WebSocket-Key value");
|
|
|
|
}
|
|
|
|
|
2019-01-03 22:41:06 +01:00
|
|
|
if (headers["upgrade"] != "websocket")
|
|
|
|
{
|
|
|
|
return sendErrorResponse(400, "Invalid or missing Upgrade header");
|
|
|
|
}
|
|
|
|
|
|
|
|
if (headers.find("sec-websocket-version") == headers.end())
|
|
|
|
{
|
|
|
|
return sendErrorResponse(400, "Missing Sec-WebSocket-Version value");
|
|
|
|
}
|
|
|
|
|
|
|
|
{
|
|
|
|
std::stringstream ss;
|
|
|
|
ss << headers["sec-websocket-version"];
|
|
|
|
int version;
|
|
|
|
ss >> version;
|
|
|
|
|
|
|
|
if (version != 13)
|
|
|
|
{
|
|
|
|
return sendErrorResponse(400, "Invalid Sec-WebSocket-Version, "
|
|
|
|
"need 13, got" + ss.str());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2019-01-03 05:07:54 +01:00
|
|
|
char output[29] = {};
|
|
|
|
WebSocketHandshakeKeyGen::generate(headers["sec-websocket-key"].c_str(), output);
|
|
|
|
|
|
|
|
std::stringstream ss;
|
2019-03-21 21:43:47 +01:00
|
|
|
ss << "HTTP/1.1 101 Switching Protocols\r\n";
|
2019-01-03 05:07:54 +01:00
|
|
|
ss << "Sec-WebSocket-Accept: " << std::string(output) << "\r\n";
|
2019-01-15 08:35:37 +01:00
|
|
|
ss << "Upgrade: websocket\r\n";
|
2019-01-15 18:31:37 +01:00
|
|
|
ss << "Connection: Upgrade\r\n";
|
2019-01-03 05:07:54 +01:00
|
|
|
|
|
|
|
// Parse the client headers. Does it support deflate ?
|
|
|
|
std::string header = headers["sec-websocket-extensions"];
|
|
|
|
WebSocketPerMessageDeflateOptions webSocketPerMessageDeflateOptions(header);
|
|
|
|
|
|
|
|
// If the client has requested that extension, enable it.
|
|
|
|
if (webSocketPerMessageDeflateOptions.enabled())
|
|
|
|
{
|
|
|
|
_enablePerMessageDeflate = true;
|
|
|
|
|
|
|
|
if (!_perMessageDeflate.init(webSocketPerMessageDeflateOptions))
|
|
|
|
{
|
|
|
|
return WebSocketInitResult(
|
|
|
|
false, 0,"Failed to initialize per message deflate engine");
|
|
|
|
}
|
|
|
|
ss << webSocketPerMessageDeflateOptions.generateHeader();
|
|
|
|
}
|
|
|
|
|
|
|
|
ss << "\r\n";
|
|
|
|
|
|
|
|
if (!_socket->writeBytes(ss.str(), isCancellationRequested))
|
|
|
|
{
|
|
|
|
return WebSocketInitResult(false, 0, std::string("Failed sending response to ") + remote);
|
|
|
|
}
|
|
|
|
|
2019-01-04 02:44:10 +01:00
|
|
|
return WebSocketInitResult(true, 200, "", headers, uri);
|
2019-01-03 05:07:54 +01:00
|
|
|
}
|
|
|
|
}
|