Merge commit 'c992cb4e42cc223f67ede0e48d7ff3f4947af0c6' as 'test/compatibility/C/uWebSockets'

This commit is contained in:
Benjamin Sergeant
2020-01-04 15:41:03 -08:00
68 changed files with 9564 additions and 0 deletions

View File

@ -0,0 +1,348 @@
/*
* Authored by Alex Hultman, 2018-2019.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef UWS_APP_H
#define UWS_APP_H
/* An app is a convenience wrapper of some of the most used fuctionalities and allows a
* builder-pattern kind of init. Apps operate on the implicit thread local Loop */
#include "HttpContext.h"
#include "HttpResponse.h"
#include "WebSocketContext.h"
#include "WebSocket.h"
#include "WebSocketExtensions.h"
#include "WebSocketHandshake.h"
namespace uWS {
/* Compress options (really more like PerMessageDeflateOptions) */
enum CompressOptions {
/* Compression disabled */
DISABLED = 0,
/* We compress using a shared non-sliding window. No added memory usage, worse compression. */
SHARED_COMPRESSOR = 1,
/* We compress using a dedicated sliding window. Major memory usage added, better compression of similarly repeated messages. */
DEDICATED_COMPRESSOR = 2
};
template <bool SSL>
struct TemplatedApp {
private:
/* The app always owns at least one http context, but creates websocket contexts on demand */
HttpContext<SSL> *httpContext;
std::vector<WebSocketContext<SSL, true> *> webSocketContexts;
public:
/* Attaches a "filter" function to track socket connections/disconnections */
void filter(fu2::unique_function<void(HttpResponse<SSL> *, int)> &&filterHandler) {
httpContext->filter(std::move(filterHandler));
}
/* Publishes a message to all websocket contexts */
void publish(std::string_view topic, std::string_view message, OpCode opCode, bool compress = false) {
for (auto *webSocketContext : webSocketContexts) {
webSocketContext->getExt()->publish(topic, message, opCode, compress);
}
}
~TemplatedApp() {
/* Let's just put everything here */
if (httpContext) {
httpContext->free();
for (auto *webSocketContext : webSocketContexts) {
webSocketContext->free();
}
}
}
/* Disallow copying, only move */
TemplatedApp(const TemplatedApp &other) = delete;
TemplatedApp(TemplatedApp &&other) {
/* Move HttpContext */
httpContext = other.httpContext;
other.httpContext = nullptr;
/* Move webSocketContexts */
webSocketContexts = std::move(other.webSocketContexts);
}
TemplatedApp(us_socket_context_options_t options = {}) {
httpContext = uWS::HttpContext<SSL>::create(uWS::Loop::get(), options);
}
bool constructorFailed() {
return !httpContext;
}
struct WebSocketBehavior {
CompressOptions compression = DISABLED;
int maxPayloadLength = 16 * 1024;
int idleTimeout = 120;
int maxBackpressure = 1 * 1024 * 1204;
fu2::unique_function<void(uWS::WebSocket<SSL, true> *, HttpRequest *)> open = nullptr;
fu2::unique_function<void(uWS::WebSocket<SSL, true> *, std::string_view, uWS::OpCode)> message = nullptr;
fu2::unique_function<void(uWS::WebSocket<SSL, true> *)> drain = nullptr;
fu2::unique_function<void(uWS::WebSocket<SSL, true> *)> ping = nullptr;
fu2::unique_function<void(uWS::WebSocket<SSL, true> *)> pong = nullptr;
fu2::unique_function<void(uWS::WebSocket<SSL, true> *, int, std::string_view)> close = nullptr;
};
template <typename UserData>
TemplatedApp &&ws(std::string pattern, WebSocketBehavior &&behavior) {
/* Don't compile if alignment rules cannot be satisfied */
static_assert(alignof(UserData) <= LIBUS_EXT_ALIGNMENT,
"µWebSockets cannot satisfy UserData alignment requirements. You need to recompile µSockets with LIBUS_EXT_ALIGNMENT adjusted accordingly.");
/* Every route has its own websocket context with its own behavior and user data type */
auto *webSocketContext = WebSocketContext<SSL, true>::create(Loop::get(), (us_socket_context_t *) httpContext);
/* We need to clear this later on */
webSocketContexts.push_back(webSocketContext);
/* Quick fix to disable any compression if set */
#ifdef UWS_NO_ZLIB
behavior.compression = uWS::DISABLED;
#endif
/* If we are the first one to use compression, initialize it */
if (behavior.compression) {
LoopData *loopData = (LoopData *) us_loop_ext(us_socket_context_loop(SSL, webSocketContext->getSocketContext()));
/* Initialize loop's deflate inflate streams */
if (!loopData->zlibContext) {
loopData->zlibContext = new ZlibContext;
loopData->inflationStream = new InflationStream;
loopData->deflationStream = new DeflationStream;
}
}
/* Copy all handlers */
webSocketContext->getExt()->messageHandler = std::move(behavior.message);
webSocketContext->getExt()->drainHandler = std::move(behavior.drain);
webSocketContext->getExt()->closeHandler = std::move([closeHandler = std::move(behavior.close)](WebSocket<SSL, true> *ws, int code, std::string_view message) mutable {
closeHandler(ws, code, message);
/* Destruct user data after returning from close handler */
((UserData *) ws->getUserData())->~UserData();
});
/* Copy settings */
webSocketContext->getExt()->maxPayloadLength = behavior.maxPayloadLength;
webSocketContext->getExt()->idleTimeout = behavior.idleTimeout;
webSocketContext->getExt()->maxBackpressure = behavior.maxBackpressure;
httpContext->onHttp("get", pattern, [webSocketContext, httpContext = this->httpContext, behavior = std::move(behavior)](auto *res, auto *req) mutable {
/* If we have this header set, it's a websocket */
std::string_view secWebSocketKey = req->getHeader("sec-websocket-key");
if (secWebSocketKey.length() == 24) {
/* Note: OpenSSL can be used here to speed this up somewhat */
char secWebSocketAccept[29] = {};
WebSocketHandshake::generate(secWebSocketKey.data(), secWebSocketAccept);
res->writeStatus("101 Switching Protocols")
->writeHeader("Upgrade", "websocket")
->writeHeader("Connection", "Upgrade")
->writeHeader("Sec-WebSocket-Accept", secWebSocketAccept);
/* Select first subprotocol if present */
std::string_view secWebSocketProtocol = req->getHeader("sec-websocket-protocol");
if (secWebSocketProtocol.length()) {
res->writeHeader("Sec-WebSocket-Protocol", secWebSocketProtocol.substr(0, secWebSocketProtocol.find(',')));
}
/* Negotiate compression */
bool perMessageDeflate = false;
bool slidingDeflateWindow = false;
if (behavior.compression != DISABLED) {
std::string_view extensions = req->getHeader("sec-websocket-extensions");
if (extensions.length()) {
/* We never support client context takeover (the client cannot compress with a sliding window). */
int wantedOptions = PERMESSAGE_DEFLATE | CLIENT_NO_CONTEXT_TAKEOVER;
/* Shared compressor is the default */
if (behavior.compression == SHARED_COMPRESSOR) {
/* Disable per-socket compressor */
wantedOptions |= SERVER_NO_CONTEXT_TAKEOVER;
}
/* isServer = true */
ExtensionsNegotiator<true> extensionsNegotiator(wantedOptions);
extensionsNegotiator.readOffer(extensions);
/* Todo: remove these mid string copies */
std::string offer = extensionsNegotiator.generateOffer();
if (offer.length()) {
res->writeHeader("Sec-WebSocket-Extensions", offer);
}
/* Did we negotiate permessage-deflate? */
if (extensionsNegotiator.getNegotiatedOptions() & PERMESSAGE_DEFLATE) {
perMessageDeflate = true;
}
/* Is the server allowed to compress with a sliding window? */
if (!(extensionsNegotiator.getNegotiatedOptions() & SERVER_NO_CONTEXT_TAKEOVER)) {
slidingDeflateWindow = true;
}
}
}
/* This will add our mark */
res->upgrade();
/* Move any backpressure */
std::string backpressure(std::move(((AsyncSocketData<SSL> *) res->getHttpResponseData())->buffer));
/* Keep any fallback buffer alive until we returned from open event, keeping req valid */
std::string fallback(std::move(res->getHttpResponseData()->salvageFallbackBuffer()));
/* Destroy HttpResponseData */
res->getHttpResponseData()->~HttpResponseData();
/* Adopting a socket invalidates it, do not rely on it directly to carry any data */
WebSocket<SSL, true> *webSocket = (WebSocket<SSL, true> *) us_socket_context_adopt_socket(SSL,
(us_socket_context_t *) webSocketContext, (us_socket_t *) res, sizeof(WebSocketData) + sizeof(UserData));
/* Update corked socket in case we got a new one (assuming we always are corked in handlers). */
webSocket->AsyncSocket<SSL>::cork();
/* Initialize websocket with any moved backpressure intact */
httpContext->upgradeToWebSocket(
webSocket->init(perMessageDeflate, slidingDeflateWindow, std::move(backpressure))
);
/* Emit open event and start the timeout */
if (behavior.open) {
us_socket_timeout(SSL, (us_socket_t *) webSocket, behavior.idleTimeout);
/* Default construct the UserData right before calling open handler */
new (webSocket->getUserData()) UserData;
behavior.open(webSocket, req);
}
/* We are going to get uncorked by the Http get return */
/* We do not need to check for any close or shutdown here as we immediately return from get handler */
} else {
/* Tell the router that we did not handle this request */
req->setYield(true);
}
}, true);
return std::move(*this);
}
TemplatedApp &&get(std::string pattern, fu2::unique_function<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
httpContext->onHttp("get", pattern, std::move(handler));
return std::move(*this);
}
TemplatedApp &&post(std::string pattern, fu2::unique_function<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
httpContext->onHttp("post", pattern, std::move(handler));
return std::move(*this);
}
TemplatedApp &&options(std::string pattern, fu2::unique_function<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
httpContext->onHttp("options", pattern, std::move(handler));
return std::move(*this);
}
TemplatedApp &&del(std::string pattern, fu2::unique_function<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
httpContext->onHttp("delete", pattern, std::move(handler));
return std::move(*this);
}
TemplatedApp &&patch(std::string pattern, fu2::unique_function<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
httpContext->onHttp("patch", pattern, std::move(handler));
return std::move(*this);
}
TemplatedApp &&put(std::string pattern, fu2::unique_function<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
httpContext->onHttp("put", pattern, std::move(handler));
return std::move(*this);
}
TemplatedApp &&head(std::string pattern, fu2::unique_function<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
httpContext->onHttp("head", pattern, std::move(handler));
return std::move(*this);
}
TemplatedApp &&connect(std::string pattern, fu2::unique_function<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
httpContext->onHttp("connect", pattern, std::move(handler));
return std::move(*this);
}
TemplatedApp &&trace(std::string pattern, fu2::unique_function<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
httpContext->onHttp("trace", pattern, std::move(handler));
return std::move(*this);
}
/* This one catches any method */
TemplatedApp &&any(std::string pattern, fu2::unique_function<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
httpContext->onHttp("*", pattern, std::move(handler));
return std::move(*this);
}
/* Host, port, callback */
TemplatedApp &&listen(std::string host, int port, fu2::unique_function<void(us_listen_socket_t *)> &&handler) {
if (!host.length()) {
return listen(port, std::move(handler));
}
handler(httpContext->listen(host.c_str(), port, 0));
return std::move(*this);
}
/* Host, port, options, callback */
TemplatedApp &&listen(std::string host, int port, int options, fu2::unique_function<void(us_listen_socket_t *)> &&handler) {
if (!host.length()) {
return listen(port, options, std::move(handler));
}
handler(httpContext->listen(host.c_str(), port, options));
return std::move(*this);
}
/* Port, callback */
TemplatedApp &&listen(int port, fu2::unique_function<void(us_listen_socket_t *)> &&handler) {
handler(httpContext->listen(nullptr, port, 0));
return std::move(*this);
}
/* Port, options, callback */
TemplatedApp &&listen(int port, int options, fu2::unique_function<void(us_listen_socket_t *)> &&handler) {
handler(httpContext->listen(nullptr, port, options));
return std::move(*this);
}
TemplatedApp &&run() {
uWS::run();
return std::move(*this);
}
};
typedef TemplatedApp<false> App;
typedef TemplatedApp<true> SSLApp;
}
#endif // UWS_APP_H

View File

@ -0,0 +1,228 @@
/*
* Authored by Alex Hultman, 2018-2019.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef UWS_ASYNCSOCKET_H
#define UWS_ASYNCSOCKET_H
/* This class implements async socket memory management strategies */
#include "LoopData.h"
#include "AsyncSocketData.h"
namespace uWS {
template <bool, bool> struct WebSocketContext;
template <bool SSL>
struct AsyncSocket {
template <bool> friend struct HttpContext;
template <bool, bool> friend struct WebSocketContext;
template <bool> friend struct WebSocketContextData;
friend struct TopicTree;
protected:
/* Get loop data for socket */
LoopData *getLoopData() {
return (LoopData *) us_loop_ext(us_socket_context_loop(SSL, us_socket_context(SSL, (us_socket_t *) this)));
}
/* Get socket extension */
AsyncSocketData<SSL> *getAsyncSocketData() {
return (AsyncSocketData<SSL> *) us_socket_ext(SSL, (us_socket_t *) this);
}
/* Socket timeout */
void timeout(unsigned int seconds) {
us_socket_timeout(SSL, (us_socket_t *) this, seconds);
}
/* Shutdown socket without any automatic drainage */
void shutdown() {
us_socket_shutdown(SSL, (us_socket_t *) this);
}
/* Immediately close socket */
us_socket_t *close() {
return us_socket_close(SSL, (us_socket_t *) this);
}
/* Cork this socket. Only one socket may ever be corked per-loop at any given time */
void cork() {
/* What if another socket is corked? */
getLoopData()->corkedSocket = this;
}
/* Returns wheter we are corked or not */
bool isCorked() {
return getLoopData()->corkedSocket == this;
}
/* Returns whether we could cork (it is free) */
bool canCork() {
return getLoopData()->corkedSocket == nullptr;
}
/* Returns a suitable buffer for temporary assemblation of send data */
std::pair<char *, bool> getSendBuffer(size_t size) {
/* If we are corked and we have room, return the cork buffer itself */
LoopData *loopData = getLoopData();
if (loopData->corkedSocket == this && loopData->corkOffset + size < LoopData::CORK_BUFFER_SIZE) {
char *sendBuffer = loopData->corkBuffer + loopData->corkOffset;
loopData->corkOffset += (int) size;
return {sendBuffer, false};
} else {
/* Slow path for now, we want to always be corked if possible */
return {(char *) malloc(size), true};
}
}
/* Returns the user space backpressure. */
int getBufferedAmount() {
return (int) getAsyncSocketData()->buffer.size();
}
/* Returns the remote IP address or empty string on failure */
std::string_view getRemoteAddress() {
static thread_local char buf[16];
int ipLength = 16;
us_socket_remote_address(SSL, (us_socket_t *) this, buf, &ipLength);
return std::string_view(buf, ipLength);
}
/* Write in three levels of prioritization: cork-buffer, syscall, socket-buffer. Always drain if possible.
* Returns pair of bytes written (anywhere) and wheter or not this call resulted in the polling for
* writable (or we are in a state that implies polling for writable). */
std::pair<int, bool> write(const char *src, int length, bool optionally = false, int nextLength = 0) {
/* Fake success if closed, simple fix to allow uncork of closed socket to succeed */
if (us_socket_is_closed(SSL, (us_socket_t *) this)) {
return {length, false};
}
LoopData *loopData = getLoopData();
AsyncSocketData<SSL> *asyncSocketData = getAsyncSocketData();
/* We are limited if we have a per-socket buffer */
if (asyncSocketData->buffer.length()) {
/* Write off as much as we can */
int written = us_socket_write(SSL, (us_socket_t *) this, asyncSocketData->buffer.data(), (int) asyncSocketData->buffer.length(), /*nextLength != 0 | */length);
/* On failure return, otherwise continue down the function */
if ((unsigned int) written < asyncSocketData->buffer.length()) {
/* Update buffering (todo: we can do better here if we keep track of what happens to this guy later on) */
asyncSocketData->buffer = asyncSocketData->buffer.substr(written);
if (optionally) {
/* Thankfully we can exit early here */
return {0, true};
} else {
/* This path is horrible and points towards erroneous usage */
asyncSocketData->buffer.append(src, length);
return {length, true};
}
}
/* At this point we simply have no buffer and can continue as normal */
asyncSocketData->buffer.clear();
}
if (length) {
if (loopData->corkedSocket == this) {
/* We are corked */
if (LoopData::CORK_BUFFER_SIZE - loopData->corkOffset >= length) {
/* If the entire chunk fits in cork buffer */
memcpy(loopData->corkBuffer + loopData->corkOffset, src, length);
loopData->corkOffset += length;
/* Fall through to default return */
} else {
/* Strategy differences between SSL and non-SSL regarding syscall minimizing */
if constexpr (SSL) {
/* Cork up as much as we can */
int stripped = LoopData::CORK_BUFFER_SIZE - loopData->corkOffset;
memcpy(loopData->corkBuffer + loopData->corkOffset, src, stripped);
loopData->corkOffset = LoopData::CORK_BUFFER_SIZE;
auto [written, failed] = uncork(src + stripped, length - stripped, optionally);
return {written + stripped, failed};
}
/* For non-SSL we take the penalty of two syscalls */
return uncork(src, length, optionally);
}
} else {
/* We are not corked */
int written = us_socket_write(SSL, (us_socket_t *) this, src, length, nextLength != 0);
/* Did we fail? */
if (written < length) {
/* If the write was optional then just bail out */
if (optionally) {
return {written, true};
}
/* Fall back to worst possible case (should be very rare for HTTP) */
/* At least we can reserve room for next chunk if we know it up front */
if (nextLength) {
asyncSocketData->buffer.reserve(asyncSocketData->buffer.length() + length - written + nextLength);
}
/* Buffer this chunk */
asyncSocketData->buffer.append(src + written, length - written);
/* Return the failure */
return {length, true};
}
/* Fall through to default return */
}
}
/* Default fall through return */
return {length, false};
}
/* Uncork this socket and flush or buffer any corked and/or passed data. It is essential to remember doing this. */
/* It does NOT count bytes written from cork buffer (they are already accounted for in the write call responsible for its corking)! */
std::pair<int, bool> uncork(const char *src = nullptr, int length = 0, bool optionally = false) {
LoopData *loopData = getLoopData();
if (loopData->corkedSocket == this) {
loopData->corkedSocket = nullptr;
if (loopData->corkOffset) {
/* Corked data is already accounted for via its write call */
auto [written, failed] = write(loopData->corkBuffer, loopData->corkOffset, false, length);
loopData->corkOffset = 0;
if (failed) {
/* We do not need to care for buffering here, write does that */
return {0, true};
}
}
/* We should only return with new writes, not things written to cork already */
return write(src, length, optionally, 0);
} else {
/* We are not even corked! */
return {0, false};
}
}
};
}
#endif // UWS_ASYNCSOCKET_H

