WebSocketPerMessageDeflateCompressor can work with vector or std::string

This commit is contained in:
Benjamin Sergeant 2020-07-07 18:17:44 -07:00
parent e9e768a288
commit 95eab59c08
3 changed files with 65 additions and 9 deletions

View File

@ -59,14 +59,38 @@ namespace ix
return true; return true;
} }
bool WebSocketPerMessageDeflateCompressor::endsWith(const std::string& value, template<typename T>
const std::string& ending) bool WebSocketPerMessageDeflateCompressor::endsWithEmptyUnCompressedBlock(const T& value)
{ {
if (ending.size() > value.size()) return false; if (kEmptyUncompressedBlock.size() > value.size()) return false;
return std::equal(ending.rbegin(), ending.rend(), value.rbegin()); auto N = value.size();
return value[N - 1] == kEmptyUncompressedBlock[3] &&
value[N - 2] == kEmptyUncompressedBlock[2] &&
value[N - 3] == kEmptyUncompressedBlock[1] &&
value[N - 4] == kEmptyUncompressedBlock[0];
} }
bool WebSocketPerMessageDeflateCompressor::compress(const std::string& in, std::string& out) bool WebSocketPerMessageDeflateCompressor::compress(const std::string& in, std::string& out)
{
return compressData(in, out);
}
bool WebSocketPerMessageDeflateCompressor::compress(const std::string& in, std::vector<uint8_t>& out)
{
return compressData(in, out);
}
bool WebSocketPerMessageDeflateCompressor::compress(const std::vector<uint8_t>& in, std::string& out)
{
return compressData(in, out);
}
bool WebSocketPerMessageDeflateCompressor::compress(const std::vector<uint8_t>& in, std::vector<uint8_t>& out)
{
return compressData(in, out);
}
template<typename T, typename S> bool WebSocketPerMessageDeflateCompressor::compressData(const T& in, S& out)
{ {
// //
// 7.2.1. Compression // 7.2.1. Compression
@ -96,7 +120,8 @@ namespace ix
// The normal buffer size should be 6 but // The normal buffer size should be 6 but
// we remove the 4 octets from the tail (#4) // we remove the 4 octets from the tail (#4)
uint8_t buf[2] = {0x02, 0x00}; uint8_t buf[2] = {0x02, 0x00};
out.append((char*) (buf), 2); out.push_back(buf[0]);
out.push_back(buf[1]);
return true; return true;
} }
@ -114,10 +139,10 @@ namespace ix
output = _compressBufferSize - _deflateState.avail_out; output = _compressBufferSize - _deflateState.avail_out;
out.append((char*) (_compressBuffer.get()), output); out.insert(out.end(), _compressBuffer.get(), _compressBuffer.get() + output);
} while (_deflateState.avail_out == 0); } while (_deflateState.avail_out == 0);
if (endsWith(out, kEmptyUncompressedBlock)) if (endsWithEmptyUnCompressedBlock(out))
{ {
out.resize(out.size() - 4); out.resize(out.size() - 4);
} }

View File

@ -9,6 +9,7 @@
#include "zlib.h" #include "zlib.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector>
namespace ix namespace ix
{ {
@ -20,9 +21,13 @@ namespace ix
bool init(uint8_t deflateBits, bool clientNoContextTakeOver); bool init(uint8_t deflateBits, bool clientNoContextTakeOver);
bool compress(const std::string& in, std::string& out); bool compress(const std::string& in, std::string& out);
bool compress(const std::string& in, std::vector<uint8_t>& out);
bool compress(const std::vector<uint8_t>& in, std::string& out);
bool compress(const std::vector<uint8_t>& in, std::vector<uint8_t>& out);
private: private:
static bool endsWith(const std::string& value, const std::string& ending); template<typename T, typename S> bool compressData(const T& in, S& out);
template<typename T> bool endsWithEmptyUnCompressedBlock(const T& value);
int _flush; int _flush;
size_t _compressBufferSize; size_t _compressBufferSize;

View File

@ -31,9 +31,26 @@ namespace ix
return c; return c;
} }
std::string compressAndDecompressVector(const std::string& a)
{
std::string b, c;
std::vector<uint8_t> vec(a.begin(), a.end());
WebSocketPerMessageDeflateCompressor compressor;
compressor.init(11, true);
compressor.compress(vec, b);
WebSocketPerMessageDeflateDecompressor decompressor;
decompressor.init(11, true);
decompressor.decompress(b, c);
return c;
}
TEST_CASE("per-message-deflate-codec", "[zlib]") TEST_CASE("per-message-deflate-codec", "[zlib]")
{ {
SECTION("blah") SECTION("string api")
{ {
REQUIRE(compressAndDecompress("") == ""); REQUIRE(compressAndDecompress("") == "");
REQUIRE(compressAndDecompress("foo") == "foo"); REQUIRE(compressAndDecompress("foo") == "foo");
@ -41,6 +58,15 @@ namespace ix
REQUIRE(compressAndDecompress("asdcaseqw`21897dehqwed") == "asdcaseqw`21897dehqwed"); REQUIRE(compressAndDecompress("asdcaseqw`21897dehqwed") == "asdcaseqw`21897dehqwed");
REQUIRE(compressAndDecompress("/usr/local/include/ixwebsocket/IXSocketAppleSSL.h") == "/usr/local/include/ixwebsocket/IXSocketAppleSSL.h"); REQUIRE(compressAndDecompress("/usr/local/include/ixwebsocket/IXSocketAppleSSL.h") == "/usr/local/include/ixwebsocket/IXSocketAppleSSL.h");
} }
SECTION("vector api")
{
REQUIRE(compressAndDecompressVector("") == "");
REQUIRE(compressAndDecompressVector("foo") == "foo");
REQUIRE(compressAndDecompressVector("bar") == "bar");
REQUIRE(compressAndDecompressVector("asdcaseqw`21897dehqwed") == "asdcaseqw`21897dehqwed");
REQUIRE(compressAndDecompressVector("/usr/local/include/ixwebsocket/IXSocketAppleSSL.h") == "/usr/local/include/ixwebsocket/IXSocketAppleSSL.h");
}
} }
} // namespace ix } // namespace ix