/* * Authored by Alex Hultman, 2018-2019. * Intellectual property of third-party. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * http://www.apache.org/licenses/LICENSE-2.0 * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef UWS_WEBSOCKETPROTOCOL_H #define UWS_WEBSOCKETPROTOCOL_H #include #include #include namespace uWS { enum OpCode : unsigned char { TEXT = 1, BINARY = 2, CLOSE = 8, PING = 9, PONG = 10 }; enum { CLIENT, SERVER }; // 24 bytes perfectly template struct WebSocketState { public: static const unsigned int SHORT_MESSAGE_HEADER = isServer ? 6 : 2; static const unsigned int MEDIUM_MESSAGE_HEADER = isServer ? 8 : 4; static const unsigned int LONG_MESSAGE_HEADER = isServer ? 14 : 10; // 16 bytes struct State { unsigned int wantsHead : 1; unsigned int spillLength : 4; int opStack : 2; // -1, 0, 1 unsigned int lastFin : 1; // 15 bytes unsigned char spill[LONG_MESSAGE_HEADER - 1]; OpCode opCode[2]; State() { wantsHead = true; spillLength = 0; opStack = -1; lastFin = true; } } state; // 8 bytes unsigned int remainingBytes = 0; char mask[isServer ? 4 : 1]; }; namespace protocol { template T bit_cast(char *c) { T val; memcpy(&val, c, sizeof(T)); return val; } /* Byte swap for little-endian systems */ template T cond_byte_swap(T value) { uint32_t endian_test = 1; if (*((char *)&endian_test)) { union { T i; uint8_t b[sizeof(T)]; } src = { value }, dst; for (unsigned int i = 0; i < sizeof(value); i++) { dst.b[i] = src.b[sizeof(value) - 1 - i]; } return dst.i; } return value; } // Based on utf8_check.c by Markus Kuhn, 2005 // https://www.cl.cam.ac.uk/~mgk25/ucs/utf8_check.c // Optimized for predominantly 7-bit content by Alex Hultman, 2016 // Licensed as Zlib, like the rest of this project static bool isValidUtf8(unsigned char *s, size_t length) { for (unsigned char *e = s + length; s != e; ) { if (s + 4 <= e) { uint32_t tmp; memcpy(&tmp, s, 4); if ((tmp & 0x80808080) == 0) { s += 4; continue; } } while (!(*s & 0x80)) { if (++s == e) { return true; } } if ((s[0] & 0x60) == 0x40) { if (s + 1 >= e || (s[1] & 0xc0) != 0x80 || (s[0] & 0xfe) == 0xc0) { return false; } s += 2; } else if ((s[0] & 0xf0) == 0xe0) { if (s + 2 >= e || (s[1] & 0xc0) != 0x80 || (s[2] & 0xc0) != 0x80 || (s[0] == 0xe0 && (s[1] & 0xe0) == 0x80) || (s[0] == 0xed && (s[1] & 0xe0) == 0xa0)) { return false; } s += 3; } else if ((s[0] & 0xf8) == 0xf0) { if (s + 3 >= e || (s[1] & 0xc0) != 0x80 || (s[2] & 0xc0) != 0x80 || (s[3] & 0xc0) != 0x80 || (s[0] == 0xf0 && (s[1] & 0xf0) == 0x80) || (s[0] == 0xf4 && s[1] > 0x8f) || s[0] > 0xf4) { return false; } s += 4; } else { return false; } } return true; } struct CloseFrame { uint16_t code; char *message; size_t length; }; static inline CloseFrame parseClosePayload(char *src, size_t length) { CloseFrame cf = {}; if (length >= 2) { memcpy(&cf.code, src, 2); cf = {cond_byte_swap(cf.code), src + 2, length - 2}; if (cf.code < 1000 || cf.code > 4999 || (cf.code > 1011 && cf.code < 4000) || (cf.code >= 1004 && cf.code <= 1006) || !isValidUtf8((unsigned char *) cf.message, cf.length)) { return {}; } } return cf; } static inline size_t formatClosePayload(char *dst, uint16_t code, const char *message, size_t length) { if (code) { code = cond_byte_swap(code); memcpy(dst, &code, 2); /* It is invalid to pass nullptr to memcpy, even though length is 0 */ if (message) { memcpy(dst + 2, message, length); } return length + 2; } return 0; } static inline size_t messageFrameSize(size_t messageSize) { if (messageSize < 126) { return 2 + messageSize; } else if (messageSize <= UINT16_MAX) { return 4 + messageSize; } return 10 + messageSize; } enum { SND_CONTINUATION = 1, SND_NO_FIN = 2, SND_COMPRESSED = 64 }; template static inline size_t formatMessage(char *dst, const char *src, size_t length, OpCode opCode, size_t reportedLength, bool compressed) { size_t messageLength; size_t headerLength; if (reportedLength < 126) { headerLength = 2; dst[1] = (char) reportedLength; } else if (reportedLength <= UINT16_MAX) { headerLength = 4; dst[1] = 126; uint16_t tmp = cond_byte_swap((uint16_t) reportedLength); memcpy(&dst[2], &tmp, sizeof(uint16_t)); } else { headerLength = 10; dst[1] = 127; uint64_t tmp = cond_byte_swap((uint64_t) reportedLength); memcpy(&dst[2], &tmp, sizeof(uint64_t)); } int flags = 0; dst[0] = (char) ((flags & SND_NO_FIN ? 0 : 128) | (compressed ? SND_COMPRESSED : 0)); if (!(flags & SND_CONTINUATION)) { dst[0] |= (char) opCode; } char mask[4]; if (!isServer) { dst[1] |= 0x80; uint32_t random = rand(); memcpy(mask, &random, 4); memcpy(dst + headerLength, &random, 4); headerLength += 4; } messageLength = headerLength + length; memcpy(dst + headerLength, src, length); if (!isServer) { // overwrites up to 3 bytes outside of the given buffer! //WebSocketProtocol::unmaskInplace(dst + headerLength, dst + headerLength + length, mask); // this is not optimal char *start = dst + headerLength; char *stop = start + length; int i = 0; while (start != stop) { (*start++) ^= mask[i++ % 4]; } } return messageLength; } } // essentially this is only a parser template struct WIN32_EXPORT WebSocketProtocol { public: static const unsigned int SHORT_MESSAGE_HEADER = isServer ? 6 : 2; static const unsigned int MEDIUM_MESSAGE_HEADER = isServer ? 8 : 4; static const unsigned int LONG_MESSAGE_HEADER = isServer ? 14 : 10; protected: static inline bool isFin(char *frame) {return *((unsigned char *) frame) & 128;} static inline unsigned char getOpCode(char *frame) {return *((unsigned char *) frame) & 15;} static inline unsigned char payloadLength(char *frame) {return ((unsigned char *) frame)[1] & 127;} static inline bool rsv23(char *frame) {return *((unsigned char *) frame) & 48;} static inline bool rsv1(char *frame) {return *((unsigned char *) frame) & 64;} static inline void unmaskImprecise(char *dst, char *src, char *mask, unsigned int length) { for (unsigned int n = (length >> 2) + 1; n; n--) { *(dst++) = *(src++) ^ mask[0]; *(dst++) = *(src++) ^ mask[1]; *(dst++) = *(src++) ^ mask[2]; *(dst++) = *(src++) ^ mask[3]; } } static inline void unmaskImpreciseCopyMask(char *dst, char *src, char *maskPtr, unsigned int length) { char mask[4] = {maskPtr[0], maskPtr[1], maskPtr[2], maskPtr[3]}; unmaskImprecise(dst, src, mask, length); } static inline void rotateMask(unsigned int offset, char *mask) { char originalMask[4] = {mask[0], mask[1], mask[2], mask[3]}; mask[(0 + offset) % 4] = originalMask[0]; mask[(1 + offset) % 4] = originalMask[1]; mask[(2 + offset) % 4] = originalMask[2]; mask[(3 + offset) % 4] = originalMask[3]; } static inline void unmaskInplace(char *data, char *stop, char *mask) { while (data < stop) { *(data++) ^= mask[0]; *(data++) ^= mask[1]; *(data++) ^= mask[2]; *(data++) ^= mask[3]; } } template static inline bool consumeMessage(T payLength, char *&src, unsigned int &length, WebSocketState *wState, void *user) { if (getOpCode(src)) { if (wState->state.opStack == 1 || (!wState->state.lastFin && getOpCode(src) < 2)) { Impl::forceClose(wState, user); return true; } wState->state.opCode[++wState->state.opStack] = (OpCode) getOpCode(src); } else if (wState->state.opStack == -1) { Impl::forceClose(wState, user); return true; } wState->state.lastFin = isFin(src); if (Impl::refusePayloadLength(payLength, wState, user)) { Impl::forceClose(wState, user); return true; } if (payLength + MESSAGE_HEADER <= length) { if (isServer) { unmaskImpreciseCopyMask(src + MESSAGE_HEADER - 4, src + MESSAGE_HEADER, src + MESSAGE_HEADER - 4, (unsigned int) payLength); if (Impl::handleFragment(src + MESSAGE_HEADER - 4, payLength, 0, wState->state.opCode[wState->state.opStack], isFin(src), wState, user)) { return true; } } else { if (Impl::handleFragment(src + MESSAGE_HEADER, payLength, 0, wState->state.opCode[wState->state.opStack], isFin(src), wState, user)) { return true; } } if (isFin(src)) { wState->state.opStack--; } src += payLength + MESSAGE_HEADER; length -= (unsigned int) (payLength + MESSAGE_HEADER); wState->state.spillLength = 0; return false; } else { wState->state.spillLength = 0; wState->state.wantsHead = false; wState->remainingBytes = (unsigned int) (payLength - length + MESSAGE_HEADER); bool fin = isFin(src); if (isServer) { memcpy(wState->mask, src + MESSAGE_HEADER - 4, 4); unmaskImprecise(src, src + MESSAGE_HEADER, wState->mask, length - MESSAGE_HEADER); rotateMask(4 - (length - MESSAGE_HEADER) % 4, wState->mask); } else { src += MESSAGE_HEADER; } Impl::handleFragment(src, length - MESSAGE_HEADER, wState->remainingBytes, wState->state.opCode[wState->state.opStack], fin, wState, user); return true; } } static inline bool consumeContinuation(char *&src, unsigned int &length, WebSocketState *wState, void *user) { if (wState->remainingBytes <= length) { if (isServer) { int n = wState->remainingBytes >> 2; unmaskInplace(src, src + n * 4, wState->mask); for (int i = 0, s = wState->remainingBytes % 4; i < s; i++) { src[n * 4 + i] ^= wState->mask[i]; } } if (Impl::handleFragment(src, wState->remainingBytes, 0, wState->state.opCode[wState->state.opStack], wState->state.lastFin, wState, user)) { return false; } if (wState->state.lastFin) { wState->state.opStack--; } src += wState->remainingBytes; length -= wState->remainingBytes; wState->state.wantsHead = true; return true; } else { if (isServer) { unmaskInplace(src, src + ((length >> 2) + 1) * 4, wState->mask); } wState->remainingBytes -= length; if (Impl::handleFragment(src, length, wState->remainingBytes, wState->state.opCode[wState->state.opStack], wState->state.lastFin, wState, user)) { return false; } if (isServer && length % 4) { rotateMask(4 - (length % 4), wState->mask); } return false; } } public: WebSocketProtocol() { } static inline void consume(char *src, unsigned int length, WebSocketState *wState, void *user) { if (wState->state.spillLength) { src -= wState->state.spillLength; length += wState->state.spillLength; memcpy(src, wState->state.spill, wState->state.spillLength); } if (wState->state.wantsHead) { parseNext: while (length >= SHORT_MESSAGE_HEADER) { // invalid reserved bits / invalid opcodes / invalid control frames / set compressed frame if ((rsv1(src) && !Impl::setCompressed(wState, user)) || rsv23(src) || (getOpCode(src) > 2 && getOpCode(src) < 8) || getOpCode(src) > 10 || (getOpCode(src) > 2 && (!isFin(src) || payloadLength(src) > 125))) { Impl::forceClose(wState, user); return; } if (payloadLength(src) < 126) { if (consumeMessage(payloadLength(src), src, length, wState, user)) { return; } } else if (payloadLength(src) == 126) { if (length < MEDIUM_MESSAGE_HEADER) { break; } else if(consumeMessage(protocol::cond_byte_swap(protocol::bit_cast(src + 2)), src, length, wState, user)) { return; } } else if (length < LONG_MESSAGE_HEADER) { break; } else if (consumeMessage(protocol::cond_byte_swap(protocol::bit_cast(src + 2)), src, length, wState, user)) { return; } } if (length) { memcpy(wState->state.spill, src, length); wState->state.spillLength = length & 0xf; } } else if (consumeContinuation(src, length, wState, user)) { goto parseNext; } } static const int CONSUME_POST_PADDING = 4; static const int CONSUME_PRE_PADDING = LONG_MESSAGE_HEADER - 1; }; } #endif // UWS_WEBSOCKETPROTOCOL_H