View File

@ -0,0 +1,43 @@
/*
* Authored by Alex Hultman, 2018-2019.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef UWS_ASYNCSOCKETDATA_H
#define UWS_ASYNCSOCKETDATA_H
#include <string>
namespace uWS {
/* Depending on how we want AsyncSocket to function, this will need to change */
template <bool SSL>
struct AsyncSocketData {
/* This will do for now */
std::string buffer;
/* Allow move constructing us */
AsyncSocketData(std::string &&backpressure) : buffer(std::move(backpressure)) {
}
/* Or emppty */
AsyncSocketData() = default;
};
}
#endif // UWS_ASYNCSOCKETDATA_H

View File

@ -0,0 +1,384 @@
/*
* Authored by Alex Hultman, 2018-2019.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef UWS_HTTPCONTEXT_H
#define UWS_HTTPCONTEXT_H
/* This class defines the main behavior of HTTP and emits various events */
#include "Loop.h"
#include "HttpContextData.h"
#include "HttpResponseData.h"
#include "AsyncSocket.h"
#include <string_view>
#include <iostream>
#include "f2/function2.hpp"
namespace uWS {
template<bool> struct HttpResponse;
template <bool SSL>
struct HttpContext {
template<bool> friend struct TemplatedApp;
private:
HttpContext() = delete;
/* Maximum delay allowed until an HTTP connection is terminated due to outstanding request or rejected data (slow loris protection) */
static const int HTTP_IDLE_TIMEOUT_S = 10;
us_socket_context_t *getSocketContext() {
return (us_socket_context_t *) this;
}
static us_socket_context_t *getSocketContext(us_socket_t *s) {
return (us_socket_context_t *) us_socket_context(SSL, s);
}
HttpContextData<SSL> *getSocketContextData() {
return (HttpContextData<SSL> *) us_socket_context_ext(SSL, getSocketContext());
}
static HttpContextData<SSL> *getSocketContextDataS(us_socket_t *s) {
return (HttpContextData<SSL> *) us_socket_context_ext(SSL, getSocketContext(s));
}
/* Init the HttpContext by registering libusockets event handlers */
HttpContext<SSL> *init() {
/* Handle socket connections */
us_socket_context_on_open(SSL, getSocketContext(), [](us_socket_t *s, int is_client, char *ip, int ip_length) {
/* Any connected socket should timeout until it has a request */
us_socket_timeout(SSL, s, HTTP_IDLE_TIMEOUT_S);
/* Init socket ext */
new (us_socket_ext(SSL, s)) HttpResponseData<SSL>;
/* Call filter */
HttpContextData<SSL> *httpContextData = getSocketContextDataS(s);
for (auto &f : httpContextData->filterHandlers) {
f((HttpResponse<SSL> *) s, 1);
}
return s;
});
/* Handle socket disconnections */
us_socket_context_on_close(SSL, getSocketContext(), [](us_socket_t *s) {
/* Get socket ext */
HttpResponseData<SSL> *httpResponseData = (HttpResponseData<SSL> *) us_socket_ext(SSL, s);
/* Call filter */
HttpContextData<SSL> *httpContextData = getSocketContextDataS(s);
for (auto &f : httpContextData->filterHandlers) {
f((HttpResponse<SSL> *) s, -1);
}
/* Signal broken HTTP request only if we have a pending request */
if (httpResponseData->onAborted) {
httpResponseData->onAborted();
}
/* Destruct socket ext */
httpResponseData->~HttpResponseData<SSL>();
return s;
});
/* Handle HTTP data streams */
us_socket_context_on_data(SSL, getSocketContext(), [](us_socket_t *s, char *data, int length) {
// total overhead is about 210k down to 180k
// ~210k req/sec is the original perf with write in data
// ~200k req/sec is with cork and formatting
// ~190k req/sec is with http parsing
// ~180k - 190k req/sec is with varying routing
HttpContextData<SSL> *httpContextData = getSocketContextDataS(s);
/* Do not accept any data while in shutdown state */
if (us_socket_is_shut_down(SSL, (us_socket_t *) s)) {
return s;
}
HttpResponseData<SSL> *httpResponseData = (HttpResponseData<SSL> *) us_socket_ext(SSL, s);
/* Cork this socket */
((AsyncSocket<SSL> *) s)->cork();
// clients need to know the cursor after http parse, not servers!
// how far did we read then? we need to know to continue with websocket parsing data? or?
/* The return value is entirely up to us to interpret. The HttpParser only care for whether the returned value is DIFFERENT or not from passed user */
void *returnedSocket = httpResponseData->consumePostPadded(data, length, s, [httpContextData](void *s, uWS::HttpRequest *httpRequest) -> void * {
/* For every request we reset the timeout and hang until user makes action */
/* Warning: if we are in shutdown state, resetting the timer is a security issue! */
us_socket_timeout(SSL, (us_socket_t *) s, 0);
/* Reset httpResponse */
HttpResponseData<SSL> *httpResponseData = (HttpResponseData<SSL> *) us_socket_ext(SSL, (us_socket_t *) s);
httpResponseData->offset = 0;
/* Are we not ready for another request yet? Terminate the connection. */
if (httpResponseData->state & HttpResponseData<SSL>::HTTP_RESPONSE_PENDING) {
us_socket_close(SSL, (us_socket_t *) s);
return nullptr;
}
/* Mark pending request and emit it */
httpResponseData->state = HttpResponseData<SSL>::HTTP_RESPONSE_PENDING;
/* Route the method and URL */
httpContextData->router.getUserData() = {(HttpResponse<SSL> *) s, httpRequest};
if (!httpContextData->router.route(httpRequest->getMethod(), httpRequest->getUrl())) {
/* We have to force close this socket as we have no handler for it */
us_socket_close(SSL, (us_socket_t *) s);
return nullptr;
}
/* First of all we need to check if this socket was deleted due to upgrade */
if (httpContextData->upgradedWebSocket) {
/* We differ between closed and upgraded below */
return nullptr;
}
/* Was the socket closed? */
if (us_socket_is_closed(SSL, (struct us_socket_t *) s)) {
return nullptr;
}
/* We absolutely have to terminate parsing if shutdown */
if (us_socket_is_shut_down(SSL, (us_socket_t *) s)) {
return nullptr;
}
/* Returning from a request handler without responding or attaching an onAborted handler is ill-use */
if (!((HttpResponse<SSL> *) s)->hasResponded() && !httpResponseData->onAborted) {
/* Throw exception here? */
std::cerr << "Error: Returning from a request handler without responding or attaching an abort handler is forbidden!" << std::endl;
std::terminate();
}
/* If we have not responded and we have a data handler, we need to timeout to enfore client sending the data */
if (!((HttpResponse<SSL> *) s)->hasResponded() && httpResponseData->inStream) {
us_socket_timeout(SSL, (us_socket_t *) s, HTTP_IDLE_TIMEOUT_S);
}
/* Continue parsing */
return s;
}, [httpResponseData](void *user, std::string_view data, bool fin) -> void * {
/* We always get an empty chunk even if there is no data */
if (httpResponseData->inStream) {
/* Todo: can this handle timeout for non-post as well? */
if (fin) {
/* If we just got the last chunk (or empty chunk), disable timeout */
us_socket_timeout(SSL, (struct us_socket_t *) user, 0);
} else {
/* We still have some more data coming in later, so reset timeout */
us_socket_timeout(SSL, (struct us_socket_t *) user, HTTP_IDLE_TIMEOUT_S);
}
/* We might respond in the handler, so do not change timeout after this */
httpResponseData->inStream(data, fin);
/* Was the socket closed? */
if (us_socket_is_closed(SSL, (struct us_socket_t *) user)) {
return nullptr;
}
/* We absolutely have to terminate parsing if shutdown */
if (us_socket_is_shut_down(SSL, (us_socket_t *) user)) {
return nullptr;
}
/* If we were given the last data chunk, reset data handler to ensure following
* requests on the same socket won't trigger any previously registered behavior */
if (fin) {
httpResponseData->inStream = nullptr;
}
}
return user;
}, [](void *user) {
/* Close any socket on HTTP errors */
us_socket_close(SSL, (us_socket_t *) user);
return nullptr;
});
/* We need to uncork in all cases, except for nullptr (closed socket, or upgraded socket) */
if (returnedSocket != nullptr) {
/* Timeout on uncork failure */
auto [written, failed] = ((AsyncSocket<SSL> *) returnedSocket)->uncork();
if (failed) {
/* All Http sockets timeout by this, and this behavior match the one in HttpResponse::cork */
/* Warning: both HTTP_IDLE_TIMEOUT_S and HTTP_TIMEOUT_S are 10 seconds and both are used the same */
((AsyncSocket<SSL> *) s)->timeout(HTTP_IDLE_TIMEOUT_S);
}
return (us_socket_t *) returnedSocket;
}
/* If we upgraded, check here (differ between nullptr close and nullptr upgrade) */
if (httpContextData->upgradedWebSocket) {
/* This path is only for upgraded websockets */
AsyncSocket<SSL> *asyncSocket = (AsyncSocket<SSL> *) httpContextData->upgradedWebSocket;
/* Uncork here as well (note: what if we failed to uncork and we then pub/sub before we even upgraded?) */
/*auto [written, failed] = */asyncSocket->uncork();
/* Reset upgradedWebSocket before we return */
httpContextData->upgradedWebSocket = nullptr;
/* Return the new upgraded websocket */
return (us_socket_t *) asyncSocket;
}
/* It is okay to uncork a closed socket and we need to */
((AsyncSocket<SSL> *) s)->uncork();
/* We cannot return nullptr to the underlying stack in any case */
return s;
});
/* Handle HTTP write out (note: SSL_read may trigger this spuriously, the app need to handle spurious calls) */
us_socket_context_on_writable(SSL, getSocketContext(), [](us_socket_t *s) {
AsyncSocket<SSL> *asyncSocket = (AsyncSocket<SSL> *) s;
HttpResponseData<SSL> *httpResponseData = (HttpResponseData<SSL> *) asyncSocket->getAsyncSocketData();
/* Ask the developer to write data and return success (true) or failure (false), OR skip sending anything and return success (true). */
if (httpResponseData->onWritable) {
/* We are now writable, so hang timeout again, the user does not have to do anything so we should hang until end or tryEnd rearms timeout */
us_socket_timeout(SSL, s, 0);
/* We expect the developer to return whether or not write was successful (true).
* If write was never called, the developer should still return true so that we may drain. */
bool success = httpResponseData->onWritable(httpResponseData->offset);
/* The developer indicated that their onWritable failed. */
if (!success) {
/* Skip testing if we can drain anything since that might perform an extra syscall */
return s;
}
/* We don't want to fall through since we don't want to mess with timeout.
* It makes little sense to drain any backpressure when the user has registered onWritable. */
return s;
}
/* Drain any socket buffer, this might empty our backpressure and thus finish the request */
/*auto [written, failed] = */asyncSocket->write(nullptr, 0, true, 0);
/* Expect another writable event, or another request within the timeout */
asyncSocket->timeout(HTTP_IDLE_TIMEOUT_S);
return s;
});
/* Handle FIN, HTTP does not support half-closed sockets, so simply close */
us_socket_context_on_end(SSL, getSocketContext(), [](us_socket_t *s) {
/* We do not care for half closed sockets */
AsyncSocket<SSL> *asyncSocket = (AsyncSocket<SSL> *) s;
return asyncSocket->close();
});
/* Handle socket timeouts, simply close them so to not confuse client with FIN */
us_socket_context_on_timeout(SSL, getSocketContext(), [](us_socket_t *s) {
/* Force close rather than gracefully shutdown and risk confusing the client with a complete download */
AsyncSocket<SSL> *asyncSocket = (AsyncSocket<SSL> *) s;
return asyncSocket->close();
});
return this;
}
/* Used by App in its WebSocket handler */
void upgradeToWebSocket(void *newSocket) {
HttpContextData<SSL> *httpContextData = getSocketContextData();
httpContextData->upgradedWebSocket = newSocket;
}
public:
/* Construct a new HttpContext using specified loop */
static HttpContext *create(Loop *loop, us_socket_context_options_t options = {}) {
HttpContext *httpContext;
httpContext = (HttpContext *) us_create_socket_context(SSL, (us_loop_t *) loop, sizeof(HttpContextData<SSL>), options);
if (!httpContext) {
return nullptr;
}
/* Init socket context data */
new ((HttpContextData<SSL> *) us_socket_context_ext(SSL, (us_socket_context_t *) httpContext)) HttpContextData<SSL>();
return httpContext->init();
}
/* Destruct the HttpContext, it does not follow RAII */
void free() {
/* Destruct socket context data */
HttpContextData<SSL> *httpContextData = getSocketContextData();
httpContextData->~HttpContextData<SSL>();
/* Free the socket context in whole */
us_socket_context_free(SSL, getSocketContext());
}
void filter(fu2::unique_function<void(HttpResponse<SSL> *, int)> &&filterHandler) {
getSocketContextData()->filterHandlers.emplace_back(std::move(filterHandler));
}
/* Register an HTTP route handler acording to URL pattern */
void onHttp(std::string method, std::string pattern, fu2::unique_function<void(HttpResponse<SSL> *, HttpRequest *)> &&handler, bool upgrade = false) {
HttpContextData<SSL> *httpContextData = getSocketContextData();
/* Todo: This is ugly, fix */
std::vector<std::string> methods;
if (method == "*") {
methods = httpContextData->router.methods;
} else {
methods = {method};
}
httpContextData->router.add(methods, pattern, [handler = std::move(handler)](auto *r) mutable {
auto user = r->getUserData();
user.httpRequest->setYield(false);
user.httpRequest->setParameters(r->getParameters());
handler(user.httpResponse, user.httpRequest);
/* If any handler yielded, the router will keep looking for a suitable handler. */
if (user.httpRequest->getYield()) {
return false;
}
return true;
}, method == "*" ? httpContextData->router.LOW_PRIORITY : (upgrade ? httpContextData->router.HIGH_PRIORITY : httpContextData->router.MEDIUM_PRIORITY));
}
/* Listen to port using this HttpContext */
us_listen_socket_t *listen(const char *host, int port, int options) {
return us_socket_context_listen(SSL, getSocketContext(), host, port, options, sizeof(HttpResponseData<SSL>));
}
};
}
#endif // UWS_HTTPCONTEXT_H

View File

