per message deflate support (with zlib)

This commit is contained in:
Benjamin Sergeant
2018-11-09 18:23:49 -08:00
parent 32f4c8305e
commit 43fcf93584
32 changed files with 1003 additions and 257 deletions

View File

@ -9,6 +9,7 @@
//
#include "IXWebSocketTransport.h"
#include "IXWebSocketHttpHeaders.h"
#include "IXSocket.h"
#ifdef IXWEBSOCKET_USE_TLS
@ -31,18 +32,17 @@
#include <iostream>
#include <sstream>
#include <regex>
#include <unordered_map>
#include <random>
#include <algorithm>
namespace ix {
namespace ix
{
WebSocketTransport::WebSocketTransport() :
_readyState(CLOSED),
_enablePerMessageDeflate(false)
{
_perMessageDeflate.init();
}
WebSocketTransport::~WebSocketTransport()
@ -50,9 +50,12 @@ namespace ix {
;
}
void WebSocketTransport::configure(const std::string& url)
void WebSocketTransport::configure(const std::string& url,
const WebSocketPerMessageDeflateOptions& perMessageDeflateOptions)
{
_url = url;
_perMessageDeflateOptions = perMessageDeflateOptions;
_enablePerMessageDeflate = _perMessageDeflateOptions.enabled();
}
bool WebSocketTransport::parseUrl(const std::string& url,
@ -135,21 +138,22 @@ namespace ix {
std::string WebSocketTransport::genRandomString(const int len)
{
static const char alphanum[] =
std::string alphanum =
"0123456789"
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"abcdefghijklmnopqrstuvwxyz";
"ABCDEFGH"
"abcdefgh";
std::random_device r;
std::default_random_engine e1(r());
std::uniform_int_distribution<int> dist(0, sizeof(alphanum) - 1);
std::uniform_int_distribution<int> dist(0, (int) alphanum.size() - 1);
std::string s;
s.resize(len);
for (int i = 0; i < len; ++i)
{
s[i] += alphanum[dist(e1)];
int x = dist(e1);
s[i] = alphanum[x];
}
return s;
@ -206,35 +210,31 @@ namespace ix {
std::string secWebSocketKey = genRandomString(22);
secWebSocketKey += "==";
std::string extensions;
std::stringstream ss;
ss << "GET " << path << " HTTP/1.1\r\n";
ss << "Host: "<< host << ":" << port << "\r\n";
ss << "Upgrade: websocket\r\n";
ss << "Connection: Upgrade\r\n";
ss << "Sec-WebSocket-Version: 13\r\n";
ss << "Sec-WebSocket-Key: " << secWebSocketKey << "\r\n";
if (_enablePerMessageDeflate)
{
// extensions = "Sec-WebSocket-Extensions: permessage-deflate; client_no_context_takeover; server_no_context_takeover\r\n";
extensions = "Sec-WebSocket-Extensions: permessage-deflate\r\n";
ss << _perMessageDeflateOptions.generateHeader();
}
char line[512];
int status;
int i;
snprintf(line, 512,
"GET %s HTTP/1.1\r\n"
"Host: %s:%d\r\n"
"Upgrade: websocket\r\n"
"Connection: Upgrade\r\n"
"Sec-WebSocket-Key: %s\r\n"
"Sec-WebSocket-Version: 13\r\n"
"%s"
"\r\n",
path.c_str(), host.c_str(), port,
secWebSocketKey.c_str(), extensions.c_str());
ss << "\r\n";
size_t lineSize = strlen(line);
if (_socket->send(line, lineSize) != lineSize)
std::string request = ss.str();
int requestSize = (int) request.size();
if (_socket->send(const_cast<char*>(request.c_str()), requestSize) != requestSize)
{
return WebSocketInitResult(false, 0, std::string("Failed sending GET request to ") + _url);
}
char line[512];
int i;
for (i = 0; i < 2 || (i < 255 && line[i-2] != '\r' && line[i-1] != '\n'); ++i)
{
if (_socket->recv(line+i, 1) == 0)
@ -248,6 +248,9 @@ namespace ix {
return WebSocketInitResult(false, 0, std::string("Got bad status line connecting to ") + _url);
}
// Validate status
int status;
// HTTP/1.0 is too old.
if (sscanf(line, "HTTP/1.0 %d", &status) == 1)
{
@ -268,7 +271,7 @@ namespace ix {
return WebSocketInitResult(false, status, ss.str());
}
std::unordered_map<std::string, std::string> headers;
WebSocketHttpHeaders headers;
while (true)
{
@ -310,7 +313,6 @@ namespace ix {
std::transform(name.begin(), name.end(), name.begin(), ::tolower);
headers[name] = value;
std::cout << name << " -> " << value << std::endl;
}
}
@ -322,10 +324,29 @@ namespace ix {
return WebSocketInitResult(false, status, errorMsg);
}
if (_enablePerMessageDeflate)
{
// Parse the server response. Does it support deflate ?
std::string header = headers["sec-websocket-extensions"];
WebSocketPerMessageDeflateOptions webSocketPerMessageDeflateOptions(header);
// If the server does not support that extension, disable it.
if (!webSocketPerMessageDeflateOptions.enabled())
{
_enablePerMessageDeflate = false;
}
if (!_perMessageDeflate.init(webSocketPerMessageDeflateOptions))
{
return WebSocketInitResult(
false, 0,"Failed to initialize per message deflate engine");
}
}
_socket->configure();
setReadyState(OPEN);
return WebSocketInitResult(true, status, "");
return WebSocketInitResult(true, status, "", headers);
}
WebSocketTransport::ReadyStateValues WebSocketTransport::getReadyState() const
@ -341,7 +362,7 @@ namespace ix {
if (readyStateValue == CLOSED)
{
std::lock_guard<std::mutex> lock(_closeDataMutex);
_onCloseCallback(_closeCode, _closeReason);
_onCloseCallback(_closeCode, _closeReason, _closeWireSize);
_closeCode = 0;
_closeReason = std::string();
}
@ -546,33 +567,7 @@ namespace ix {
std::string stringMessage(_receivedData.begin(),
_receivedData.end());
std::cout << "raw msg: " << stringMessage << std::endl;
std::cout << "raw msg size: " << stringMessage.size() << std::endl;
// ws.rsv1 means the message is compressed
// FIXME hack hack
std::string decompressedMessage;
if (_enablePerMessageDeflate && ws.rsv1)
{
if (_perMessageDeflate.decompress(stringMessage,
decompressedMessage))
{
std::cout << "decompressed msg: " << decompressedMessage << std::endl;
std::cout << "msg size: " << decompressedMessage.size() << std::endl;
onMessageCallback(decompressedMessage, MSG);
}
else
{
std::cout << "error decompressing msg !"<< std::endl;
}
}
else
{
onMessageCallback(stringMessage, MSG);
}
emitMessage(MSG, stringMessage, ws, onMessageCallback);
_receivedData.clear();
}
}
@ -583,10 +578,9 @@ namespace ix {
_rxbuf.begin()+ws.header_size + (size_t) ws.N);
// Reply back right away
sendData(wsheader_type::PONG, pingData.size(),
pingData.begin(), pingData.end());
sendData(wsheader_type::PONG, pingData);
onMessageCallback(pingData, PING);
emitMessage(PING, pingData, ws, onMessageCallback);
}
else if (ws.opcode == wsheader_type::PONG)
{
@ -594,7 +588,7 @@ namespace ix {
std::string pongData(_rxbuf.begin()+ws.header_size,
_rxbuf.begin()+ws.header_size + (size_t) ws.N);
onMessageCallback(pongData, PONG);
emitMessage(PONG, pongData, ws, onMessageCallback);
}
else if (ws.opcode == wsheader_type::CLOSE)
{
@ -613,6 +607,7 @@ namespace ix {
std::lock_guard<std::mutex> lock(_closeDataMutex);
_closeCode = code;
_closeReason = reason;
_closeWireSize = _rxbuf.size();
}
close();
@ -627,6 +622,32 @@ namespace ix {
}
}
void WebSocketTransport::emitMessage(MessageKind messageKind,
const std::string& message,
const wsheader_type& ws,
const OnMessageCallback& onMessageCallback)
{
// ws.rsv1 means the message is compressed
std::string decompressedMessage;
if (_enablePerMessageDeflate && ws.rsv1)
{
if (_perMessageDeflate.decompress(message, decompressedMessage))
{
onMessageCallback(decompressedMessage, decompressedMessage.size(),
messageKind);
}
else
{
std::cerr << "error decompressing msg !"<< std::endl;
}
}
else
{
onMessageCallback(message, message.size(), messageKind);
}
}
unsigned WebSocketTransport::getRandomUnsigned()
{
auto now = std::chrono::system_clock::now();
@ -636,16 +657,32 @@ namespace ix {
return static_cast<unsigned>(seconds);
}
void WebSocketTransport::sendData(wsheader_type::opcode_type type,
uint64_t message_size,
std::string::const_iterator message_begin,
std::string::const_iterator message_end)
WebSocketSendInfo WebSocketTransport::sendData(wsheader_type::opcode_type type,
const std::string& message)
{
if (_readyState == CLOSING || _readyState == CLOSED)
{
return;
return WebSocketSendInfo();
}
size_t payloadSize = message.size();
size_t wireSize = message.size();
std::string compressedMessage;
std::string::const_iterator message_begin = message.begin();
std::string::const_iterator message_end = message.end();
if (_enablePerMessageDeflate)
{
_perMessageDeflate.compress(message, compressedMessage);
wireSize = compressedMessage.size();
message_begin = compressedMessage.begin();
message_end = compressedMessage.end();
}
uint64_t message_size = wireSize;
unsigned x = getRandomUnsigned();
uint8_t masking_key[4] = {};
masking_key[0] = (x >> 24);
@ -709,36 +746,18 @@ namespace ix {
// Now actually send this data
sendOnSocket();
return WebSocketSendInfo(true, payloadSize, wireSize);
}
void WebSocketTransport::sendPing(const std::string& message)
WebSocketSendInfo WebSocketTransport::sendPing(const std::string& message)
{
sendData(wsheader_type::PING, message.size(), message.begin(), message.end());
return sendData(wsheader_type::PING, message);
}
void WebSocketTransport::sendBinary(const std::string& message)
WebSocketSendInfo WebSocketTransport::sendBinary(const std::string& message)
{
if (_enablePerMessageDeflate)
{
// FIXME hack hack
std::string compressedMessage;
_perMessageDeflate.compress(message, compressedMessage);
std::cout << "uncompressedMessage " << message << std::endl;
std::cout << "uncompressedMessage.size() " << message.size() << std::endl;
std::cout << "compressedMessage.size() " << compressedMessage.size()
<< std::endl;
// sendData(wsheader_type::BINARY_FRAME, message.size(), message.begin(), message.end());
sendData(wsheader_type::BINARY_FRAME,
compressedMessage.size(),
compressedMessage.begin(),
compressedMessage.end());
}
else
{
sendData(wsheader_type::BINARY_FRAME, message.size(),
message.begin(), message.end());
}
return sendData(wsheader_type::BINARY_FRAME, message);
}
void WebSocketTransport::sendOnSocket()