Handle Sec-WebSocket-Accept correctly
This commit is contained in:
		@@ -19,6 +19,8 @@
 | 
				
			|||||||
# endif
 | 
					# endif
 | 
				
			||||||
#endif
 | 
					#endif
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "libwshandshake.hpp"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// #include <unistd.h>
 | 
					// #include <unistd.h>
 | 
				
			||||||
#include <string.h>
 | 
					#include <string.h>
 | 
				
			||||||
#include <stdlib.h>
 | 
					#include <stdlib.h>
 | 
				
			||||||
@@ -30,6 +32,8 @@
 | 
				
			|||||||
#include <iostream>
 | 
					#include <iostream>
 | 
				
			||||||
#include <sstream>
 | 
					#include <sstream>
 | 
				
			||||||
#include <regex>
 | 
					#include <regex>
 | 
				
			||||||
 | 
					#include <unordered_map>
 | 
				
			||||||
 | 
					#include <random>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace ix {
 | 
					namespace ix {
 | 
				
			||||||
@@ -128,6 +132,28 @@ namespace ix {
 | 
				
			|||||||
        std::cout << "-------------------------------" << std::endl;
 | 
					        std::cout << "-------------------------------" << std::endl;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    std::string WebSocketTransport::genRandomString(const int len)
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					        static const char alphanum[] =
 | 
				
			||||||
 | 
					            "0123456789"
 | 
				
			||||||
 | 
					            "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
 | 
				
			||||||
 | 
					            "abcdefghijklmnopqrstuvwxyz";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        std::random_device r;
 | 
				
			||||||
 | 
					        std::default_random_engine e1(r());
 | 
				
			||||||
 | 
					        std::uniform_int_distribution<int> dist(0, sizeof(alphanum) - 1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        std::string s;
 | 
				
			||||||
 | 
					        s.resize(len);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for (int i = 0; i < len; ++i)
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            s[i] += alphanum[dist(e1)];
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return s;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    WebSocketInitResult WebSocketTransport::init()
 | 
					    WebSocketInitResult WebSocketTransport::init()
 | 
				
			||||||
    {
 | 
					    {
 | 
				
			||||||
        std::string protocol, host, path, query;
 | 
					        std::string protocol, host, path, query;
 | 
				
			||||||
@@ -169,6 +195,16 @@ namespace ix {
 | 
				
			|||||||
            return WebSocketInitResult(false, 0, ss.str());
 | 
					            return WebSocketInitResult(false, 0, ss.str());
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        //
 | 
				
			||||||
 | 
					        // Generate a random 24 bytes string which looks like it is base64 encoded
 | 
				
			||||||
 | 
					        // y3JJHMbDL1EzLkh9GBhXDw==
 | 
				
			||||||
 | 
					        // 0cb3Vd9HkbpVVumoS3Noka==
 | 
				
			||||||
 | 
					        //
 | 
				
			||||||
 | 
					        // See https://stackoverflow.com/questions/18265128/what-is-sec-websocket-key-for
 | 
				
			||||||
 | 
					        //
 | 
				
			||||||
 | 
					        std::string secWebSocketKey = genRandomString(22);
 | 
				
			||||||
 | 
					        secWebSocketKey += "==";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        char line[256];
 | 
					        char line[256];
 | 
				
			||||||
        int status;
 | 
					        int status;
 | 
				
			||||||
        int i;
 | 
					        int i;
 | 
				
			||||||
@@ -177,12 +213,10 @@ namespace ix {
 | 
				
			|||||||
            "Host: %s:%d\r\n"
 | 
					            "Host: %s:%d\r\n"
 | 
				
			||||||
            "Upgrade: websocket\r\n"
 | 
					            "Upgrade: websocket\r\n"
 | 
				
			||||||
            "Connection: Upgrade\r\n"
 | 
					            "Connection: Upgrade\r\n"
 | 
				
			||||||
            "Sec-WebSocket-Key: x3JJHMbDL1EzLkh9GBhXDw==\r\n"
 | 
					            "Sec-WebSocket-Key: %s\r\n"
 | 
				
			||||||
            "Sec-WebSocket-Version: 13\r\n"
 | 
					            "Sec-WebSocket-Version: 13\r\n"
 | 
				
			||||||
            "\r\n",
 | 
					            "\r\n",
 | 
				
			||||||
            path.c_str(), host.c_str(), port);
 | 
					            path.c_str(), host.c_str(), port, secWebSocketKey.c_str());
 | 
				
			||||||
 | 
					 | 
				
			||||||
        // XXX: this should be done non-blocking,
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        size_t lineSize = strlen(line);
 | 
					        size_t lineSize = strlen(line);
 | 
				
			||||||
        if (_socket->send(line, lineSize) != lineSize)
 | 
					        if (_socket->send(line, lineSize) != lineSize)
 | 
				
			||||||
@@ -224,9 +258,12 @@ namespace ix {
 | 
				
			|||||||
            return WebSocketInitResult(false, status, ss.str());
 | 
					            return WebSocketInitResult(false, status, ss.str());
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // TODO: verify response headers,
 | 
					        std::unordered_map<std::string, std::string> headers;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        while (true) 
 | 
					        while (true) 
 | 
				
			||||||
        {
 | 
					        {
 | 
				
			||||||
 | 
					            int colon = 0;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            for (i = 0;
 | 
					            for (i = 0;
 | 
				
			||||||
                 i < 2 || (i < 255 && line[i-2] != '\r' && line[i-1] != '\n');
 | 
					                 i < 2 || (i < 255 && line[i-2] != '\r' && line[i-1] != '\n');
 | 
				
			||||||
                 ++i)
 | 
					                 ++i)
 | 
				
			||||||
@@ -235,11 +272,38 @@ namespace ix {
 | 
				
			|||||||
                {
 | 
					                {
 | 
				
			||||||
                    return WebSocketInitResult(false, status, std::string("Failed reading response header from ") + _url);
 | 
					                    return WebSocketInitResult(false, status, std::string("Failed reading response header from ") + _url);
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                if (line[i] == ':' && colon == 0)
 | 
				
			||||||
 | 
					                {
 | 
				
			||||||
 | 
					                    colon = i;
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
            if (line[0] == '\r' && line[1] == '\n')
 | 
					            if (line[0] == '\r' && line[1] == '\n')
 | 
				
			||||||
            {
 | 
					            {
 | 
				
			||||||
                break;
 | 
					                break;
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            // line is a single header entry. split by ':', and add it to our
 | 
				
			||||||
 | 
					            // header map. ignore lines with no colon.
 | 
				
			||||||
 | 
					            if (colon > 0)
 | 
				
			||||||
 | 
					            {
 | 
				
			||||||
 | 
					                line[i] = '\0';
 | 
				
			||||||
 | 
					                std::string lineStr(line);
 | 
				
			||||||
 | 
					                // colon is ':', colon+1 is ' ', colon+2 is the start of the value.
 | 
				
			||||||
 | 
					                // i is end of string (\0), i-colon is length of string minus key;
 | 
				
			||||||
 | 
					                // subtract 1 for '\0', 1 for '\n', 1 for '\r',
 | 
				
			||||||
 | 
					                // 1 for the ' ' after the ':', and total is -4
 | 
				
			||||||
 | 
					                headers[lineStr.substr(0, colon)] =
 | 
				
			||||||
 | 
					                    lineStr.substr(colon + 2, i - colon - 4);
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        char output[29] = {};
 | 
				
			||||||
 | 
					        WebSocketHandshake::generate(secWebSocketKey.c_str(), output);
 | 
				
			||||||
 | 
					        if (std::string(output) != headers["Sec-WebSocket-Accept"])
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            std::string errorMsg("Invalid Sec-WebSocket-Accept value");
 | 
				
			||||||
 | 
					            return WebSocketInitResult(false, status, errorMsg);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        _socket->configure();
 | 
					        _socket->configure();
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -142,5 +142,6 @@ namespace ix
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        unsigned getRandomUnsigned();
 | 
					        unsigned getRandomUnsigned();
 | 
				
			||||||
        void unmaskReceiveBuffer(const wsheader_type& ws);
 | 
					        void unmaskReceiveBuffer(const wsheader_type& ws);
 | 
				
			||||||
 | 
					        std::string genRandomString(const int len);
 | 
				
			||||||
    };
 | 
					    };
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										128
									
								
								ixwebsocket/libwshandshake.hpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										128
									
								
								ixwebsocket/libwshandshake.hpp
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,128 @@
 | 
				
			|||||||
 | 
					// Copyright (c) 2016 Alex Hultman and contributors
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// This software is provided 'as-is', without any express or implied
 | 
				
			||||||
 | 
					// warranty. In no event will the authors be held liable for any damages
 | 
				
			||||||
 | 
					// arising from the use of this software.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Permission is granted to anyone to use this software for any purpose,
 | 
				
			||||||
 | 
					// including commercial applications, and to alter it and redistribute it
 | 
				
			||||||
 | 
					// freely, subject to the following restrictions:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// 1. The origin of this software must not be misrepresented; you must not
 | 
				
			||||||
 | 
					//    claim that you wrote the original software. If you use this software
 | 
				
			||||||
 | 
					//    in a product, an acknowledgement in the product documentation would be
 | 
				
			||||||
 | 
					//    appreciated but is not required.
 | 
				
			||||||
 | 
					// 2. Altered source versions must be plainly marked as such, and must not be
 | 
				
			||||||
 | 
					//    misrepresented as being the original software.
 | 
				
			||||||
 | 
					// 3. This notice may not be removed or altered from any source distribution.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#pragma once
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <cstdint>
 | 
				
			||||||
 | 
					#include <cstddef>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class WebSocketHandshake {
 | 
				
			||||||
 | 
					    template <int N, typename T>
 | 
				
			||||||
 | 
					    struct static_for {
 | 
				
			||||||
 | 
					        void operator()(uint32_t *a, uint32_t *b) {
 | 
				
			||||||
 | 
					            static_for<N - 1, T>()(a, b);
 | 
				
			||||||
 | 
					            T::template f<N - 1>(a, b);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    template <typename T>
 | 
				
			||||||
 | 
					    struct static_for<0, T> {
 | 
				
			||||||
 | 
					        void operator()(uint32_t *a, uint32_t *hash) {}
 | 
				
			||||||
 | 
					    };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    template <int state>
 | 
				
			||||||
 | 
					    struct Sha1Loop {
 | 
				
			||||||
 | 
					        static inline uint32_t rol(uint32_t value, size_t bits) {return (value << bits) | (value >> (32 - bits));}
 | 
				
			||||||
 | 
					        static inline uint32_t blk(uint32_t b[16], size_t i) {
 | 
				
			||||||
 | 
					            return rol(b[(i + 13) & 15] ^ b[(i + 8) & 15] ^ b[(i + 2) & 15] ^ b[i], 1);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        template <int i>
 | 
				
			||||||
 | 
					        static inline void f(uint32_t *a, uint32_t *b) {
 | 
				
			||||||
 | 
					            switch (state) {
 | 
				
			||||||
 | 
					            case 1:
 | 
				
			||||||
 | 
					                a[i % 5] += ((a[(3 + i) % 5] & (a[(2 + i) % 5] ^ a[(1 + i) % 5])) ^ a[(1 + i) % 5]) + b[i] + 0x5a827999 + rol(a[(4 + i) % 5], 5);
 | 
				
			||||||
 | 
					                a[(3 + i) % 5] = rol(a[(3 + i) % 5], 30);
 | 
				
			||||||
 | 
					                break;
 | 
				
			||||||
 | 
					            case 2:
 | 
				
			||||||
 | 
					                b[i] = blk(b, i);
 | 
				
			||||||
 | 
					                a[(1 + i) % 5] += ((a[(4 + i) % 5] & (a[(3 + i) % 5] ^ a[(2 + i) % 5])) ^ a[(2 + i) % 5]) + b[i] + 0x5a827999 + rol(a[(5 + i) % 5], 5);
 | 
				
			||||||
 | 
					                a[(4 + i) % 5] = rol(a[(4 + i) % 5], 30);
 | 
				
			||||||
 | 
					                break;
 | 
				
			||||||
 | 
					            case 3:
 | 
				
			||||||
 | 
					                b[(i + 4) % 16] = blk(b, (i + 4) % 16);
 | 
				
			||||||
 | 
					                a[i % 5] += (a[(3 + i) % 5] ^ a[(2 + i) % 5] ^ a[(1 + i) % 5]) + b[(i + 4) % 16] + 0x6ed9eba1 + rol(a[(4 + i) % 5], 5);
 | 
				
			||||||
 | 
					                a[(3 + i) % 5] = rol(a[(3 + i) % 5], 30);
 | 
				
			||||||
 | 
					                break;
 | 
				
			||||||
 | 
					            case 4:
 | 
				
			||||||
 | 
					                b[(i + 8) % 16] = blk(b, (i + 8) % 16);
 | 
				
			||||||
 | 
					                a[i % 5] += (((a[(3 + i) % 5] | a[(2 + i) % 5]) & a[(1 + i) % 5]) | (a[(3 + i) % 5] & a[(2 + i) % 5])) + b[(i + 8) % 16] + 0x8f1bbcdc + rol(a[(4 + i) % 5], 5);
 | 
				
			||||||
 | 
					                a[(3 + i) % 5] = rol(a[(3 + i) % 5], 30);
 | 
				
			||||||
 | 
					                break;
 | 
				
			||||||
 | 
					            case 5:
 | 
				
			||||||
 | 
					                b[(i + 12) % 16] = blk(b, (i + 12) % 16);
 | 
				
			||||||
 | 
					                a[i % 5] += (a[(3 + i) % 5] ^ a[(2 + i) % 5] ^ a[(1 + i) % 5]) + b[(i + 12) % 16] + 0xca62c1d6 + rol(a[(4 + i) % 5], 5);
 | 
				
			||||||
 | 
					                a[(3 + i) % 5] = rol(a[(3 + i) % 5], 30);
 | 
				
			||||||
 | 
					                break;
 | 
				
			||||||
 | 
					            case 6:
 | 
				
			||||||
 | 
					                b[i] += a[4 - i];
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    static inline void sha1(uint32_t hash[5], uint32_t b[16]) {
 | 
				
			||||||
 | 
					        uint32_t a[5] = {hash[4], hash[3], hash[2], hash[1], hash[0]};
 | 
				
			||||||
 | 
					        static_for<16, Sha1Loop<1>>()(a, b);
 | 
				
			||||||
 | 
					        static_for<4, Sha1Loop<2>>()(a, b);
 | 
				
			||||||
 | 
					        static_for<20, Sha1Loop<3>>()(a, b);
 | 
				
			||||||
 | 
					        static_for<20, Sha1Loop<4>>()(a, b);
 | 
				
			||||||
 | 
					        static_for<20, Sha1Loop<5>>()(a, b);
 | 
				
			||||||
 | 
					        static_for<5, Sha1Loop<6>>()(a, hash);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    static inline void base64(unsigned char *src, char *dst) {
 | 
				
			||||||
 | 
					        const char *b64 = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
 | 
				
			||||||
 | 
					        for (int i = 0; i < 18; i += 3) {
 | 
				
			||||||
 | 
					            *dst++ = b64[(src[i] >> 2) & 63];
 | 
				
			||||||
 | 
					            *dst++ = b64[((src[i] & 3) << 4) | ((src[i + 1] & 240) >> 4)];
 | 
				
			||||||
 | 
					            *dst++ = b64[((src[i + 1] & 15) << 2) | ((src[i + 2] & 192) >> 6)];
 | 
				
			||||||
 | 
					            *dst++ = b64[src[i + 2] & 63];
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        *dst++ = b64[(src[18] >> 2) & 63];
 | 
				
			||||||
 | 
					        *dst++ = b64[((src[18] & 3) << 4) | ((src[19] & 240) >> 4)];
 | 
				
			||||||
 | 
					        *dst++ = b64[((src[19] & 15) << 2)];
 | 
				
			||||||
 | 
					        *dst++ = '=';
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					public:
 | 
				
			||||||
 | 
					    static inline void generate(const char input[24], char output[28]) {
 | 
				
			||||||
 | 
					        uint32_t b_output[5] = {
 | 
				
			||||||
 | 
					            0x67452301, 0xefcdab89, 0x98badcfe, 0x10325476, 0xc3d2e1f0
 | 
				
			||||||
 | 
					        };
 | 
				
			||||||
 | 
					        uint32_t b_input[16] = {
 | 
				
			||||||
 | 
					            0, 0, 0, 0, 0, 0, 0x32353845, 0x41464135, 0x2d453931, 0x342d3437, 0x44412d39,
 | 
				
			||||||
 | 
					            0x3543412d, 0x43354142, 0x30444338, 0x35423131, 0x80000000
 | 
				
			||||||
 | 
					        };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for (int i = 0; i < 6; i++) {
 | 
				
			||||||
 | 
					            b_input[i] = (input[4 * i + 3] & 0xff) | (input[4 * i + 2] & 0xff) << 8 | (input[4 * i + 1] & 0xff) << 16 | (input[4 * i + 0] & 0xff) << 24;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        sha1(b_output, b_input);
 | 
				
			||||||
 | 
					        uint32_t last_b[16] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 480};
 | 
				
			||||||
 | 
					        sha1(b_output, last_b);
 | 
				
			||||||
 | 
					        for (int i = 0; i < 5; i++) {
 | 
				
			||||||
 | 
					            uint32_t tmp = b_output[i];
 | 
				
			||||||
 | 
					            char *bytes = (char *) &b_output[i];
 | 
				
			||||||
 | 
					            bytes[3] = tmp & 0xff;
 | 
				
			||||||
 | 
					            bytes[2] = (tmp >> 8) & 0xff;
 | 
				
			||||||
 | 
					            bytes[1] = (tmp >> 16) & 0xff;
 | 
				
			||||||
 | 
					            bytes[0] = (tmp >> 24) & 0xff;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        base64((unsigned char *) b_output, output);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
		Reference in New Issue
	
	Block a user