@ -0,0 +1,48 @@
/*
* Authored by Alex Hultman, 2018-2019.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef UWS_HTTPCONTEXTDATA_H
#define UWS_HTTPCONTEXTDATA_H
#include "HttpRouter.h"
#include <vector>
#include "f2/function2.hpp"
namespace uWS {
template<bool> struct HttpResponse;
struct HttpRequest;
template <bool SSL>
struct alignas(16) HttpContextData {
template <bool> friend struct HttpContext;
template <bool> friend struct HttpResponse;
private:
std::vector<fu2::unique_function<void(HttpResponse<SSL> *, int)>> filterHandlers;
struct RouterData {
HttpResponse<SSL> *httpResponse;
HttpRequest *httpRequest;
};
HttpRouter<RouterData> router;
void *upgradedWebSocket = nullptr;
};
}
#endif // UWS_HTTPCONTEXTDATA_H

View File

@ -0,0 +1,334 @@
/*
* Authored by Alex Hultman, 2018-2019.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef UWS_HTTPPARSER_H
#define UWS_HTTPPARSER_H
// todo: HttpParser is in need of a few clean-ups and refactorings
/* The HTTP parser is an independent module subject to unit testing / fuzz testing */
#include <string>
#include <cstring>
#include <algorithm>
#include "f2/function2.hpp"
namespace uWS {
/* We require at least this much post padding */
static const int MINIMUM_HTTP_POST_PADDING = 32;
struct HttpRequest {
friend struct HttpParser;
private:
const static int MAX_HEADERS = 50;
struct Header {
std::string_view key, value;
} headers[MAX_HEADERS];
int querySeparator;
bool didYield;
std::pair<int, std::string_view *> currentParameters;
public:
bool getYield() {
return didYield;
}
/* Iteration over headers (key, value) */
struct HeaderIterator {
Header *ptr;
bool operator!=(const HeaderIterator &other) const {
/* Comparison with end is a special case */
if (ptr != other.ptr) {
return other.ptr || ptr->key.length();
}
return false;
}
HeaderIterator &operator++() {
ptr++;
return *this;
}
std::pair<std::string_view, std::string_view> operator*() const {
return {ptr->key, ptr->value};
}
};
HeaderIterator begin() {
return {headers + 1};
}
HeaderIterator end() {
return {nullptr};
}
/* If you do not want to handle this route */
void setYield(bool yield) {
didYield = yield;
}
std::string_view getHeader(std::string_view lowerCasedHeader) {
for (Header *h = headers; (++h)->key.length(); ) {
if (h->key.length() == lowerCasedHeader.length() && !strncmp(h->key.data(), lowerCasedHeader.data(), lowerCasedHeader.length())) {
return h->value;
}
}
return std::string_view(nullptr, 0);
}
std::string_view getUrl() {
return std::string_view(headers->value.data(), querySeparator);
}
std::string_view getMethod() {
return std::string_view(headers->key.data(), headers->key.length());
}
std::string_view getQuery() {
if (querySeparator < (int) headers->value.length()) {
/* Strip the initial ? */
return std::string_view(headers->value.data() + querySeparator + 1, headers->value.length() - querySeparator - 1);
} else {
return std::string_view(nullptr, 0);
}
}
void setParameters(std::pair<int, std::string_view *> parameters) {
currentParameters = parameters;
}
std::string_view getParameter(unsigned int index) {
if (currentParameters.first < (int) index) {
return {};
} else {
return currentParameters.second[index];
}
}
};
struct HttpParser {
private:
std::string fallback;
unsigned int remainingStreamingBytes = 0;
const size_t MAX_FALLBACK_SIZE = 1024 * 4;
static unsigned int toUnsignedInteger(std::string_view str) {
unsigned int unsignedIntegerValue = 0;
for (unsigned char c : str) {
unsignedIntegerValue = unsignedIntegerValue * 10 + (c - '0');
}
return unsignedIntegerValue;
}
static unsigned int getHeaders(char *postPaddedBuffer, char *end, struct HttpRequest::Header *headers) {
char *preliminaryKey, *preliminaryValue, *start = postPaddedBuffer;
for (unsigned int i = 0; i < HttpRequest::MAX_HEADERS; i++) {
for (preliminaryKey = postPaddedBuffer; (*postPaddedBuffer != ':') & (*postPaddedBuffer > 32); *(postPaddedBuffer++) |= 32);
if (*postPaddedBuffer == '\r') {
if ((postPaddedBuffer != end) & (postPaddedBuffer[1] == '\n') & (i > 0)) {
headers->key = std::string_view(nullptr, 0);
return (unsigned int) ((postPaddedBuffer + 2) - start);
} else {
return 0;
}
} else {
headers->key = std::string_view(preliminaryKey, (size_t) (postPaddedBuffer - preliminaryKey));
for (postPaddedBuffer++; (*postPaddedBuffer == ':' || *postPaddedBuffer < 33) && *postPaddedBuffer != '\r'; postPaddedBuffer++);
preliminaryValue = postPaddedBuffer;
postPaddedBuffer = (char *) memchr(postPaddedBuffer, '\r', end - postPaddedBuffer);
if (postPaddedBuffer && postPaddedBuffer[1] == '\n') {
headers->value = std::string_view(preliminaryValue, (size_t) (postPaddedBuffer - preliminaryValue));
postPaddedBuffer += 2;
headers++;
} else {
return 0;
}
}
}
return 0;
}
// the only caller of getHeaders
template <int CONSUME_MINIMALLY>
std::pair<int, void *> fenceAndConsumePostPadded(char *data, int length, void *user, HttpRequest *req, fu2::unique_function<void *(void *, HttpRequest *)> &requestHandler, fu2::unique_function<void *(void *, std::string_view, bool)> &dataHandler) {
int consumedTotal = 0;
data[length] = '\r';
for (int consumed; length && (consumed = getHeaders(data, data + length, req->headers)); ) {
data += consumed;
length -= consumed;
consumedTotal += consumed;
req->headers->value = std::string_view(req->headers->value.data(), std::max<int>(0, (int) req->headers->value.length() - 9));
/* Parse query */
const char *querySeparatorPtr = (const char *) memchr(req->headers->value.data(), '?', req->headers->value.length());
req->querySeparator = (int) ((querySeparatorPtr ? querySeparatorPtr : req->headers->value.data() + req->headers->value.length()) - req->headers->value.data());
/* If returned socket is not what we put in we need
* to break here as we either have upgraded to
* WebSockets or otherwise closed the socket. */
void *returnedUser = requestHandler(user, req);
if (returnedUser != user) {
/* We are upgraded to WebSocket or otherwise broken */
return {consumedTotal, returnedUser};
}
// todo: do not check this for GET (get should not have a body)
// todo: also support reading chunked streams
std::string_view contentLengthString = req->getHeader("content-length");
if (contentLengthString.length()) {
remainingStreamingBytes = toUnsignedInteger(contentLengthString);
if (!CONSUME_MINIMALLY) {
unsigned int emittable = std::min<unsigned int>(remainingStreamingBytes, length);
dataHandler(user, std::string_view(data, emittable), emittable == remainingStreamingBytes);
remainingStreamingBytes -= emittable;
data += emittable;
length -= emittable;
consumedTotal += emittable;
}
} else {
/* Still emit an empty data chunk to signal no data */
dataHandler(user, {}, true);
}
if (CONSUME_MINIMALLY) {
break;
}
}
return {consumedTotal, user};
}
public:
/* We do this to prolong the validity of parsed headers by keeping only the fallback buffer alive */
std::string &&salvageFallbackBuffer() {
return std::move(fallback);
}
void *consumePostPadded(char *data, int length, void *user, fu2::unique_function<void *(void *, HttpRequest *)> &&requestHandler, fu2::unique_function<void *(void *, std::string_view, bool)> &&dataHandler, fu2::unique_function<void *(void *)> &&errorHandler) {
HttpRequest req;
if (remainingStreamingBytes) {
// this is exactly the same as below!
// todo: refactor this
if (remainingStreamingBytes >= (unsigned int) length) {
void *returnedUser = dataHandler(user, std::string_view(data, length), remainingStreamingBytes == (unsigned int) length);
remainingStreamingBytes -= length;
return returnedUser;
} else {
void *returnedUser = dataHandler(user, std::string_view(data, remainingStreamingBytes), true);
data += remainingStreamingBytes;
length -= remainingStreamingBytes;
remainingStreamingBytes = 0;
if (returnedUser != user) {
return returnedUser;
}
}
} else if (fallback.length()) {
int had = (int) fallback.length();
int maxCopyDistance = (int) std::min(MAX_FALLBACK_SIZE - fallback.length(), (size_t) length);
/* We don't want fallback to be short string optimized, since we want to move it */
fallback.reserve(fallback.length() + maxCopyDistance + std::max<int>(MINIMUM_HTTP_POST_PADDING, sizeof(std::string)));
fallback.append(data, maxCopyDistance);
// break here on break
std::pair<int, void *> consumed = fenceAndConsumePostPadded<true>(fallback.data(), (int) fallback.length(), user, &req, requestHandler, dataHandler);
if (consumed.second != user) {
return consumed.second;
}
if (consumed.first) {
fallback.clear();
data += consumed.first - had;
length -= consumed.first - had;
if (remainingStreamingBytes) {
// this is exactly the same as above!
if (remainingStreamingBytes >= (unsigned int) length) {
void *returnedUser = dataHandler(user, std::string_view(data, length), remainingStreamingBytes == (unsigned int) length);
remainingStreamingBytes -= length;
return returnedUser;
} else {
void *returnedUser = dataHandler(user, std::string_view(data, remainingStreamingBytes), true);
data += remainingStreamingBytes;
length -= remainingStreamingBytes;
remainingStreamingBytes = 0;
if (returnedUser != user) {
return returnedUser;
}
}
}
} else {
if (fallback.length() == MAX_FALLBACK_SIZE) {
// note: you don't really need error handler, just return something strange!
// we could have it return a constant pointer to denote error!
return errorHandler(user);
}
return user;
}
}
std::pair<int, void *> consumed = fenceAndConsumePostPadded<false>(data, length, user, &req, requestHandler, dataHandler);
if (consumed.second != user) {
return consumed.second;
}
data += consumed.first;
length -= consumed.first;
if (length) {
if ((unsigned int) length < MAX_FALLBACK_SIZE) {
fallback.append(data, length);
} else {
return errorHandler(user);
}
}
// added for now
return user;
}
};
}
#endif // UWS_HTTPPARSER_H

View File

@ -0,0 +1,317 @@
/*
* Authored by Alex Hultman, 2018-2019.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef UWS_HTTPRESPONSE_H
#define UWS_HTTPRESPONSE_H
/* An HttpResponse is the channel on which you send back a response */
#include "AsyncSocket.h"
#include "HttpResponseData.h"
#include "HttpContextData.h"
#include "Utilities.h"
#include "f2/function2.hpp"
/* todo: tryWrite is missing currently, only send smaller segments with write */
namespace uWS {
/* Some pre-defined status constants to use with writeStatus */
static const char *HTTP_200_OK = "200 OK";
/* The general timeout for HTTP sockets */
static const int HTTP_TIMEOUT_S = 10;
template <bool SSL>
struct HttpResponse : public AsyncSocket<SSL> {
/* Solely used for getHttpResponseData() */
template <bool> friend struct TemplatedApp;
typedef AsyncSocket<SSL> Super;
private:
HttpResponseData<SSL> *getHttpResponseData() {
return (HttpResponseData<SSL> *) Super::getAsyncSocketData();
}
/* Write an unsigned 32-bit integer in hex */
void writeUnsignedHex(unsigned int value) {
char buf[10];
int length = utils::u32toaHex(value, buf);
/* For now we do this copy */
Super::write(buf, length);
}
/* Write an unsigned 32-bit integer */
void writeUnsigned(unsigned int value) {
char buf[10];
int length = utils::u32toa(value, buf);
/* For now we do this copy */
Super::write(buf, length);
}
/* When we are done with a response we mark it like so */
void markDone(HttpResponseData<SSL> *httpResponseData) {
httpResponseData->onAborted = nullptr;
/* Also remove onWritable so that we do not emit when draining behind the scenes. */
httpResponseData->onWritable = nullptr;
/* We are done with this request */
httpResponseData->state &= ~HttpResponseData<SSL>::HTTP_RESPONSE_PENDING;
}
/* Called only once per request */
void writeMark() {
writeHeader("uWebSockets", "v0.17");
}
/* Returns true on success, indicating that it might be feasible to write more data.
* Will start timeout if stream reaches totalSize or write failure. */
bool internalEnd(std::string_view data, int totalSize, bool optional, bool allowContentLength = true) {
/* Write status if not already done */
writeStatus(HTTP_200_OK);
/* If no total size given then assume this chunk is everything */
if (!totalSize) {
totalSize = (int) data.length();
}
HttpResponseData<SSL> *httpResponseData = getHttpResponseData();
if (httpResponseData->state & HttpResponseData<SSL>::HTTP_WRITE_CALLED) {
/* We do not have tryWrite-like functionalities, so ignore optional in this path */
/* Do not allow sending 0 chunk here */
if (data.length()) {
Super::write("\r\n", 2);
writeUnsignedHex((unsigned int) data.length());
Super::write("\r\n", 2);
/* Ignoring optional for now */
Super::write(data.data(), (int) data.length());
}
/* Terminating 0 chunk */
Super::write("\r\n0\r\n\r\n", 7);
markDone(httpResponseData);
/* tryEnd can never fail when in chunked mode, since we do not have tryWrite (yet), only write */
Super::timeout(HTTP_TIMEOUT_S);
return true;
} else {
/* Write content-length on first call */
if (!(httpResponseData->state & HttpResponseData<SSL>::HTTP_END_CALLED)) {
/* Write mark, this propagates to WebSockets too */
writeMark();
/* WebSocket upgrades does not allow content-length */
if (allowContentLength) {
/* Even zero is a valid content-length */
Super::write("Content-Length: ", 16);
writeUnsigned(totalSize);
Super::write("\r\n\r\n", 4);
} else {
Super::write("\r\n", 2);
}
/* Mark end called */
httpResponseData->state |= HttpResponseData<SSL>::HTTP_END_CALLED;
}
/* Even if we supply no new data to write, its failed boolean is useful to know
* if it failed to drain any prior failed header writes */
/* Write as much as possible without causing backpressure */
auto [written, failed] = Super::write(data.data(), (int) data.length(), optional);
httpResponseData->offset += written;
/* Success is when we wrote the entire thing without any failures */
bool success = (unsigned int) written == data.length() && !failed;
/* If we are now at the end, start a timeout. Also start a timeout if we failed. */
if (!success || httpResponseData->offset == totalSize) {
Super::timeout(HTTP_TIMEOUT_S);
}
/* Remove onAborted function if we reach the end */
if (httpResponseData->offset == totalSize) {
markDone(httpResponseData);
}
return success;
}
}
/* This call is identical to end, but will never write content-length and is thus suitable for upgrades */
void upgrade() {
internalEnd({nullptr, 0}, 0, false, false);
}
public:
/* Immediately terminate this Http response */
using Super::close;
using Super::getRemoteAddress;
/* Note: Headers are not checked in regards to timeout.
* We only check when you actively push data or end the request */
/* Write the HTTP status */
HttpResponse *writeStatus(std::string_view status) {
HttpResponseData<SSL> *httpResponseData = getHttpResponseData();
/* Do not allow writing more than one status */
if (httpResponseData->state & HttpResponseData<SSL>::HTTP_STATUS_CALLED) {
return this;
}
/* Update status */
httpResponseData->state |= HttpResponseData<SSL>::HTTP_STATUS_CALLED;
Super::write("HTTP/1.1 ", 9);
Super::write(status.data(), (int) status.length());
Super::write("\r\n", 2);
return this;
}
/* Write an HTTP header with string value */
HttpResponse *writeHeader(std::string_view key, std::string_view value) {
writeStatus(HTTP_200_OK);
Super::write(key.data(), (int) key.length());
Super::write(": ", 2);
Super::write(value.data(), (int) value.length());
Super::write("\r\n", 2);
return this;
}
/* Write an HTTP header with unsigned int value */
HttpResponse *writeHeader(std::string_view key, unsigned int value) {
Super::write(key.data(), (int) key.length());
Super::write(": ", 2);
writeUnsigned(value);
Super::write("\r\n", 2);
return this;
}
/* End the response with an optional data chunk. Always starts a timeout. */
void end(std::string_view data = {}) {
internalEnd(data, (int) data.length(), false);
}
/* Try and end the response. Returns [true, true] on success.
* Starts a timeout in some cases. Returns [ok, hasResponded] */
std::pair<bool, bool> tryEnd(std::string_view data, int totalSize = 0) {
return {internalEnd(data, totalSize, true), hasResponded()};
}
/* Write parts of the response in chunking fashion. Starts timeout if failed. */
bool write(std::string_view data) {
writeStatus(HTTP_200_OK);
/* Do not allow sending 0 chunks, they mark end of response */
if (!data.length()) {
/* If you called us, then according to you it was fine to call us so it's fine to still call us */
return true;
}
HttpResponseData<SSL> *httpResponseData = getHttpResponseData();
if (!(httpResponseData->state & HttpResponseData<SSL>::HTTP_WRITE_CALLED)) {
/* Write mark on first call to write */
writeMark();
writeHeader("Transfer-Encoding", "chunked");
httpResponseData->state |= HttpResponseData<SSL>::HTTP_WRITE_CALLED;
}
Super::write("\r\n", 2);
writeUnsignedHex((unsigned int) data.length());
Super::write("\r\n", 2);
auto [written, failed] = Super::write(data.data(), (int) data.length());
if (failed) {
Super::timeout(HTTP_TIMEOUT_S);
}
/* If we did not fail the write, accept more */
return !failed;
}
/* Get the current byte write offset for this Http response */
int getWriteOffset() {
HttpResponseData<SSL> *httpResponseData = getHttpResponseData();
return httpResponseData->offset;
}
/* Checking if we have fully responded and are ready for another request */
bool hasResponded() {
HttpResponseData<SSL> *httpResponseData = getHttpResponseData();
return !(httpResponseData->state & HttpResponseData<SSL>::HTTP_RESPONSE_PENDING);
}
/* Corks the response if possible. Leaves already corked socket be. */
HttpResponse *cork(fu2::unique_function<void()> &&handler) {
if (!Super::isCorked() && Super::canCork()) {
Super::cork();
handler();
/* Timeout on uncork failure, since most writes will succeed while corked */
auto [written, failed] = Super::uncork();
if (failed) {
/* For now we only have one single timeout so let's use it */
/* This behavior should equal the behavior in HttpContext when uncorking fails */
Super::timeout(HTTP_TIMEOUT_S);
}
} else {
/* We are already corked, or can't cork so let's just call the handler */
handler();
}
return this;
}
/* Attach handler for writable HTTP response */
HttpResponse *onWritable(fu2::unique_function<bool(int)> &&handler) {
HttpResponseData<SSL> *httpResponseData = getHttpResponseData();
httpResponseData->onWritable = std::move(handler);
return this;
}
/* Attach handler for aborted HTTP request */
HttpResponse *onAborted(fu2::unique_function<void()> &&handler) {
HttpResponseData<SSL> *httpResponseData = getHttpResponseData();
httpResponseData->onAborted = std::move(handler);
return this;
}
/* Attach a read handler for data sent. Will be called with FIN set true if last segment. */
void onData(fu2::unique_function<void(std::string_view, bool)> &&handler) {
HttpResponseData<SSL> *data = getHttpResponseData();
data->inStream = std::move(handler);
}
};
}
#endif // UWS_HTTPRESPONSE_H

View File

@ -0,0 +1,57 @@
/*
* Authored by Alex Hultman, 2018-2019.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef UWS_HTTPRESPONSEDATA_H
#define UWS_HTTPRESPONSEDATA_H
/* This data belongs to the HttpResponse */
#include "HttpParser.h"
#include "AsyncSocketData.h"
#include "f2/function2.hpp"
namespace uWS {
template <bool SSL>
struct HttpResponseData : AsyncSocketData<SSL>, HttpParser {
template <bool> friend struct HttpResponse;
template <bool> friend struct HttpContext;
private:
/* Bits of status */
enum {
HTTP_STATUS_CALLED = 1, // used
HTTP_WRITE_CALLED = 2, // used
HTTP_END_CALLED = 4, // used
HTTP_RESPONSE_PENDING = 8, // used
HTTP_ENDED_STREAM_OUT = 16 // not used
};
/* Per socket event handlers */
fu2::unique_function<bool(int)> onWritable;
fu2::unique_function<void()> onAborted;
fu2::unique_function<void(std::string_view, bool)> inStream; // onData
/* Outgoing offset */
int offset = 0;
/* Current state (content-length sent, status sent, write called, etc */
int state = 0;
};
}
#endif // UWS_HTTPRESPONSEDATA_H

View File

@ -0,0 +1,233 @@
/*
* Authored by Alex Hultman, 2018-2019.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef UWS_HTTPROUTER_HPP
#define UWS_HTTPROUTER_HPP
#include <map>
#include <vector>
#include <cstring>
#include <string_view>
#include <string>
#include <algorithm>
#include <memory>
#include "f2/function2.hpp"
namespace uWS {
template <class USERDATA>
struct HttpRouter {
/* These are public for now */
std::vector<std::string> methods = {"get", "post", "head", "put", "delete", "connect", "options", "trace", "patch"};
static const uint32_t HIGH_PRIORITY = 0xd0000000, MEDIUM_PRIORITY = 0xe0000000, LOW_PRIORITY = 0xf0000000;
private:
USERDATA userData;
static const unsigned int MAX_URL_SEGMENTS = 100;
/* Handler ids are 32-bit */
static const uint32_t HANDLER_MASK = 0x0fffffff;
/* Methods and their respective priority */
std::map<std::string, int> priority;
/* List of handlers */
std::vector<fu2::unique_function<bool(HttpRouter *)>> handlers;
/* Current URL cache */
std::string_view currentUrl;
std::string_view urlSegmentVector[MAX_URL_SEGMENTS];
int urlSegmentTop;
/* The matching tree */
struct Node {
std::string name;
std::vector<std::unique_ptr<Node>> children;
std::vector<uint32_t> handlers;
} root = {"rootNode"};
/* Advance from parent to child, adding child if necessary */
Node *getNode(Node *parent, std::string child) {
for (std::unique_ptr<Node> &node : parent->children) {
if (node->name == child) {
return node.get();
}
}
/* Insert sorted, but keep order if parent is root (we sort methods by priority elsewhere) */
std::unique_ptr<Node> newNode(new Node({child}));
return parent->children.emplace(std::upper_bound(parent->children.begin(), parent->children.end(), newNode, [parent, this](auto &a, auto &b) {
return b->name.length() && (parent != &root) && (b->name < a->name);
}), std::move(newNode))->get();
}
/* Basically a pre-allocated stack */
struct RouteParameters {
friend struct HttpRouter;
private:
std::string_view params[MAX_URL_SEGMENTS];
int paramsTop;
void reset() {
paramsTop = -1;
}
void push(std::string_view param) {
/* We check these bounds indirectly via the urlSegments limit */
params[++paramsTop] = param;
}
void pop() {
/* Same here, we cannot pop outside */
paramsTop--;
}
} routeParameters;
/* Set URL for router. Will reset any URL cache */
inline void setUrl(std::string_view url) {
/* Remove / from input URL */
currentUrl = url.substr(std::min<unsigned int>((unsigned int) url.length(), 1));
urlSegmentTop = -1;
}
/* Lazily parse or read from cache */
inline std::string_view getUrlSegment(int urlSegment) {
if (urlSegment > urlSegmentTop) {
/* Return empty segment if we are out of URL or stack space, but never for first url segment */
if (!currentUrl.length() || urlSegment > 99) {
return {};
}
auto segmentLength = currentUrl.find('/');
if (segmentLength == std::string::npos) {
segmentLength = currentUrl.length();
/* Push to url segment vector */
urlSegmentVector[urlSegment] = currentUrl.substr(0, segmentLength);
urlSegmentTop++;
/* Update currentUrl */
currentUrl = currentUrl.substr(segmentLength);
} else {
/* Push to url segment vector */
urlSegmentVector[urlSegment] = currentUrl.substr(0, segmentLength);
urlSegmentTop++;
/* Update currentUrl */
currentUrl = currentUrl.substr(segmentLength + 1);
}
}
/* In any case we return it */
return urlSegmentVector[urlSegment];
}
/* Executes as many handlers it can */
bool executeHandlers(Node *parent, int urlSegment, USERDATA &userData) {
/* If we have no more URL and not on first round, return where we may stand */
if (urlSegment && !getUrlSegment(urlSegment).length()) {
/* We have reached accross the entire URL with no stoppage, execute */
for (int handler : parent->handlers) {
if (handlers[handler & HANDLER_MASK](this)) {
return true;
}
}
/* We reached the end, so go back */
return false;
}
for (auto &p : parent->children) {
if (p->name.length() && p->name[0] == '*') {
/* Wildcard match (can be seen as a shortcut) */
for (int handler : p->handlers) {
if (handlers[handler & HANDLER_MASK](this)) {
return true;
}
}
} else if (p->name.length() && p->name[0] == ':' && getUrlSegment(urlSegment).length()) {
/* Parameter match */
routeParameters.push(getUrlSegment(urlSegment));
if (executeHandlers(p.get(), urlSegment + 1, userData)) {
return true;
}
routeParameters.pop();
} else if (p->name == getUrlSegment(urlSegment)) {
/* Static match */
if (executeHandlers(p.get(), urlSegment + 1, userData)) {
return true;
}
}
}
return false;
}
public:
HttpRouter() {
int p = 0;
for (std::string &method : methods) {
priority[method] = p++;
}
}
std::pair<int, std::string_view *> getParameters() {
return {routeParameters.paramsTop, routeParameters.params};
}
USERDATA &getUserData() {
return userData;
}
/* Fast path */
bool route(std::string_view method, std::string_view url) {
/* Reset url parsing cache */
setUrl(url);
routeParameters.reset();
/* Begin by finding the method node */
for (auto &p : root.children) {
if (p->name == method) {
/* Then route the url */
return executeHandlers(p.get(), 0, userData);
}
}
/* We did not find any handler for this method and url */
return false;
}
/* Adds the corresponding entires in matching tree and handler list */
void add(std::vector<std::string> methods, std::string pattern, fu2::unique_function<bool(HttpRouter *)> &&handler, int priority = MEDIUM_PRIORITY) {
for (std::string method : methods) {
/* Lookup method */
Node *node = getNode(&root, method);
/* Iterate over all segments */
setUrl(pattern);
for (int i = 0; getUrlSegment(i).length() || i == 0; i++) {
node = getNode(node, std::string(getUrlSegment(i)));
}
/* Insert handler in order sorted by priority (most significant 1 byte) */
node->handlers.insert(std::upper_bound(node->handlers.begin(), node->handlers.end(), (uint32_t) (priority | handlers.size())), (uint32_t) (priority | handlers.size()));
}
/* Alloate this handler */
handlers.emplace_back(std::move(handler));
}
};
}
#endif // UWS_HTTPROUTER_HPP

View File

@ -0,0 +1,169 @@
/*
* Authored by Alex Hultman, 2018-2019.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef UWS_LOOP_H
#define UWS_LOOP_H
/* The loop is lazily created per-thread and run with uWS::run() */
#include "LoopData.h"
#include <libusockets.h>
namespace uWS {
struct Loop {
private:
static void wakeupCb(us_loop_t *loop) {
LoopData *loopData = (LoopData *) us_loop_ext(loop);
/* Swap current deferQueue */
loopData->deferMutex.lock();
int oldDeferQueue = loopData->currentDeferQueue;
loopData->currentDeferQueue = (loopData->currentDeferQueue + 1) % 2;
loopData->deferMutex.unlock();
/* Drain the queue */
for (auto &x : loopData->deferQueues[oldDeferQueue]) {
x();
}
loopData->deferQueues[oldDeferQueue].clear();
}
static void preCb(us_loop_t *loop) {
LoopData *loopData = (LoopData *) us_loop_ext(loop);
for (auto &p : loopData->preHandlers) {
p.second((Loop *) loop);
}
}
static void postCb(us_loop_t *loop) {
LoopData *loopData = (LoopData *) us_loop_ext(loop);
for (auto &p : loopData->postHandlers) {
p.second((Loop *) loop);
}
}
Loop() = delete;
~Loop() = default;
Loop *init() {
new (us_loop_ext((us_loop_t *) this)) LoopData;
return this;
}
static Loop *create(void *hint) {
return ((Loop *) us_create_loop(hint, wakeupCb, preCb, postCb, sizeof(LoopData)))->init();
}
/* What to do with loops created with existingNativeLoop? */
struct LoopCleaner {
~LoopCleaner() {
if(loop && cleanMe) {
loop->free();
}
}
Loop *loop = nullptr;
bool cleanMe = false;
};
public:
/* Lazily initializes a per-thread loop and returns it.
* Will automatically free all initialized loops at exit. */
static Loop *get(void *existingNativeLoop = nullptr) {
static thread_local LoopCleaner lazyLoop;
if (!lazyLoop.loop) {
/* If we are given a native loop pointer we pass that to uSockets and let it deal with it */
if (existingNativeLoop) {
/* Todo: here we want to pass the pointer, not a boolean */
lazyLoop.loop = create(existingNativeLoop);
/* We cannot register automatic free here, must be manually done */
} else {
lazyLoop.loop = create(nullptr);
lazyLoop.cleanMe = true;
}
}
return lazyLoop.loop;
}
/* Freeing the default loop should be done once */
void free() {
LoopData *loopData = (LoopData *) us_loop_ext((us_loop_t *) this);
loopData->~LoopData();
/* uSockets will track whether this loop is owned by us or a borrowed alien loop */
us_loop_free((us_loop_t *) this);
}
void addPostHandler(void *key, fu2::unique_function<void(Loop *)> &&handler) {
LoopData *loopData = (LoopData *) us_loop_ext((us_loop_t *) this);
loopData->postHandlers.emplace(key, std::move(handler));
}
/* Bug: what if you remove a handler while iterating them? */
void removePostHandler(void *key) {
LoopData *loopData = (LoopData *) us_loop_ext((us_loop_t *) this);
loopData->postHandlers.erase(key);
}
void addPreHandler(void *key, fu2::unique_function<void(Loop *)> &&handler) {
LoopData *loopData = (LoopData *) us_loop_ext((us_loop_t *) this);
loopData->preHandlers.emplace(key, std::move(handler));
}
/* Bug: what if you remove a handler while iterating them? */
void removePreHandler(void *key) {
LoopData *loopData = (LoopData *) us_loop_ext((us_loop_t *) this);
loopData->preHandlers.erase(key);
}
/* Defer this callback on Loop's thread of execution */
void defer(fu2::unique_function<void()> &&cb) {
LoopData *loopData = (LoopData *) us_loop_ext((us_loop_t *) this);
//if (std::thread::get_id() == ) // todo: add fast path for same thread id
loopData->deferMutex.lock();
loopData->deferQueues[loopData->currentDeferQueue].emplace_back(std::move(cb));
loopData->deferMutex.unlock();
us_wakeup_loop((us_loop_t *) this);
}
/* Actively block and run this loop */
void run() {
us_loop_run((us_loop_t *) this);
}
/* Passively integrate with the underlying default loop */
/* Used to seamlessly integrate with third parties such as Node.js */
void integrate() {
us_loop_integrate((us_loop_t *) this);
}
};
/* Can be called from any thread to run the thread local loop */
inline void run() {
Loop::get()->run();
}
}
#endif // UWS_LOOP_H

View File

@ -0,0 +1,72 @@
/*
* Authored by Alex Hultman, 2018-2019.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef UWS_LOOPDATA_H
#define UWS_LOOPDATA_H
#include <thread>
#include <functional>
#include <vector>
#include <mutex>
#include <map>
#include "PerMessageDeflate.h"
#include "f2/function2.hpp"
namespace uWS {
struct Loop;
struct alignas(16) LoopData {
friend struct Loop;
private:
std::mutex deferMutex;
int currentDeferQueue = 0;
std::vector<fu2::unique_function<void()>> deferQueues[2];
/* Map from void ptr to handler */
std::map<void *, fu2::unique_function<void(Loop *)>> postHandlers, preHandlers;
public:
~LoopData() {
/* If we have had App.ws called with compression we need to clear this */
if (zlibContext) {
delete zlibContext;
delete inflationStream;
delete deflationStream;
}
delete [] corkBuffer;
}
/* Good 16k for SSL perf. */
static const int CORK_BUFFER_SIZE = 16 * 1024;
/* Cork data */
char *corkBuffer = new char[CORK_BUFFER_SIZE];
int corkOffset = 0;
void *corkedSocket = nullptr;
/* Per message deflate data */
ZlibContext *zlibContext = nullptr;
InflationStream *inflationStream = nullptr;
DeflationStream *deflationStream = nullptr;
};
}
#endif // UWS_LOOPDATA_H

View File

@ -0,0 +1,187 @@
/*
* Authored by Alex Hultman, 2018-2019.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* This standalone module implements deflate / inflate streams */
#ifndef UWS_PERMESSAGEDEFLATE_H
#define UWS_PERMESSAGEDEFLATE_H
#ifndef UWS_NO_ZLIB
#include <zlib.h>
#endif
#include <string>
namespace uWS {
/* Do not compile this module if we don't want it */
#ifdef UWS_NO_ZLIB
struct ZlibContext {};
struct InflationStream {
std::string_view inflate(ZlibContext *zlibContext, std::string_view compressed, size_t maxPayloadLength) {
return compressed;
}
};
struct DeflationStream {
std::string_view deflate(ZlibContext *zlibContext, std::string_view raw, bool reset) {
return raw;
}
};
#else
#define LARGE_BUFFER_SIZE 1024 * 16 // todo: fix this
struct ZlibContext {
/* Any returned data is valid until next same-class call.
* We need to have two classes to allow inflation followed
* by many deflations without modifying the inflation */
std::string dynamicDeflationBuffer;
std::string dynamicInflationBuffer;
char *deflationBuffer;
char *inflationBuffer;
ZlibContext() {
deflationBuffer = (char *) malloc(LARGE_BUFFER_SIZE);
inflationBuffer = (char *) malloc(LARGE_BUFFER_SIZE);
}
~ZlibContext() {
free(deflationBuffer);
free(inflationBuffer);
}
};
struct DeflationStream {
z_stream deflationStream = {};
DeflationStream() {
deflateInit2(&deflationStream, 1, Z_DEFLATED, -15, 8, Z_DEFAULT_STRATEGY);
}
/* Deflate and optionally reset */
std::string_view deflate(ZlibContext *zlibContext, std::string_view raw, bool reset) {
/* Odd place to clear this one, fix */
zlibContext->dynamicDeflationBuffer.clear();
deflationStream.next_in = (Bytef *) raw.data();
deflationStream.avail_in = (unsigned int) raw.length();
/* This buffer size has to be at least 6 bytes for Z_SYNC_FLUSH to work */
const int DEFLATE_OUTPUT_CHUNK = LARGE_BUFFER_SIZE;
int err;
do {
deflationStream.next_out = (Bytef *) zlibContext->deflationBuffer;
deflationStream.avail_out = DEFLATE_OUTPUT_CHUNK;
err = ::deflate(&deflationStream, Z_SYNC_FLUSH);
if (Z_OK == err && deflationStream.avail_out == 0) {
zlibContext->dynamicDeflationBuffer.append(zlibContext->deflationBuffer, DEFLATE_OUTPUT_CHUNK - deflationStream.avail_out);
continue;
} else {
break;
}
} while (true);
/* This must not change avail_out */
if (reset) {
deflateReset(&deflationStream);
}
if (zlibContext->dynamicDeflationBuffer.length()) {
zlibContext->dynamicDeflationBuffer.append(zlibContext->deflationBuffer, DEFLATE_OUTPUT_CHUNK - deflationStream.avail_out);
return {(char *) zlibContext->dynamicDeflationBuffer.data(), zlibContext->dynamicDeflationBuffer.length() - 4};
}
return {
zlibContext->deflationBuffer,
DEFLATE_OUTPUT_CHUNK - deflationStream.avail_out - 4
};
}
~DeflationStream() {
deflateEnd(&deflationStream);
}
};
struct InflationStream {
z_stream inflationStream = {};
InflationStream() {
inflateInit2(&inflationStream, -15);
}
~InflationStream() {
inflateEnd(&inflationStream);
}
std::string_view inflate(ZlibContext *zlibContext, std::string_view compressed, size_t maxPayloadLength) {
/* We clear this one here, could be done better */
zlibContext->dynamicInflationBuffer.clear();
inflationStream.next_in = (Bytef *) compressed.data();
inflationStream.avail_in = (unsigned int) compressed.length();
int err;
do {
inflationStream.next_out = (Bytef *) zlibContext->inflationBuffer;
inflationStream.avail_out = LARGE_BUFFER_SIZE;
err = ::inflate(&inflationStream, Z_SYNC_FLUSH);
if (err == Z_OK && inflationStream.avail_out) {
break;
}
zlibContext->dynamicInflationBuffer.append(zlibContext->inflationBuffer, LARGE_BUFFER_SIZE - inflationStream.avail_out);
} while (inflationStream.avail_out == 0 && zlibContext->dynamicInflationBuffer.length() <= maxPayloadLength);
inflateReset(&inflationStream);
if ((err != Z_BUF_ERROR && err != Z_OK) || zlibContext->dynamicInflationBuffer.length() > maxPayloadLength) {
return {nullptr, 0};
}
if (zlibContext->dynamicInflationBuffer.length()) {
zlibContext->dynamicInflationBuffer.append(zlibContext->inflationBuffer, LARGE_BUFFER_SIZE - inflationStream.avail_out);
/* Let's be strict about the max size */
if (zlibContext->dynamicInflationBuffer.length() > maxPayloadLength) {
return {nullptr, 0};
}
return {zlibContext->dynamicInflationBuffer.data(), zlibContext->dynamicInflationBuffer.length()};
}
/* Let's be strict about the max size */
if ((LARGE_BUFFER_SIZE - inflationStream.avail_out) > maxPayloadLength) {
return {nullptr, 0};
}
return {zlibContext->inflationBuffer, LARGE_BUFFER_SIZE - inflationStream.avail_out};
}
};
#endif
}
#endif // UWS_PERMESSAGEDEFLATE_H

View File

@ -0,0 +1,429 @@
/*
* Authored by Alex Hultman, 2018-2019.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef UWS_TOPICTREE_H
#define UWS_TOPICTREE_H
#include <iostream>
#include <vector>
#include <map>
#include <string_view>
#include <functional>
#include <set>
#include <chrono>
#include <list>
namespace uWS {
/* A Subscriber is an extension of a socket */
struct Subscriber {
std::list<struct Topic *> subscriptions;
void *user;
Subscriber(void *user) : user(user) {}
};
struct Topic {
/* Memory for our name */
char *name;
size_t length;
/* Our parent or nullptr */
Topic *parent = nullptr;
/* Next triggered Topic */
bool triggered = false;
/* Exact string matches */
std::map<std::string_view, Topic *> children;
/* Wildcard child */
Topic *wildcardChild = nullptr;
/* Terminating wildcard child */
Topic *terminatingWildcardChild = nullptr;
/* What we published */
std::map<unsigned int, std::string> messages;
std::set<Subscriber *> subs;
};
struct TopicTree {
private:
std::function<int(Subscriber *, std::string_view)> cb;
Topic *root = new Topic;
/* Global messageId for deduplication of overlapping topics and ordering between topics */
unsigned int messageId = 0;
/* The triggered topics */
Topic *triggeredTopics[64];
int numTriggeredTopics = 0;
Subscriber *min = (Subscriber *) UINTPTR_MAX;
/* Cull or trim unused Topic nodes from leaf to root */
void trimTree(Topic *topic) {
if (!topic->subs.size() && !topic->children.size() && !topic->terminatingWildcardChild && !topic->wildcardChild) {
Topic *parent = topic->parent;
if (topic->length == 1) {
if (topic->name[0] == '#') {
parent->terminatingWildcardChild = nullptr;
} else if (topic->name[0] == '+') {
parent->wildcardChild = nullptr;
}
}
/* Erase us from our parents set (wildcards also live here) */
parent->children.erase(std::string_view(topic->name, topic->length));
/* If this node is triggered, make sure to remove it from the triggered list */
if (topic->triggered) {
Topic *tmp[64];
int length = 0;
for (int i = 0; i < numTriggeredTopics; i++) {
if (triggeredTopics[i] != topic) {
tmp[length++] = triggeredTopics[i];
}
}
for (int i = 0; i < length; i++) {
triggeredTopics[i] = tmp[i];
}
numTriggeredTopics = length;
}
/* Free various memory for the node */
delete [] topic->name;
delete topic;
if (parent != root) {
trimTree(parent);
}
}
}
/* Should be getData and commit? */
void publish(Topic *iterator, size_t start, size_t stop, std::string_view topic, std::string_view message) {
/* If we already have 64 triggered topics make sure to drain it here */
if (numTriggeredTopics == 64) {
drain();
}
for (; stop != std::string::npos; start = stop + 1) {
stop = topic.find('/', start);
std::string_view segment = topic.substr(start, stop - start);
/* Do we have a terminating wildcard child? */
if (iterator->terminatingWildcardChild) {
iterator->terminatingWildcardChild->messages[messageId] = message;
/* Add this topic to triggered */
if (!iterator->terminatingWildcardChild->triggered) {
triggeredTopics[numTriggeredTopics++] = iterator->terminatingWildcardChild;
/* Keep track of lowest subscriber */
if (*iterator->terminatingWildcardChild->subs.begin() < min) {
min = *iterator->terminatingWildcardChild->subs.begin();
}
iterator->terminatingWildcardChild->triggered = true;
}
}
/* Do we have a wildcard child? */
if (iterator->wildcardChild) {
publish(iterator->wildcardChild, stop + 1, stop, topic, message);
}
std::map<std::string_view, Topic *>::iterator it = iterator->children.find(segment);
if (it == iterator->children.end()) {
/* Stop trying to match by exact string */
return;
}
iterator = it->second;
}
/* If we went all the way we matched exactly */
iterator->messages[messageId] = message;
/* Add this topic to triggered */
if (!iterator->triggered) {
triggeredTopics[numTriggeredTopics++] = iterator;
/* Keep track of lowest subscriber */
if (*iterator->subs.begin() < min) {
min = *iterator->subs.begin();
}
iterator->triggered = true;
}
}
public:
TopicTree(std::function<int(Subscriber *, std::string_view)> cb) {
this->cb = cb;
}
~TopicTree() {
delete root;
}
void subscribe(std::string_view topic, Subscriber *subscriber) {
/* Start iterating from the root */
Topic *iterator = root;
/* Traverse the topic, inserting a node for every new segment separated by / */
for (size_t start = 0, stop = 0; stop != std::string::npos; start = stop + 1) {
stop = topic.find('/', start);
std::string_view segment = topic.substr(start, stop - start);
auto lb = iterator->children.lower_bound(segment);
if (lb != iterator->children.end() && !(iterator->children.key_comp()(segment, lb->first))) {
iterator = lb->second;
} else {
/* Allocate and insert new node */
Topic *newTopic = new Topic;
newTopic->parent = iterator;
newTopic->name = new char[segment.length()];
newTopic->length = segment.length();
newTopic->terminatingWildcardChild = nullptr;
newTopic->wildcardChild = nullptr;
memcpy(newTopic->name, segment.data(), segment.length());
/* For simplicity we do insert wildcards with text */
iterator->children.insert(lb, {std::string_view(newTopic->name, segment.length()), newTopic});
/* Store fast lookup to wildcards */
if (segment.length() == 1) {
/* If this segment is '+' it is a wildcard */
if (segment[0] == '+') {
iterator->wildcardChild = newTopic;
}
/* If this segment is '#' it is a terminating wildcard */
if (segment[0] == '#') {
iterator->terminatingWildcardChild = newTopic;
}
}
iterator = newTopic;
}
}
/* Add socket to Topic's Set */
auto [it, inserted] = iterator->subs.insert(subscriber);
/* Add Topic to list of subscriptions only if we weren't already subscribed */
if (inserted) {
subscriber->subscriptions.push_back(iterator);
}
}
void publish(std::string_view topic, std::string_view message) {
publish(root, 0, 0, topic, message);
messageId++;
}
/* Returns whether we were subscribed prior */
bool unsubscribe(std::string_view topic, Subscriber *subscriber) {
/* Subscribers are likely to have very few subscriptions (20 or fewer) */
if (subscriber) {
/* Lookup exact Topic ptr from string */
Topic *iterator = root;
for (size_t start = 0, stop = 0; stop != std::string::npos; start = stop + 1) {
stop = topic.find('/', start);
std::string_view segment = topic.substr(start, stop - start);
std::map<std::string_view, Topic *>::iterator it = iterator->children.find(segment);
if (it == iterator->children.end()) {
/* This topic does not even exist */
return false;
}
iterator = it->second;
}
/* Try and remove this topic from our list */
for (auto it = subscriber->subscriptions.begin(); it != subscriber->subscriptions.end(); it++) {
if (*it == iterator) {
/* Remove topic ptr from our list */
subscriber->subscriptions.erase(it);
/* Remove us from Topic's subs */
iterator->subs.erase(subscriber);
trimTree(iterator);
return true;
}
}
}
return false;
}
/* Can be called with nullptr, ignore it then */
void unsubscribeAll(Subscriber *subscriber) {
if (subscriber) {
for (Topic *topic : subscriber->subscriptions) {
topic->subs.erase(subscriber);
trimTree(topic);
}
subscriber->subscriptions.clear();
}
}
/* Drain the tree by emitting what to send with every Subscriber */
/* Better name would be commit() and making it public so that one can commit and shutdown, etc */
void drain() {
/* Do nothing if nothing to send */
if (!numTriggeredTopics) {
return;
}
/* bug fix: Filter triggered topics without subscribers */
int numFilteredTriggeredTopics = 0;
for (int i = 0; i < numTriggeredTopics; i++) {
if (triggeredTopics[i]->subs.size()) {
triggeredTopics[numFilteredTriggeredTopics++] = triggeredTopics[i];
}
}
numTriggeredTopics = numFilteredTriggeredTopics;
if (!numTriggeredTopics) {
return;
}
/* bug fix: update min, as the one tracked via subscribe gets invalid as you unsubscribe */
min = (Subscriber *)UINTPTR_MAX;
for (int i = 0; i < numTriggeredTopics; i++) {
if ((triggeredTopics[i]->subs.size()) && (min > *triggeredTopics[i]->subs.begin())) {
min = *triggeredTopics[i]->subs.begin();
}
}
/* Check if we really have any sockets still */
if (min != (Subscriber *)UINTPTR_MAX) {
/* Up to 64 triggered Topics per batch */
std::map<uint64_t, std::string> intersectionCache;
/* Loop over these here */
std::set<Subscriber *>::iterator it[64];
std::set<Subscriber *>::iterator end[64];
for (int i = 0; i < numTriggeredTopics; i++) {
it[i] = triggeredTopics[i]->subs.begin();
end[i] = triggeredTopics[i]->subs.end();
}
/* Empty all sets from unique subscribers */
for (int nonEmpty = numTriggeredTopics; nonEmpty; ) {
Subscriber *nextMin = (Subscriber *)UINTPTR_MAX;
/* The message sets relevant for this intersection */
std::map<unsigned int, std::string> *perSubscriberIntersectingTopicMessages[64];
int numPerSubscriberIntersectingTopicMessages = 0;
uint64_t intersection = 0;
for (int i = 0; i < numTriggeredTopics; i++) {
if ((it[i] != end[i]) && (*it[i] == min)) {
/* Mark this intersection */
intersection |= ((uint64_t)1 << i);
perSubscriberIntersectingTopicMessages[numPerSubscriberIntersectingTopicMessages++] = &triggeredTopics[i]->messages;
it[i]++;
if (it[i] == end[i]) {
nonEmpty--;
}
else {
if (nextMin > *it[i]) {
nextMin = *it[i];
}
}
}
else {
/* We need to lower nextMin to us, in the case of min being the last in a set */
if ((it[i] != end[i]) && (nextMin > *it[i])) {
nextMin = *it[i];
}
}
}
/* Generate cache for intersection */
if (intersectionCache[intersection].length() == 0) {
/* Build the union in order without duplicates */
std::map<unsigned int, std::string> complete;
for (int i = 0; i < numPerSubscriberIntersectingTopicMessages; i++) {
complete.insert(perSubscriberIntersectingTopicMessages[i]->begin(), perSubscriberIntersectingTopicMessages[i]->end());
}
/* Create the linear cache */
std::string res;
for (auto &p : complete) {
res.append(p.second);
}
cb(min, intersectionCache[intersection] = std::move(res));
}
else {
cb(min, intersectionCache[intersection]);
}
min = nextMin;
}
}
/* Clear messages of triggered Topics */
for (int i = 0; i < numTriggeredTopics; i++) {
triggeredTopics[i]->messages.clear();
triggeredTopics[i]->triggered = false;
}
numTriggeredTopics = 0;
}
void print(Topic *root = nullptr, int indentation = 1) {
if (root == nullptr) {
std::cout << "Print of tree:" << std::endl;
root = this->root;
}
for (auto p : root->children) {
for (int i = 0; i < indentation; i++) {
std::cout << " ";
}
std::cout << std::string_view(p.second->name, p.second->length) << " = " << p.second->messages.size() << " publishes, " << p.second->subs.size() << " subscribers {";
for (auto &p : p.second->subs) {
std::cout << p << " referring to socket: " << p->user << ", ";
}
std::cout << "}" << std::endl;
print(p.second, indentation + 1);
}
}
};
}
#endif

View File

@ -0,0 +1,66 @@
/*
* Authored by Alex Hultman, 2018-2019.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef UWS_UTILITIES_H
#define UWS_UTILITIES_H
/* Various common utilities */
#include <cstdint>
namespace uWS {
namespace utils {
inline int u32toaHex(uint32_t value, char *dst) {
char palette[] = "0123456789abcdef";
char temp[10];
char *p = temp;
do {
*p++ = palette[value % 16];
value /= 16;
} while (value > 0);
int ret = (int) (p - temp);
do {
*dst++ = *--p;
} while (p != temp);
return ret;
}
inline int u32toa(uint32_t value, char *dst) {
char temp[10];
char *p = temp;
do {
*p++ = (char) ((value % 10) + '0');
value /= 10;
} while (value > 0);
int ret = (int) (p - temp);
do {
*dst++ = *--p;
} while (p != temp);
return ret;
}
}
}
#endif // UWS_UTILITIES_H

View File

@ -0,0 +1,213 @@
/*
* Authored by Alex Hultman, 2018-2019.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef UWS_WEBSOCKET_H
#define UWS_WEBSOCKET_H
#include "WebSocketData.h"
#include "WebSocketProtocol.h"
#include "AsyncSocket.h"
#include "WebSocketContextData.h"
#include <string_view>
namespace uWS {
template <bool SSL, bool isServer>
struct WebSocket : AsyncSocket<SSL> {
template <bool> friend struct TemplatedApp;
private:
typedef AsyncSocket<SSL> Super;
void *init(bool perMessageDeflate, bool slidingCompression, std::string &&backpressure) {
new (us_socket_ext(SSL, (us_socket_t *) this)) WebSocketData(perMessageDeflate, slidingCompression, std::move(backpressure));
return this;
}
public:
/* Returns pointer to the per socket user data */
void *getUserData() {
WebSocketData *webSocketData = (WebSocketData *) us_socket_ext(SSL, (us_socket_t *) this);
/* We just have it overallocated by sizeof type */
return (webSocketData + 1);
}
/* See AsyncSocket */
using Super::getBufferedAmount;
using Super::getRemoteAddress;
/* Simple, immediate close of the socket. Emits close event */
using Super::close;
/* Send or buffer a WebSocket frame, compressed or not. Returns false on increased user space backpressure. */
bool send(std::string_view message, uWS::OpCode opCode = uWS::OpCode::BINARY, bool compress = false) {
/* Transform the message to compressed domain if requested */
if (compress) {
WebSocketData *webSocketData = (WebSocketData *) Super::getAsyncSocketData();
/* Check and correct the compress hint */
if (opCode < 3 && webSocketData->compressionStatus == WebSocketData::ENABLED) {
LoopData *loopData = Super::getLoopData();
/* Compress using either shared or dedicated deflationStream */
if (webSocketData->deflationStream) {
message = webSocketData->deflationStream->deflate(loopData->zlibContext, message, false);
} else {
message = loopData->deflationStream->deflate(loopData->zlibContext, message, true);
}
} else {
compress = false;
}
}
/* Check to see if we can cork for the user */
bool automaticallyCorked = false;
if (!Super::isCorked() && Super::canCork()) {
automaticallyCorked = true;
Super::cork();
}
/* Get size, alloate size, write if needed */
size_t messageFrameSize = protocol::messageFrameSize(message.length());
auto[sendBuffer, requiresWrite] = Super::getSendBuffer(messageFrameSize);
protocol::formatMessage<isServer>(sendBuffer, message.data(), message.length(), opCode, message.length(), compress);
/* This is the slow path, when we couldn't cork for the user */
if (requiresWrite) {
auto[written, failed] = Super::write(sendBuffer, (int) messageFrameSize);
/* For now, we are slow here */
free(sendBuffer);
if (failed) {
/* Return false for failure, skipping to reset the timeout below */
return false;
}
}
/* Uncork here if we automatically corked for the user */
if (automaticallyCorked) {
auto [written, failed] = Super::uncork();
if (failed) {
return false;
}
}
/* Every successful send resets the timeout */
WebSocketContextData<SSL> *webSocketContextData = (WebSocketContextData<SSL> *) us_socket_context_ext(SSL,
(us_socket_context_t *) us_socket_context(SSL, (us_socket_t *) this)
);
AsyncSocket<SSL>::timeout(webSocketContextData->idleTimeout);
/* Return success */
return true;
}
/* Send websocket close frame, emit close event, send FIN if successful */
void end(int code, std::string_view message = {}) {
/* Check if we already called this one */
WebSocketData *webSocketData = (WebSocketData *) us_socket_ext(SSL, (us_socket_t *) this);
if (webSocketData->isShuttingDown) {
return;
}
/* We postpone any FIN sending to either drainage or uncorking */
webSocketData->isShuttingDown = true;
/* Format and send the close frame */
static const int MAX_CLOSE_PAYLOAD = 123;
int length = (int) std::min<size_t>(MAX_CLOSE_PAYLOAD, message.length());
char closePayload[MAX_CLOSE_PAYLOAD + 2];
int closePayloadLength = (int) protocol::formatClosePayload(closePayload, (uint16_t) code, message.data(), length);
bool ok = send(std::string_view(closePayload, closePayloadLength), OpCode::CLOSE);
/* FIN if we are ok and not corked */
WebSocket<SSL, true> *webSocket = (WebSocket<SSL, true> *) this;
if (!webSocket->isCorked()) {
if (ok) {
/* If we are not corked, and we just sent off everything, we need to FIN right here.
* In all other cases, we need to fin either if uncork was successful, or when drainage is complete. */
webSocket->shutdown();
}
}
/* Emit close event */
WebSocketContextData<SSL> *webSocketContextData = (WebSocketContextData<SSL> *) us_socket_context_ext(SSL,
(us_socket_context_t *) us_socket_context(SSL, (us_socket_t *) this)
);
if (webSocketContextData->closeHandler) {
webSocketContextData->closeHandler(this, code, message);
}
/* Make sure to unsubscribe from any pub/sub node at exit */
webSocketContextData->topicTree.unsubscribeAll(webSocketData->subscriber);
delete webSocketData->subscriber;
webSocketData->subscriber = nullptr;
}
/* Corks the response if possible. Leaves already corked socket be. */
void cork(fu2::unique_function<void()> &&handler) {
if (!Super::isCorked() && Super::canCork()) {
Super::cork();
handler();
/* There is no timeout when failing to uncork for WebSockets,
* as that is handled by idleTimeout */
auto [written, failed] = Super::uncork();
} else {
/* We are already corked, or can't cork so let's just call the handler */
handler();
}
}
/* Subscribe to a topic according to MQTT rules and syntax */
void subscribe(std::string_view topic) {
WebSocketContextData<SSL> *webSocketContextData = (WebSocketContextData<SSL> *) us_socket_context_ext(SSL,
(us_socket_context_t *) us_socket_context(SSL, (us_socket_t *) this)
);
/* Make us a subscriber if we aren't yet */
WebSocketData *webSocketData = (WebSocketData *) us_socket_ext(SSL, (us_socket_t *) this);
if (!webSocketData->subscriber) {
webSocketData->subscriber = new Subscriber(this);
}
webSocketContextData->topicTree.subscribe(topic, webSocketData->subscriber);
}
/* Unsubscribe from a topic, returns true if we were subscribed */
bool unsubscribe(std::string_view topic) {
WebSocketContextData<SSL> *webSocketContextData = (WebSocketContextData<SSL> *) us_socket_context_ext(SSL,
(us_socket_context_t *) us_socket_context(SSL, (us_socket_t *) this)
);
WebSocketData *webSocketData = (WebSocketData *) us_socket_ext(SSL, (us_socket_t *) this);
return webSocketContextData->topicTree.unsubscribe(topic, webSocketData->subscriber);
}
/* Publish a message to a topic according to MQTT rules and syntax */
void publish(std::string_view topic, std::string_view message, OpCode opCode = OpCode::TEXT, bool compress = false) {
WebSocketContextData<SSL> *webSocketContextData = (WebSocketContextData<SSL> *) us_socket_context_ext(SSL,
(us_socket_context_t *) us_socket_context(SSL, (us_socket_t *) this)
);
/* Is the same as publishing per websocket context */
webSocketContextData->publish(topic, message, opCode, compress);
}
};
}
#endif // UWS_WEBSOCKET_H

View File

@ -0,0 +1,380 @@
/*
* Authored by Alex Hultman, 2018-2019.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef UWS_WEBSOCKETCONTEXT_H
#define UWS_WEBSOCKETCONTEXT_H
#include "WebSocketContextData.h"
#include "WebSocketProtocol.h"
#include "WebSocketData.h"
#include "WebSocket.h"
namespace uWS {
template <bool SSL, bool isServer>
struct WebSocketContext {
template <bool> friend struct TemplatedApp;
template <bool, typename> friend struct WebSocketProtocol;
private:
WebSocketContext() = delete;
us_socket_context_t *getSocketContext() {
return (us_socket_context_t *) this;
}
WebSocketContextData<SSL> *getExt() {
return (WebSocketContextData<SSL> *) us_socket_context_ext(SSL, (us_socket_context_t *) this);
}
/* If we have negotiated compression, set this frame compressed */
static bool setCompressed(uWS::WebSocketState<isServer> *wState, void *s) {
WebSocketData *webSocketData = (WebSocketData *) us_socket_ext(SSL, (us_socket_t *) s);
if (webSocketData->compressionStatus == WebSocketData::CompressionStatus::ENABLED) {
webSocketData->compressionStatus = WebSocketData::CompressionStatus::COMPRESSED_FRAME;
return true;
} else {
return false;
}
}
static void forceClose(uWS::WebSocketState<isServer> *wState, void *s) {
us_socket_close(SSL, (us_socket_t *) s);
}
/* Returns true on breakage */
static bool handleFragment(char *data, size_t length, unsigned int remainingBytes, int opCode, bool fin, uWS::WebSocketState<isServer> *webSocketState, void *s) {
/* WebSocketData and WebSocketContextData */
WebSocketContextData<SSL> *webSocketContextData = (WebSocketContextData<SSL> *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s));
WebSocketData *webSocketData = (WebSocketData *) us_socket_ext(SSL, (us_socket_t *) s);
/* Is this a non-control frame? */
if (opCode < 3) {
/* Did we get everything in one go? */
if (!remainingBytes && fin && !webSocketData->fragmentBuffer.length()) {
/* Handle compressed frame */
if (webSocketData->compressionStatus == WebSocketData::CompressionStatus::COMPRESSED_FRAME) {
webSocketData->compressionStatus = WebSocketData::CompressionStatus::ENABLED;
LoopData *loopData = (LoopData *) us_loop_ext(us_socket_context_loop(SSL, us_socket_context(SSL, (us_socket_t *) s)));
std::string_view inflatedFrame = loopData->inflationStream->inflate(loopData->zlibContext, {data, length}, webSocketContextData->maxPayloadLength);
if (!inflatedFrame.length()) {
forceClose(webSocketState, s);
return true;
} else {
data = (char *) inflatedFrame.data();
length = inflatedFrame.length();
}
}
/* Check text messages for Utf-8 validity */
if (opCode == 1 && !protocol::isValidUtf8((unsigned char *) data, length)) {
forceClose(webSocketState, s);
return true;
}
/* Emit message event & break if we are closed or shut down when returning */
if (webSocketContextData->messageHandler) {
webSocketContextData->messageHandler((WebSocket<SSL, isServer> *) s, std::string_view(data, length), (uWS::OpCode) opCode);
if (us_socket_is_closed(SSL, (us_socket_t *) s) || webSocketData->isShuttingDown) {
return true;
}
}
} else {
/* Allocate fragment buffer up front first time */
if (!webSocketData->fragmentBuffer.length()) {
webSocketData->fragmentBuffer.reserve(length + remainingBytes);
}
/* Fragments forming a big message are not caught until appending them */
if (refusePayloadLength(length + webSocketData->fragmentBuffer.length(), webSocketState, s)) {
forceClose(webSocketState, s);
return true;
}
webSocketData->fragmentBuffer.append(data, length);
/* Are we done now? */
// todo: what if we don't have any remaining bytes yet we are not fin? forceclose!
if (!remainingBytes && fin) {
/* Handle compression */
if (webSocketData->compressionStatus == WebSocketData::CompressionStatus::COMPRESSED_FRAME) {
webSocketData->compressionStatus = WebSocketData::CompressionStatus::ENABLED;
// what's really the story here?
webSocketData->fragmentBuffer.append("....");
LoopData *loopData = (LoopData *) us_loop_ext(
us_socket_context_loop(SSL,
us_socket_context(SSL, (us_socket_t *) s)
)
);
std::string_view inflatedFrame = loopData->inflationStream->inflate(loopData->zlibContext, {webSocketData->fragmentBuffer.data(), webSocketData->fragmentBuffer.length() - 4}, webSocketContextData->maxPayloadLength);
if (!inflatedFrame.length()) {
forceClose(webSocketState, s);
return true;
} else {
data = (char *) inflatedFrame.data();
length = inflatedFrame.length();
}
} else {
// reset length and data ptrs
length = webSocketData->fragmentBuffer.length();
data = webSocketData->fragmentBuffer.data();
}
/* Check text messages for Utf-8 validity */
if (opCode == 1 && !protocol::isValidUtf8((unsigned char *) data, length)) {
forceClose(webSocketState, s);
return true;
}
/* Emit message and check for shutdown or close */
if (webSocketContextData->messageHandler) {
webSocketContextData->messageHandler((WebSocket<SSL, isServer> *) s, std::string_view(data, length), (uWS::OpCode) opCode);
if (us_socket_is_closed(SSL, (us_socket_t *) s) || webSocketData->isShuttingDown) {
return true;
}
}
/* If we shutdown or closed, this will be taken care of elsewhere */
webSocketData->fragmentBuffer.clear();
}
}
} else {
/* Control frames need the websocket to send pings, pongs and close */
WebSocket<SSL, isServer> *webSocket = (WebSocket<SSL, isServer> *) s;
if (!remainingBytes && fin && !webSocketData->controlTipLength) {
if (opCode == CLOSE) {
auto closeFrame = protocol::parseClosePayload(data, length);
webSocket->end(closeFrame.code, std::string_view(closeFrame.message, closeFrame.length));
return true;
} else {
if (opCode == PING) {
webSocket->send(std::string_view(data, length), (OpCode) OpCode::PONG);
/*group->pingHandler(webSocket, data, length);
if (webSocket->isClosed() || webSocket->isShuttingDown()) {
return true;
}*/
} else if (opCode == PONG) {
/*group->pongHandler(webSocket, data, length);
if (webSocket->isClosed() || webSocket->isShuttingDown()) {
return true;
}*/
}
}
} else {
/* Here we never mind any size optimizations as we are in the worst possible path */
webSocketData->fragmentBuffer.append(data, length);
webSocketData->controlTipLength += (int) length;
if (!remainingBytes && fin) {
char *controlBuffer = (char *) webSocketData->fragmentBuffer.data() + webSocketData->fragmentBuffer.length() - webSocketData->controlTipLength;
if (opCode == CLOSE) {
protocol::CloseFrame closeFrame = protocol::parseClosePayload(controlBuffer, webSocketData->controlTipLength);
webSocket->end(closeFrame.code, std::string_view(closeFrame.message, closeFrame.length));
return true;
} else {
if (opCode == PING) {
webSocket->send(std::string_view(controlBuffer, webSocketData->controlTipLength), (OpCode) OpCode::PONG);
/*group->pingHandler(webSocket, controlBuffer, webSocket->controlTipLength);
if (webSocket->isClosed() || webSocket->isShuttingDown()) {
return true;
}*/
} else if (opCode == PONG) {
/*group->pongHandler(webSocket, controlBuffer, webSocket->controlTipLength);
if (webSocket->isClosed() || webSocket->isShuttingDown()) {
return true;
}*/
}
}
/* Same here, we do not care for any particular smart allocation scheme */
webSocketData->fragmentBuffer.resize(webSocketData->fragmentBuffer.length() - webSocketData->controlTipLength);
webSocketData->controlTipLength = 0;
}
}
}
return false;
}
static bool refusePayloadLength(uint64_t length, uWS::WebSocketState<isServer> *wState, void *s) {
auto *webSocketContextData = (WebSocketContextData<SSL> *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s));
/* Return true for refuse, false for accept */
return webSocketContextData->maxPayloadLength < length;
}
WebSocketContext<SSL, isServer> *init() {
/* Adopting a socket does not trigger open event.
* We arreive as WebSocket with timeout set and
* any backpressure from HTTP state kept. */
/* Handle socket disconnections */
us_socket_context_on_close(SSL, getSocketContext(), [](auto *s) {
/* For whatever reason, if we already have emitted close event, do not emit it again */
WebSocketData *webSocketData = (WebSocketData *) (us_socket_ext(SSL, s));
if (!webSocketData->isShuttingDown) {
/* Emit close event */
auto *webSocketContextData = (WebSocketContextData<SSL> *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s));
if (webSocketContextData->closeHandler) {
webSocketContextData->closeHandler((WebSocket<SSL, true> *) s, 1006, {});
}
/* Make sure to unsubscribe from any pub/sub node at exit */
webSocketContextData->topicTree.unsubscribeAll(webSocketData->subscriber);
delete webSocketData->subscriber;
webSocketData->subscriber = nullptr;
}
/* Destruct in-placed data struct */
webSocketData->~WebSocketData();
return s;
});
/* Handle WebSocket data streams */
us_socket_context_on_data(SSL, getSocketContext(), [](auto *s, char *data, int length) {
/* We need the websocket data */
WebSocketData *webSocketData = (WebSocketData *) (us_socket_ext(SSL, s));
/* When in websocket shutdown mode, we do not care for ANY message, whether responding close frame or not.
* We only care for the TCP FIN really, not emitting any message after closing is key */
if (webSocketData->isShuttingDown) {
return s;
}
auto *webSocketContextData = (WebSocketContextData<SSL> *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s));
auto *asyncSocket = (AsyncSocket<SSL> *) s;
/* Every time we get data and not in shutdown state we simply reset the timeout */
asyncSocket->timeout(webSocketContextData->idleTimeout);
/* We always cork on data */
asyncSocket->cork();
/* This parser has virtually no overhead */
uWS::WebSocketProtocol<isServer, WebSocketContext<SSL, isServer>>::consume(data, length, (WebSocketState<isServer> *) webSocketData, s);
/* Uncorking a closed socekt is fine, in fact it is needed */
asyncSocket->uncork();
/* If uncorking was successful and we are in shutdown state then send TCP FIN */
if (asyncSocket->getBufferedAmount() == 0) {
/* We can now be in shutdown state */
if (webSocketData->isShuttingDown) {
/* Shutting down a closed socket is handled by uSockets and just fine */
asyncSocket->shutdown();
}
}
return s;
});
/* Handle HTTP write out (note: SSL_read may trigger this spuriously, the app need to handle spurious calls) */
us_socket_context_on_writable(SSL, getSocketContext(), [](auto *s) {
/* It makes sense to check for us_is_shut_down here and return if so, to avoid shutting down twice */
if (us_socket_is_shut_down(SSL, (us_socket_t *) s)) {
return s;
}
AsyncSocket<SSL> *asyncSocket = (AsyncSocket<SSL> *) s;
WebSocketData *webSocketData = (WebSocketData *)(us_socket_ext(SSL, s));
/* We store old backpressure since it is unclear whether write drained anything */
int backpressure = asyncSocket->getBufferedAmount();
/* Drain as much as possible */
asyncSocket->write(nullptr, 0);
/* Behavior: if we actively drain backpressure, always reset timeout (even if we are in shutdown) */
if (backpressure < asyncSocket->getBufferedAmount()) {
auto *webSocketContextData = (WebSocketContextData<SSL> *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s));
asyncSocket->timeout(webSocketContextData->idleTimeout);
}
/* Are we in (WebSocket) shutdown mode? */
if (webSocketData->isShuttingDown) {
/* Check if we just now drained completely */
if (asyncSocket->getBufferedAmount() == 0) {
/* Now perform the actual TCP/TLS shutdown which was postponed due to backpressure */
asyncSocket->shutdown();
}
} else if (backpressure > asyncSocket->getBufferedAmount()) {
/* Only call drain if we actually drained backpressure */
auto *webSocketContextData = (WebSocketContextData<SSL> *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s));
if (webSocketContextData->drainHandler) {
webSocketContextData->drainHandler((WebSocket<SSL, isServer> *) s);
}
/* No need to check for closed here as we leave the handler immediately*/
}
return s;
});
/* Handle FIN, HTTP does not support half-closed sockets, so simply close */
us_socket_context_on_end(SSL, getSocketContext(), [](auto *s) {
/* If we get a fin, we just close I guess */
us_socket_close(SSL, (us_socket_t *) s);
return s;
});
/* Handle socket timeouts, simply close them so to not confuse client with FIN */
us_socket_context_on_timeout(SSL, getSocketContext(), [](auto *s) {
/* Timeout is very simple; we just close it */
us_socket_close(SSL, (us_socket_t *) s);
return s;
});
return this;
}
void free() {
WebSocketContextData<SSL> *webSocketContextData = (WebSocketContextData<SSL> *) us_socket_context_ext(SSL, (us_socket_context_t *) this);
webSocketContextData->~WebSocketContextData();
us_socket_context_free(SSL, (us_socket_context_t *) this);
}
public:
/* WebSocket contexts are always child contexts to a HTTP context so no SSL options are needed as they are inherited */
static WebSocketContext *create(Loop *loop, us_socket_context_t *parentSocketContext) {
WebSocketContext *webSocketContext = (WebSocketContext *) us_create_child_socket_context(SSL, parentSocketContext, sizeof(WebSocketContextData<SSL>));
if (!webSocketContext) {
return nullptr;
}
/* Init socket context data */
new ((WebSocketContextData<SSL> *) us_socket_context_ext(SSL, (us_socket_context_t *)webSocketContext)) WebSocketContextData<SSL>;
return webSocketContext->init();
}
};
}
#endif // UWS_WEBSOCKETCONTEXT_H

View File

@ -0,0 +1,100 @@
/*
* Authored by Alex Hultman, 2018-2019.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef UWS_WEBSOCKETCONTEXTDATA_H
#define UWS_WEBSOCKETCONTEXTDATA_H
#include "f2/function2.hpp"
#include <string_view>
#include "WebSocketProtocol.h"
#include "TopicTreeDraft.h"
namespace uWS {
template <bool, bool> struct WebSocket;
/* todo: this looks identical to WebSocketBehavior, why not just std::move that entire thing in? */
template <bool SSL>
struct WebSocketContextData {
/* The callbacks for this context */
fu2::unique_function<void(WebSocket<SSL, true> *, std::string_view, uWS::OpCode)> messageHandler = nullptr;
fu2::unique_function<void(WebSocket<SSL, true> *)> drainHandler = nullptr;
fu2::unique_function<void(WebSocket<SSL, true> *, int, std::string_view)> closeHandler = nullptr;
/* Settings for this context */
size_t maxPayloadLength = 0;
int idleTimeout = 0;
/* There needs to be a maxBackpressure which will force close everything over that limit */
size_t maxBackpressure = 0;
/* Each websocket context has a topic tree for pub/sub */
TopicTree topicTree;
~WebSocketContextData() {
/* We must unregister any loop post handler here */
Loop::get()->removePostHandler(this);
Loop::get()->removePreHandler(this);
}
WebSocketContextData() : topicTree([this](Subscriber *s, std::string_view data) -> int {
/* We rely on writing to regular asyncSockets */
auto *asyncSocket = (AsyncSocket<SSL> *) s->user;
auto [written, failed] = asyncSocket->write(data.data(), (int) data.length());
if (!failed) {
asyncSocket->timeout(this->idleTimeout);
} else {
/* Note: this assumes we are not corked, as corking will swallow things and fail later on */
/* Check if we now have too much backpressure (todo: don't buffer up before check) */
if ((unsigned int) asyncSocket->getBufferedAmount() > maxBackpressure) {
asyncSocket->close();
}
}
/* Reserved, unused */
return 0;
}) {
/* We empty for both pre and post just to make sure */
Loop::get()->addPostHandler(this, [this](Loop *loop) {
/* Commit pub/sub batches every loop iteration */
topicTree.drain();
});
Loop::get()->addPreHandler(this, [this](Loop *loop) {
/* Commit pub/sub batches every loop iteration */
topicTree.drain();
});
}
/* Helper for topictree publish, common path from app and ws */
void publish(std::string_view topic, std::string_view message, OpCode opCode, bool compress) {
/* We frame the message right here and only pass raw bytes to the pub/subber */
char *dst = (char *) malloc(protocol::messageFrameSize(message.size()));
size_t dst_length = protocol::formatMessage<true>(dst, message.data(), message.length(), opCode, message.length(), false);
topicTree.publish(topic, std::string_view(dst, dst_length));
::free(dst);
}
};
}
#endif // UWS_WEBSOCKETCONTEXTDATA_H

View File

@ -0,0 +1,70 @@
/*
* Authored by Alex Hultman, 2018-2019.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef UWS_WEBSOCKETDATA_H
#define UWS_WEBSOCKETDATA_H
#include "WebSocketProtocol.h"
#include "AsyncSocketData.h"
#include "PerMessageDeflate.h"
#include <string>
namespace uWS {
struct WebSocketData : AsyncSocketData<false>, WebSocketState<true> {
template <bool, bool> friend struct WebSocketContext;
template <bool, bool> friend struct WebSocket;
private:
std::string fragmentBuffer;
int controlTipLength = 0;
bool isShuttingDown = 0;
enum CompressionStatus : char {
DISABLED,
ENABLED,
COMPRESSED_FRAME
} compressionStatus;
/* We might have a dedicated compressor */
DeflationStream *deflationStream = nullptr;
/* We could be a subscriber */
Subscriber *subscriber = nullptr;
public:
WebSocketData(bool perMessageDeflate, bool slidingCompression, std::string &&backpressure) : AsyncSocketData<false>(std::move(backpressure)), WebSocketState<true>() {
compressionStatus = perMessageDeflate ? ENABLED : DISABLED;
/* Initialize the dedicated sliding window */
if (perMessageDeflate && slidingCompression) {
deflationStream = new DeflationStream;
}
}
~WebSocketData() {
if (deflationStream) {
delete deflationStream;
}
if (subscriber) {
delete subscriber;
}
}
};
}
#endif // UWS_WEBSOCKETDATA_H

View File

@ -0,0 +1,169 @@
/*
* Authored by Alex Hultman, 2018-2019.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef UWS_WEBSOCKETEXTENSIONS_H
#define UWS_WEBSOCKETEXTENSIONS_H
#include <climits>
#include <string_view>
namespace uWS {
enum Options : unsigned int {
NO_OPTIONS = 0,
PERMESSAGE_DEFLATE = 1,
SERVER_NO_CONTEXT_TAKEOVER = 2, // remove this
CLIENT_NO_CONTEXT_TAKEOVER = 4, // remove this
NO_DELAY = 8,
SLIDING_DEFLATE_WINDOW = 16
};
enum ExtensionTokens {
TOK_PERMESSAGE_DEFLATE = 1838,
TOK_SERVER_NO_CONTEXT_TAKEOVER = 2807,
TOK_CLIENT_NO_CONTEXT_TAKEOVER = 2783,
TOK_SERVER_MAX_WINDOW_BITS = 2372,
TOK_CLIENT_MAX_WINDOW_BITS = 2348
};
struct ExtensionsParser {
private:
int *lastInteger = nullptr;
public:
bool perMessageDeflate = false;
bool serverNoContextTakeover = false;
bool clientNoContextTakeover = false;
int serverMaxWindowBits = 0;
int clientMaxWindowBits = 0;
int getToken(const char *&in, const char *stop) {
while (in != stop && !isalnum(*in)) {
in++;
}
/* Don't care more than this for now */
static_assert(SHRT_MIN > INT_MIN, "Integer overflow fix is invalid for this platform, report this as a bug!");
int hashedToken = 0;
while (in != stop && (isalnum(*in) || *in == '-' || *in == '_')) {
if (isdigit(*in)) {
/* This check is a quick and incorrect fix for integer overflow
* in oss-fuzz but we don't care as it doesn't matter either way */
if (hashedToken > SHRT_MIN && hashedToken < SHRT_MAX) {
hashedToken = hashedToken * 10 - (*in - '0');
}
} else {
hashedToken += *in;
}
in++;
}
return hashedToken;
}
ExtensionsParser(const char *data, size_t length) {
const char *stop = data + length;
int token = 1;
for (; token && token != TOK_PERMESSAGE_DEFLATE; token = getToken(data, stop));
perMessageDeflate = (token == TOK_PERMESSAGE_DEFLATE);
while ((token = getToken(data, stop))) {
switch (token) {
case TOK_PERMESSAGE_DEFLATE:
return;
case TOK_SERVER_NO_CONTEXT_TAKEOVER:
serverNoContextTakeover = true;
break;
case TOK_CLIENT_NO_CONTEXT_TAKEOVER:
clientNoContextTakeover = true;
break;
case TOK_SERVER_MAX_WINDOW_BITS:
serverMaxWindowBits = 1;
lastInteger = &serverMaxWindowBits;
break;
case TOK_CLIENT_MAX_WINDOW_BITS:
clientMaxWindowBits = 1;
lastInteger = &clientMaxWindowBits;
break;
default:
if (token < 0 && lastInteger) {
*lastInteger = -token;
}
break;
}
}
}
};
template <bool isServer>
struct ExtensionsNegotiator {
protected:
int options;
public:
ExtensionsNegotiator(int wantedOptions) {
options = wantedOptions;
}
std::string generateOffer() {
std::string extensionsOffer;
if (options & Options::PERMESSAGE_DEFLATE) {
extensionsOffer += "permessage-deflate";
if (options & Options::CLIENT_NO_CONTEXT_TAKEOVER) {
extensionsOffer += "; client_no_context_takeover";
}
/* It is questionable sending this improves anything */
/*if (options & Options::SERVER_NO_CONTEXT_TAKEOVER) {
extensionsOffer += "; server_no_context_takeover";
}*/
}
return extensionsOffer;
}
void readOffer(std::string_view offer) {
if (isServer) {
ExtensionsParser extensionsParser(offer.data(), offer.length());
if ((options & PERMESSAGE_DEFLATE) && extensionsParser.perMessageDeflate) {
if (extensionsParser.clientNoContextTakeover || (options & CLIENT_NO_CONTEXT_TAKEOVER)) {
options |= CLIENT_NO_CONTEXT_TAKEOVER;
}
/* We leave this option for us to read even if the client did not send it */
if (extensionsParser.serverNoContextTakeover) {
options |= SERVER_NO_CONTEXT_TAKEOVER;
}/* else {
options &= ~SERVER_NO_CONTEXT_TAKEOVER;
}*/
} else {
options &= ~PERMESSAGE_DEFLATE;
}
} else {
// todo!
}
}
int getNegotiatedOptions() {
return options;
}
};
}
#endif // UWS_WEBSOCKETEXTENSIONS_H

View File

@ -0,0 +1,134 @@
/*
* Authored by Alex Hultman, 2018-2019.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef UWS_WEBSOCKETHANDSHAKE_H
#define UWS_WEBSOCKETHANDSHAKE_H
#include <cstdint>
#include <cstddef>
namespace uWS {
struct WebSocketHandshake {
template <int N, typename T>
struct static_for {
void operator()(uint32_t *a, uint32_t *b) {
static_for<N - 1, T>()(a, b);
T::template f<N - 1>(a, b);
}
};
template <typename T>
struct static_for<0, T> {
void operator()(uint32_t *a, uint32_t *hash) {}
};
template <int state>
struct Sha1Loop {
static inline uint32_t rol(uint32_t value, size_t bits) {return (value << bits) | (value >> (32 - bits));}
static inline uint32_t blk(uint32_t b[16], size_t i) {
return rol(b[(i + 13) & 15] ^ b[(i + 8) & 15] ^ b[(i + 2) & 15] ^ b[i], 1);
}
template <int i>
static inline void f(uint32_t *a, uint32_t *b) {
switch (state) {
case 1:
a[i % 5] += ((a[(3 + i) % 5] & (a[(2 + i) % 5] ^ a[(1 + i) % 5])) ^ a[(1 + i) % 5]) + b[i] + 0x5a827999 + rol(a[(4 + i) % 5], 5);
a[(3 + i) % 5] = rol(a[(3 + i) % 5], 30);
break;
case 2:
b[i] = blk(b, i);
a[(1 + i) % 5] += ((a[(4 + i) % 5] & (a[(3 + i) % 5] ^ a[(2 + i) % 5])) ^ a[(2 + i) % 5]) + b[i] + 0x5a827999 + rol(a[(5 + i) % 5], 5);
a[(4 + i) % 5] = rol(a[(4 + i) % 5], 30);
break;
case 3:
b[(i + 4) % 16] = blk(b, (i + 4) % 16);
a[i % 5] += (a[(3 + i) % 5] ^ a[(2 + i) % 5] ^ a[(1 + i) % 5]) + b[(i + 4) % 16] + 0x6ed9eba1 + rol(a[(4 + i) % 5], 5);
a[(3 + i) % 5] = rol(a[(3 + i) % 5], 30);
break;
case 4:
b[(i + 8) % 16] = blk(b, (i + 8) % 16);
a[i % 5] += (((a[(3 + i) % 5] | a[(2 + i) % 5]) & a[(1 + i) % 5]) | (a[(3 + i) % 5] & a[(2 + i) % 5])) + b[(i + 8) % 16] + 0x8f1bbcdc + rol(a[(4 + i) % 5], 5);
a[(3 + i) % 5] = rol(a[(3 + i) % 5], 30);
break;
case 5:
b[(i + 12) % 16] = blk(b, (i + 12) % 16);
a[i % 5] += (a[(3 + i) % 5] ^ a[(2 + i) % 5] ^ a[(1 + i) % 5]) + b[(i + 12) % 16] + 0xca62c1d6 + rol(a[(4 + i) % 5], 5);
a[(3 + i) % 5] = rol(a[(3 + i) % 5], 30);
break;
case 6:
b[i] += a[4 - i];
}
}
};
static inline void sha1(uint32_t hash[5], uint32_t b[16]) {
uint32_t a[5] = {hash[4], hash[3], hash[2], hash[1], hash[0]};
static_for<16, Sha1Loop<1>>()(a, b);
static_for<4, Sha1Loop<2>>()(a, b);
static_for<20, Sha1Loop<3>>()(a, b);
static_for<20, Sha1Loop<4>>()(a, b);
static_for<20, Sha1Loop<5>>()(a, b);
static_for<5, Sha1Loop<6>>()(a, hash);
}
static inline void base64(unsigned char *src, char *dst) {
const char *b64 = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
for (int i = 0; i < 18; i += 3) {
*dst++ = b64[(src[i] >> 2) & 63];
*dst++ = b64[((src[i] & 3) << 4) | ((src[i + 1] & 240) >> 4)];
*dst++ = b64[((src[i + 1] & 15) << 2) | ((src[i + 2] & 192) >> 6)];
*dst++ = b64[src[i + 2] & 63];
}
*dst++ = b64[(src[18] >> 2) & 63];
*dst++ = b64[((src[18] & 3) << 4) | ((src[19] & 240) >> 4)];
*dst++ = b64[((src[19] & 15) << 2)];
*dst++ = '=';
}
public:
static inline void generate(const char input[24], char output[28]) {
uint32_t b_output[5] = {
0x67452301, 0xefcdab89, 0x98badcfe, 0x10325476, 0xc3d2e1f0
};
uint32_t b_input[16] = {
0, 0, 0, 0, 0, 0, 0x32353845, 0x41464135, 0x2d453931, 0x342d3437, 0x44412d39,
0x3543412d, 0x43354142, 0x30444338, 0x35423131, 0x80000000
};
for (int i = 0; i < 6; i++) {
b_input[i] = (input[4 * i + 3] & 0xff) | (input[4 * i + 2] & 0xff) << 8 | (input[4 * i + 1] & 0xff) << 16 | (input[4 * i + 0] & 0xff) << 24;
}
sha1(b_output, b_input);
uint32_t last_b[16] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 480};
sha1(b_output, last_b);
for (int i = 0; i < 5; i++) {
uint32_t tmp = b_output[i];
char *bytes = (char *) &b_output[i];
bytes[3] = (char) (tmp & 0xff);
bytes[2] = (char) ((tmp >> 8) & 0xff);
bytes[1] = (char) ((tmp >> 16) & 0xff);
bytes[0] = (char) ((tmp >> 24) & 0xff);
}
base64((unsigned char *) b_output, output);
}
};
}
#endif // UWS_WEBSOCKETHANDSHAKE_H

View File

@ -0,0 +1,443 @@
/*
* Authored by Alex Hultman, 2018-2019.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef UWS_WEBSOCKETPROTOCOL_H
#define UWS_WEBSOCKETPROTOCOL_H
#include <cstdint>
#include <cstring>
#include <cstdlib>
namespace uWS {
enum OpCode : unsigned char {
TEXT = 1,
BINARY = 2,
CLOSE = 8,
PING = 9,
PONG = 10
};
enum {
CLIENT,
SERVER
};
// 24 bytes perfectly
template <bool isServer>
struct WebSocketState {
public:
static const unsigned int SHORT_MESSAGE_HEADER = isServer ? 6 : 2;
static const unsigned int MEDIUM_MESSAGE_HEADER = isServer ? 8 : 4;
static const unsigned int LONG_MESSAGE_HEADER = isServer ? 14 : 10;
// 16 bytes
struct State {
unsigned int wantsHead : 1;
unsigned int spillLength : 4;
int opStack : 2; // -1, 0, 1
unsigned int lastFin : 1;
// 15 bytes
unsigned char spill[LONG_MESSAGE_HEADER - 1];
OpCode opCode[2];
State() {
wantsHead = true;
spillLength = 0;
opStack = -1;
lastFin = true;
}
} state;
// 8 bytes
unsigned int remainingBytes = 0;
char mask[isServer ? 4 : 1];
};
namespace protocol {
template <typename T>
T bit_cast(char *c) {
T val;
memcpy(&val, c, sizeof(T));
return val;
}
/* Byte swap for little-endian systems */
template <typename T>
T cond_byte_swap(T value) {
uint32_t endian_test = 1;
if (*((char *)&endian_test)) {
union {
T i;
uint8_t b[sizeof(T)];
} src = { value }, dst;
for (unsigned int i = 0; i < sizeof(value); i++) {
dst.b[i] = src.b[sizeof(value) - 1 - i];
}
return dst.i;
}
return value;
}
// Based on utf8_check.c by Markus Kuhn, 2005
// https://www.cl.cam.ac.uk/~mgk25/ucs/utf8_check.c
// Optimized for predominantly 7-bit content by Alex Hultman, 2016
// Licensed as Zlib, like the rest of this project
static bool isValidUtf8(unsigned char *s, size_t length)
{
for (unsigned char *e = s + length; s != e; ) {
if (s + 4 <= e) {
uint32_t tmp;
memcpy(&tmp, s, 4);
if ((tmp & 0x80808080) == 0) {
s += 4;
continue;
}
}
while (!(*s & 0x80)) {
if (++s == e) {
return true;
}
}
if ((s[0] & 0x60) == 0x40) {
if (s + 1 >= e || (s[1] & 0xc0) != 0x80 || (s[0] & 0xfe) == 0xc0) {
return false;
}
s += 2;
} else if ((s[0] & 0xf0) == 0xe0) {
if (s + 2 >= e || (s[1] & 0xc0) != 0x80 || (s[2] & 0xc0) != 0x80 ||
(s[0] == 0xe0 && (s[1] & 0xe0) == 0x80) || (s[0] == 0xed && (s[1] & 0xe0) == 0xa0)) {
return false;
}
s += 3;
} else if ((s[0] & 0xf8) == 0xf0) {
if (s + 3 >= e || (s[1] & 0xc0) != 0x80 || (s[2] & 0xc0) != 0x80 || (s[3] & 0xc0) != 0x80 ||
(s[0] == 0xf0 && (s[1] & 0xf0) == 0x80) || (s[0] == 0xf4 && s[1] > 0x8f) || s[0] > 0xf4) {
return false;
}
s += 4;
} else {
return false;
}
}
return true;
}
struct CloseFrame {
uint16_t code;
char *message;
size_t length;
};
static inline CloseFrame parseClosePayload(char *src, size_t length) {
CloseFrame cf = {};
if (length >= 2) {
memcpy(&cf.code, src, 2);
cf = {cond_byte_swap<uint16_t>(cf.code), src + 2, length - 2};
if (cf.code < 1000 || cf.code > 4999 || (cf.code > 1011 && cf.code < 4000) ||
(cf.code >= 1004 && cf.code <= 1006) || !isValidUtf8((unsigned char *) cf.message, cf.length)) {
return {};
}
}
return cf;
}
static inline size_t formatClosePayload(char *dst, uint16_t code, const char *message, size_t length) {
if (code) {
code = cond_byte_swap<uint16_t>(code);
memcpy(dst, &code, 2);
/* It is invalid to pass nullptr to memcpy, even though length is 0 */
if (message) {
memcpy(dst + 2, message, length);
}
return length + 2;
}
return 0;
}
static inline size_t messageFrameSize(size_t messageSize) {
if (messageSize < 126) {
return 2 + messageSize;
} else if (messageSize <= UINT16_MAX) {
return 4 + messageSize;
}
return 10 + messageSize;
}
enum {
SND_CONTINUATION = 1,
SND_NO_FIN = 2,
SND_COMPRESSED = 64
};
template <bool isServer>
static inline size_t formatMessage(char *dst, const char *src, size_t length, OpCode opCode, size_t reportedLength, bool compressed) {
size_t messageLength;
size_t headerLength;
if (reportedLength < 126) {
headerLength = 2;
dst[1] = (char) reportedLength;
} else if (reportedLength <= UINT16_MAX) {
headerLength = 4;
dst[1] = 126;
uint16_t tmp = cond_byte_swap<uint16_t>((uint16_t) reportedLength);
memcpy(&dst[2], &tmp, sizeof(uint16_t));
} else {
headerLength = 10;
dst[1] = 127;
uint64_t tmp = cond_byte_swap<uint64_t>((uint64_t) reportedLength);
memcpy(&dst[2], &tmp, sizeof(uint64_t));
}
int flags = 0;
dst[0] = (char) ((flags & SND_NO_FIN ? 0 : 128) | (compressed ? SND_COMPRESSED : 0));
if (!(flags & SND_CONTINUATION)) {
dst[0] |= (char) opCode;
}
char mask[4];
if (!isServer) {
dst[1] |= 0x80;
uint32_t random = rand();
memcpy(mask, &random, 4);
memcpy(dst + headerLength, &random, 4);
headerLength += 4;
}
messageLength = headerLength + length;
memcpy(dst + headerLength, src, length);
if (!isServer) {
// overwrites up to 3 bytes outside of the given buffer!
//WebSocketProtocol<isServer>::unmaskInplace(dst + headerLength, dst + headerLength + length, mask);
// this is not optimal
char *start = dst + headerLength;
char *stop = start + length;
int i = 0;
while (start != stop) {
(*start++) ^= mask[i++ % 4];
}
}
return messageLength;
}
}
// essentially this is only a parser
template <const bool isServer, typename Impl>
struct WIN32_EXPORT WebSocketProtocol {
public:
static const unsigned int SHORT_MESSAGE_HEADER = isServer ? 6 : 2;
static const unsigned int MEDIUM_MESSAGE_HEADER = isServer ? 8 : 4;
static const unsigned int LONG_MESSAGE_HEADER = isServer ? 14 : 10;
protected:
static inline bool isFin(char *frame) {return *((unsigned char *) frame) & 128;}
static inline unsigned char getOpCode(char *frame) {return *((unsigned char *) frame) & 15;}
static inline unsigned char payloadLength(char *frame) {return ((unsigned char *) frame)[1] & 127;}
static inline bool rsv23(char *frame) {return *((unsigned char *) frame) & 48;}
static inline bool rsv1(char *frame) {return *((unsigned char *) frame) & 64;}
static inline void unmaskImprecise(char *dst, char *src, char *mask, unsigned int length) {
for (unsigned int n = (length >> 2) + 1; n; n--) {
*(dst++) = *(src++) ^ mask[0];
*(dst++) = *(src++) ^ mask[1];
*(dst++) = *(src++) ^ mask[2];
*(dst++) = *(src++) ^ mask[3];
}
}
static inline void unmaskImpreciseCopyMask(char *dst, char *src, char *maskPtr, unsigned int length) {
char mask[4] = {maskPtr[0], maskPtr[1], maskPtr[2], maskPtr[3]};
unmaskImprecise(dst, src, mask, length);
}
static inline void rotateMask(unsigned int offset, char *mask) {
char originalMask[4] = {mask[0], mask[1], mask[2], mask[3]};
mask[(0 + offset) % 4] = originalMask[0];
mask[(1 + offset) % 4] = originalMask[1];
mask[(2 + offset) % 4] = originalMask[2];
mask[(3 + offset) % 4] = originalMask[3];
}
static inline void unmaskInplace(char *data, char *stop, char *mask) {
while (data < stop) {
*(data++) ^= mask[0];
*(data++) ^= mask[1];
*(data++) ^= mask[2];
*(data++) ^= mask[3];
}
}
template <unsigned int MESSAGE_HEADER, typename T>
static inline bool consumeMessage(T payLength, char *&src, unsigned int &length, WebSocketState<isServer> *wState, void *user) {
if (getOpCode(src)) {
if (wState->state.opStack == 1 || (!wState->state.lastFin && getOpCode(src) < 2)) {
Impl::forceClose(wState, user);
return true;
}
wState->state.opCode[++wState->state.opStack] = (OpCode) getOpCode(src);
} else if (wState->state.opStack == -1) {
Impl::forceClose(wState, user);
return true;
}
wState->state.lastFin = isFin(src);
if (Impl::refusePayloadLength(payLength, wState, user)) {
Impl::forceClose(wState, user);
return true;
}
if (payLength + MESSAGE_HEADER <= length) {
if (isServer) {
unmaskImpreciseCopyMask(src + MESSAGE_HEADER - 4, src + MESSAGE_HEADER, src + MESSAGE_HEADER - 4, (unsigned int) payLength);
if (Impl::handleFragment(src + MESSAGE_HEADER - 4, payLength, 0, wState->state.opCode[wState->state.opStack], isFin(src), wState, user)) {
return true;
}
} else {
if (Impl::handleFragment(src + MESSAGE_HEADER, payLength, 0, wState->state.opCode[wState->state.opStack], isFin(src), wState, user)) {
return true;
}
}
if (isFin(src)) {
wState->state.opStack--;
}
src += payLength + MESSAGE_HEADER;
length -= (unsigned int) (payLength + MESSAGE_HEADER);
wState->state.spillLength = 0;
return false;
} else {
wState->state.spillLength = 0;
wState->state.wantsHead = false;
wState->remainingBytes = (unsigned int) (payLength - length + MESSAGE_HEADER);
bool fin = isFin(src);
if (isServer) {
memcpy(wState->mask, src + MESSAGE_HEADER - 4, 4);
unmaskImprecise(src, src + MESSAGE_HEADER, wState->mask, length - MESSAGE_HEADER);
rotateMask(4 - (length - MESSAGE_HEADER) % 4, wState->mask);
} else {
src += MESSAGE_HEADER;
}
Impl::handleFragment(src, length - MESSAGE_HEADER, wState->remainingBytes, wState->state.opCode[wState->state.opStack], fin, wState, user);
return true;
}
}
static inline bool consumeContinuation(char *&src, unsigned int &length, WebSocketState<isServer> *wState, void *user) {
if (wState->remainingBytes <= length) {
if (isServer) {
int n = wState->remainingBytes >> 2;
unmaskInplace(src, src + n * 4, wState->mask);
for (int i = 0, s = wState->remainingBytes % 4; i < s; i++) {
src[n * 4 + i] ^= wState->mask[i];
}
}
if (Impl::handleFragment(src, wState->remainingBytes, 0, wState->state.opCode[wState->state.opStack], wState->state.lastFin, wState, user)) {
return false;
}
if (wState->state.lastFin) {
wState->state.opStack--;
}
src += wState->remainingBytes;
length -= wState->remainingBytes;
wState->state.wantsHead = true;
return true;
} else {
if (isServer) {
unmaskInplace(src, src + ((length >> 2) + 1) * 4, wState->mask);
}
wState->remainingBytes -= length;
if (Impl::handleFragment(src, length, wState->remainingBytes, wState->state.opCode[wState->state.opStack], wState->state.lastFin, wState, user)) {
return false;
}
if (isServer && length % 4) {
rotateMask(4 - (length % 4), wState->mask);
}
return false;
}
}
public:
WebSocketProtocol() {
}
static inline void consume(char *src, unsigned int length, WebSocketState<isServer> *wState, void *user) {
if (wState->state.spillLength) {
src -= wState->state.spillLength;
length += wState->state.spillLength;
memcpy(src, wState->state.spill, wState->state.spillLength);
}
if (wState->state.wantsHead) {
parseNext:
while (length >= SHORT_MESSAGE_HEADER) {
// invalid reserved bits / invalid opcodes / invalid control frames / set compressed frame
if ((rsv1(src) && !Impl::setCompressed(wState, user)) || rsv23(src) || (getOpCode(src) > 2 && getOpCode(src) < 8) ||
getOpCode(src) > 10 || (getOpCode(src) > 2 && (!isFin(src) || payloadLength(src) > 125))) {
Impl::forceClose(wState, user);
return;
}
if (payloadLength(src) < 126) {
if (consumeMessage<SHORT_MESSAGE_HEADER, uint8_t>(payloadLength(src), src, length, wState, user)) {
return;
}
} else if (payloadLength(src) == 126) {
if (length < MEDIUM_MESSAGE_HEADER) {
break;
} else if(consumeMessage<MEDIUM_MESSAGE_HEADER, uint16_t>(protocol::cond_byte_swap<uint16_t>(protocol::bit_cast<uint16_t>(src + 2)), src, length, wState, user)) {
return;
}
} else if (length < LONG_MESSAGE_HEADER) {
break;
} else if (consumeMessage<LONG_MESSAGE_HEADER, uint64_t>(protocol::cond_byte_swap<uint64_t>(protocol::bit_cast<uint64_t>(src + 2)), src, length, wState, user)) {
return;
}
}
if (length) {
memcpy(wState->state.spill, src, length);
wState->state.spillLength = length & 0xf;
}
} else if (consumeContinuation(src, length, wState, user)) {
goto parseNext;
}
}
static const int CONSUME_POST_PADDING = 4;
static const int CONSUME_PRE_PADDING = LONG_MESSAGE_HEADER - 1;
};
}
#endif // UWS_WEBSOCKETPROTOCOL_H

View File

@ -0,0 +1,23 @@
Boost Software License - Version 1.0 - August 17th, 2003
Permission is hereby granted, free of charge, to any person or organization
obtaining a copy of the software and accompanying documentation covered by
this license (the "Software") to use, reproduce, display, distribute,
execute, and transmit the Software, and to prepare derivative works of the
Software, and to permit third-parties to whom the Software is furnished to
do so, all subject to the following:
The copyright notices in the Software and this entire statement, including
the above license grant, this restriction and the following disclaimer,
must be included in all copies of the Software, in whole or in part, and
all derivative works of the Software, unless such copies or derivative
works are solely in the form of machine-executable object code generated by
a source language processor.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT
SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.

File diff suppressed because it is too large Load Diff