diff --git a/CMakeLists.txt b/CMakeLists.txt index bbf014e56..08295d56b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,26 +1,25 @@ project(PID) - cmake_minimum_required (VERSION 3.5) - add_definitions(-std=c++11) - -set(CXX_FLAGS "-Wall") +set(CXX_FLAGS "-W1") set(CMAKE_CXX_FLAGS, "${CXX_FLAGS}") +if(${CMAKE_SYSTEM_NAME} MATCHES "Windows") + + set(sources src/PID.cpp src/main.cpp src/uWS/Extensions.cpp src/uWS/Group.cpp src/uWS/WebSocketImpl.cpp src/uWS/Networking.cpp src/uWS/Hub.cpp src/uWS/Node.cpp src/uWS/WebSocket.cpp src/uWS/HTTPSocket.cpp src/uWS/Socket.cpp src/uWS/uUV.cpp) + set_source_files_properties(${sources} PROPERTIES COMPILE_FLAGS "-D_USE_MATH_DEFINES") -set(sources src/PID.cpp src/main.cpp) - +else(${CMAKE_SYSTEM_NAME} MATCHES "Windows") + set(sources src/PID.cpp src/main.cpp) +endif(${CMAKE_SYSTEM_NAME} MATCHES "Windows") if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin") - -include_directories(/usr/local/include) -include_directories(/usr/local/opt/openssl/include) -link_directories(/usr/local/lib) -link_directories(/usr/local/opt/openssl/lib) -link_directories(/usr/local/Cellar/libuv/1.11.0/lib) - + include_directories(/usr/local/include) + include_directories(/usr/local/opt/openssl/include) + link_directories(/usr/local/lib) + link_directories(/usr/local/opt/openssl/lib) + link_directories(/usr/local/Cellar/libuv/1.11.0/lib) endif(${CMAKE_SYSTEM_NAME} MATCHES "Darwin") - - add_executable(pid ${sources}) - -target_link_libraries(pid z ssl uv uWS) +if(NOT ${CMAKE_SYSTEM_NAME} MATCHES "Windows") + target_link_libraries(pid z ssl uv uWS) +endif(NOT ${CMAKE_SYSTEM_NAME} MATCHES "Windows") \ No newline at end of file diff --git a/CMakeSettings.json b/CMakeSettings.json new file mode 100644 index 000000000..fa524fe3d --- /dev/null +++ b/CMakeSettings.json @@ -0,0 +1,44 @@ +{ + // See https://go.microsoft.com//fwlink//?linkid=834763 for more information about this file. + "configurations": [ + { + "name": "x86-Debug", + "generator": "Visual Studio 15 2017", + "configurationType": "Debug", + "buildRoot": "${env.LOCALAPPDATA}\\CMakeBuild\\${workspaceHash}\\build\\${name}", + "cmakeCommandArgs": "", + "buildCommandArgs": "-m -v:minimal", + + "variables": [ + { + "name": "DCMAKE_TOOLCHAIN_FILE", + "value": "C:/vcpkg/scripts/buildsystems/vcpkg.cmake" + } + ] + }, + { + "name": "x86-Release", + "generator": "Visual Studio 15 2017", + "configurationType" : "Release", + "buildRoot": "${env.LOCALAPPDATA}\\CMakeBuild\\${workspaceHash}\\build\\${name}", + "cmakeCommandArgs": "", + "buildCommandArgs": "-m -v:minimal" + }, + { + "name": "x64-Debug", + "generator": "Visual Studio 15 2017 Win64", + "configurationType" : "Debug", + "buildRoot": "${env.LOCALAPPDATA}\\CMakeBuild\\${workspaceHash}\\build\\${name}", + "cmakeCommandArgs": "", + "buildCommandArgs": "-m -v:minimal" + }, + { + "name": "x64-Release", + "generator": "Visual Studio 15 2017 Win64", + "configurationType" : "Release", + "buildRoot": "${env.LOCALAPPDATA}\\CMakeBuild\\${workspaceHash}\\build\\${name}", + "cmakeCommandArgs": "", + "buildCommandArgs": "-m -v:minimal" + } + ] +} \ No newline at end of file diff --git a/README.md b/README.md index 9728ad7fa..ab068e926 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,17 @@ Self-Driving Car Engineer Nanodegree Program 1. Clone this repo. 2. Make a build directory: `mkdir build && cd build` 3. Compile: `cmake .. && make` -4. Run it: `./pid`. +4. Run it: `./pid`. + +## Windows Install Instructions + +1. Install, in your root `c:/` directory, vcpkg https://github.com/Microsoft/vcpkg (15 - 30 minutes) +2. Be sure while installing vcpkg to carefully follow all instructions! This is NOT an easy install process. +3. Install python 2.7 (dependency for libuv) +4. cd to directory with vcpkg .exe and `./vcpkg install uWebsockets` (20 min, mostly automatic) +5. Open CMakeSetting.json, check if `C:/vcpkg/scripts/buildsystems/vcpkg.cmake` is the correct directory to your vcpkg and `DCMAKE_TOOLCHAIN_FILE` matches the output from vcpkg integrate. +6. Open in VS17 community edition, build pid.exe in x86 debug. + ## Editor Settings diff --git a/src/PID.cpp b/src/PID.cpp index 4290e044e..9fc8f8083 100644 --- a/src/PID.cpp +++ b/src/PID.cpp @@ -17,5 +17,6 @@ void PID::UpdateError(double cte) { } double PID::TotalError() { + return 1; } diff --git a/src/main.cpp b/src/main.cpp index 936ccd69f..932cb8046 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -1,9 +1,14 @@ -#include +#define _USE_MATH_DEFINES + +#include "uWS/uWS.h" #include #include "json.hpp" #include "PID.h" #include + +using namespace std; + // for convenience using json = nlohmann::json; @@ -12,28 +17,35 @@ constexpr double pi() { return M_PI; } double deg2rad(double x) { return x * pi() / 180; } double rad2deg(double x) { return x * 180 / pi(); } + + // Checks if the SocketIO event has JSON data. // If there is data the JSON object in string format will be returned, // else the empty string "" will be returned. -std::string hasData(std::string s) { +std::stringstream hasData(std::string s) { auto found_null = s.find("null"); auto b1 = s.find_first_of("["); auto b2 = s.find_last_of("]"); if (found_null != std::string::npos) { - return ""; + return std::stringstream(); } else if (b1 != std::string::npos && b2 != std::string::npos) { - return s.substr(b1, b2 - b1 + 1); + std::stringstream tmp = std::stringstream(); + tmp.str(s.substr(b1, b2 - b1 + 1)); + return tmp; } - return ""; + return std::stringstream(); } + int main() { + uWS::Hub h; PID pid; // TODO: Initialize the pid variable. + h.onMessage([&pid](uWS::WebSocket ws, char *data, size_t length, uWS::OpCode opCode) { // "42" at the start of the message means there's a websocket message event. @@ -41,8 +53,8 @@ int main() // The 2 signifies a websocket event if (length && length > 2 && data[0] == '4' && data[1] == '2') { - auto s = hasData(std::string(data).substr(0, length)); - if (s != "") { + auto s = hasData(std::string(data)); + if (s.str() != "") { auto j = json::parse(s); std::string event = j[0].get(); if (event == "telemetry") { @@ -50,14 +62,14 @@ int main() double cte = std::stod(j[1]["cte"].get()); double speed = std::stod(j[1]["speed"].get()); double angle = std::stod(j[1]["steering_angle"].get()); - double steer_value; /* * TODO: Calcuate steering value here, remember the steering value is * [-1, 1]. * NOTE: Feel free to play around with the throttle and speed. Maybe use * another PID controller to control the speed! */ - + + // DEBUG std::cout << "CTE: " << cte << " Steering Value: " << steer_value << std::endl; @@ -66,42 +78,28 @@ int main() msgJson["throttle"] = 0.3; auto msg = "42[\"steer\"," + msgJson.dump() + "]"; std::cout << msg << std::endl; - ws.send(msg.data(), msg.length(), uWS::OpCode::TEXT); + (ws).send(msg.data(), msg.length(), uWS::OpCode::TEXT); } - } else { + } + else { // Manual driving std::string msg = "42[\"manual\",{}]"; - ws.send(msg.data(), msg.length(), uWS::OpCode::TEXT); + (ws).send(msg.data(), msg.length(), uWS::OpCode::TEXT); } } }); - // We don't need this since we're not using HTTP but if it's removed the program - // doesn't compile :-( - h.onHttpRequest([](uWS::HttpResponse *res, uWS::HttpRequest req, char *data, size_t, size_t) { - const std::string s = "

Hello world!

"; - if (req.getUrl().valueLength == 1) - { - res->end(s.data(), s.length()); - } - else - { - // i guess this should be done more gracefully? - res->end(nullptr, 0); - } - }); - h.onConnection([&h](uWS::WebSocket ws, uWS::HttpRequest req) { std::cout << "Connected!!!" << std::endl; }); h.onDisconnection([&h](uWS::WebSocket ws, int code, char *message, size_t length) { - ws.close(); + (ws).close(); std::cout << "Disconnected" << std::endl; }); int port = 4567; - if (h.listen(port)) + if (h.listen("0.0.0.0", port)) { std::cout << "Listening to port " << port << std::endl; } @@ -111,4 +109,4 @@ int main() return -1; } h.run(); -} +} \ No newline at end of file diff --git a/src/uWS/Extensions.cpp b/src/uWS/Extensions.cpp new file mode 100644 index 000000000..ef8f9da7e --- /dev/null +++ b/src/uWS/Extensions.cpp @@ -0,0 +1,131 @@ +#include "Extensions.h" + +namespace uWS { + +enum ExtensionTokens { + TOK_PERMESSAGE_DEFLATE = 1838, + TOK_SERVER_NO_CONTEXT_TAKEOVER = 2807, + TOK_CLIENT_NO_CONTEXT_TAKEOVER = 2783, + TOK_SERVER_MAX_WINDOW_BITS = 2372, + TOK_CLIENT_MAX_WINDOW_BITS = 2348 +}; + +class ExtensionsParser { +private: + int *lastInteger = nullptr; + +public: + bool perMessageDeflate = false; + bool serverNoContextTakeover = false; + bool clientNoContextTakeover = false; + int serverMaxWindowBits = 0; + int clientMaxWindowBits = 0; + + int getToken(const char *&in, const char *stop); + ExtensionsParser(const char *data, size_t length); +}; + +int ExtensionsParser::getToken(const char *&in, const char *stop) { + while (!isalnum(*in) && in != stop) { + in++; + } + + int hashedToken = 0; + while (isalnum(*in) || *in == '-' || *in == '_') { + if (isdigit(*in)) { + hashedToken = hashedToken * 10 - (*in - '0'); + } else { + hashedToken += *in; + } + in++; + } + return hashedToken; +} + +ExtensionsParser::ExtensionsParser(const char *data, size_t length) { + const char *stop = data + length; + int token = 1; + for (; token && token != TOK_PERMESSAGE_DEFLATE; token = getToken(data, stop)); + + perMessageDeflate = (token == TOK_PERMESSAGE_DEFLATE); + while ((token = getToken(data, stop))) { + switch (token) { + case TOK_PERMESSAGE_DEFLATE: + return; + case TOK_SERVER_NO_CONTEXT_TAKEOVER: + serverNoContextTakeover = true; + break; + case TOK_CLIENT_NO_CONTEXT_TAKEOVER: + clientNoContextTakeover = true; + break; + case TOK_SERVER_MAX_WINDOW_BITS: + serverMaxWindowBits = 1; + lastInteger = &serverMaxWindowBits; + break; + case TOK_CLIENT_MAX_WINDOW_BITS: + clientMaxWindowBits = 1; + lastInteger = &clientMaxWindowBits; + break; + default: + if (token < 0 && lastInteger) { + *lastInteger = -token; + } + break; + } + } +} + +template +ExtensionsNegotiator::ExtensionsNegotiator(int wantedOptions) { + options = wantedOptions; +} + +template +std::string ExtensionsNegotiator::generateOffer() { + std::string extensionsOffer; + if (options & Options::PERMESSAGE_DEFLATE) { + extensionsOffer += "permessage-deflate"; + + if (options & Options::CLIENT_NO_CONTEXT_TAKEOVER) { + extensionsOffer += "; client_no_context_takeover"; + } + + if (options & Options::SERVER_NO_CONTEXT_TAKEOVER) { + extensionsOffer += "; server_no_context_takeover"; + } + } + + return extensionsOffer; +} + +template +void ExtensionsNegotiator::readOffer(std::string offer) { + if (isServer) { + ExtensionsParser extensionsParser(offer.data(), offer.length()); + if ((options & PERMESSAGE_DEFLATE) && extensionsParser.perMessageDeflate) { + if (extensionsParser.clientNoContextTakeover || (options & CLIENT_NO_CONTEXT_TAKEOVER)) { + options |= CLIENT_NO_CONTEXT_TAKEOVER; + } + + if (extensionsParser.serverNoContextTakeover) { + options |= SERVER_NO_CONTEXT_TAKEOVER; + } else { + options &= ~SERVER_NO_CONTEXT_TAKEOVER; + } + } else { + options &= ~PERMESSAGE_DEFLATE; + } + } else { + // todo! + } +} + +template +int ExtensionsNegotiator::getNegotiatedOptions() { + return options; +} + +template class ExtensionsNegotiator; +template class ExtensionsNegotiator; + +} diff --git a/src/uWS/Extensions.h b/src/uWS/Extensions.h new file mode 100644 index 000000000..763b4d2cb --- /dev/null +++ b/src/uWS/Extensions.h @@ -0,0 +1,29 @@ +#ifndef EXTENSIONS_UWS_H +#define EXTENSIONS_UWS_H + +#include + +namespace uWS { + +enum Options : unsigned int { + NO_OPTIONS = 0, + PERMESSAGE_DEFLATE = 1, + SERVER_NO_CONTEXT_TAKEOVER = 2, + CLIENT_NO_CONTEXT_TAKEOVER = 4, + NO_DELAY = 8 +}; + +template +class ExtensionsNegotiator { +private: + int options; +public: + ExtensionsNegotiator(int wantedOptions); + std::string generateOffer(); + void readOffer(std::string offer); + int getNegotiatedOptions(); +}; + +} + +#endif // EXTENSIONS_UWS_H diff --git a/src/uWS/Group.cpp b/src/uWS/Group.cpp new file mode 100644 index 000000000..a03940550 --- /dev/null +++ b/src/uWS/Group.cpp @@ -0,0 +1,282 @@ +#include "Group.h" +#include "Hub.h" + +namespace uWS { + +template +void Group::setUserData(void *user) { + this->userData = user; +} + +template +void *Group::getUserData() { + return userData; +} + +template +void Group::timerCallback(uv_timer_t *timer) { + Group *group = (Group *) timer->data; + + group->forEach([](uWS::WebSocket ws) { + typename uWS::WebSocket::Data *webSocketData = (typename uWS::WebSocket::Data *) ws.getSocketData(); + if (webSocketData->hasOutstandingPong) { + ws.terminate(); + } else { + webSocketData->hasOutstandingPong = true; + } + }); + + if (group->userPingMessage.length()) { + group->broadcast(group->userPingMessage.data(), group->userPingMessage.length(), OpCode::TEXT); + } else { + group->broadcast(nullptr, 0, OpCode::PING); + } +} + +template +void Group::startAutoPing(int intervalMs, std::string userMessage) { + timer = new uv_timer_t; + uv_timer_init(loop, timer); + timer->data = this; + uv_timer_start(timer, timerCallback, intervalMs, intervalMs); + userPingMessage = userMessage; +} + +// WIP +template +void Group::addHttpSocket(uv_poll_t *httpSocket) { + + // always clear last chain! + ((uS::SocketData *) httpSocket->data)->next = nullptr; + ((uS::SocketData *) httpSocket->data)->prev = nullptr; + + if (httpSocketHead) { + uS::SocketData *nextData = (uS::SocketData *) httpSocketHead->data; + nextData->prev = httpSocket; + uS::SocketData *data = (uS::SocketData *) httpSocket->data; + data->next = httpSocketHead; + } else { + httpTimer = new uv_timer_t; + uv_timer_init(hub->getLoop(), httpTimer); + httpTimer->data = this; + uv_timer_start(httpTimer, [](uv_timer_t *httpTimer) { + Group *group = (Group *) httpTimer->data; + group->forEachHttpSocket([](HttpSocket httpSocket) { + if (httpSocket.getData()->missedDeadline) { + // recursive? don't think so! + httpSocket.terminate(); + } else if (!httpSocket.getData()->outstandingResponsesHead) { + httpSocket.getData()->missedDeadline = true; + } + }); + }, 1000, 1000); + } + httpSocketHead = httpSocket; +} + +// WIP +template +void Group::removeHttpSocket(uv_poll_t *httpSocket) { + uS::SocketData *socketData = (uS::SocketData *) httpSocket->data; + if (iterators.size()) { + iterators.top() = socketData->next; + } + if (socketData->prev == socketData->next) { + httpSocketHead = (uv_poll_t *) nullptr; + + uv_timer_stop(httpTimer); + uv_close(httpTimer, [](uv_handle_t *handle) { + delete (uv_timer_t *) handle; + }); + + } else { + if (socketData->prev) { + ((uS::SocketData *) socketData->prev->data)->next = socketData->next; + } else { + httpSocketHead = socketData->next; + } + if (socketData->next) { + ((uS::SocketData *) socketData->next->data)->prev = socketData->prev; + } + } +} + +template +void Group::addWebSocket(uv_poll_t *webSocket) { + + // always clear last chain! + ((uS::SocketData *) webSocket->data)->next = nullptr; + ((uS::SocketData *) webSocket->data)->prev = nullptr; + + if (webSocketHead) { + uS::SocketData *nextData = (uS::SocketData *) webSocketHead->data; + nextData->prev = webSocket; + uS::SocketData *data = (uS::SocketData *) webSocket->data; + data->next = webSocketHead; + } + webSocketHead = webSocket; +} + +template +void Group::removeWebSocket(uv_poll_t *webSocket) { + uS::SocketData *socketData = (uS::SocketData *) webSocket->data; + if (iterators.size()) { + iterators.top() = socketData->next; + } + if (socketData->prev == socketData->next) { + webSocketHead = (uv_poll_t *) nullptr; + } else { + if (socketData->prev) { + ((uS::SocketData *) socketData->prev->data)->next = socketData->next; + } else { + webSocketHead = socketData->next; + } + if (socketData->next) { + ((uS::SocketData *) socketData->next->data)->prev = socketData->prev; + } + } +} + +template +Group::Group(int extensionOptions, Hub *hub, uS::NodeData *nodeData) : uS::NodeData(*nodeData), hub(hub), extensionOptions(extensionOptions) { + connectionHandler = [](WebSocket, HttpRequest) {}; + messageHandler = [](WebSocket, char *, size_t, OpCode) {}; + disconnectionHandler = [](WebSocket, int, char *, size_t) {}; + pingHandler = pongHandler = [](WebSocket, char *, size_t) {}; + errorHandler = [](errorType) {}; + httpRequestHandler = [](HttpResponse *, HttpRequest, char *, size_t, size_t) {}; + httpConnectionHandler = [](HttpSocket) {}; + httpDisconnectionHandler = [](HttpSocket) {}; + httpCancelledRequestHandler = [](HttpResponse *) {}; + httpDataHandler = [](HttpResponse *, char *, size_t, size_t) {}; + + this->extensionOptions |= CLIENT_NO_CONTEXT_TAKEOVER | SERVER_NO_CONTEXT_TAKEOVER; +} + +template +void Group::stopListening() { + if (isServer) { + uS::ListenData *listenData = (uS::ListenData *) user; + if (listenData) { + if (listenData->listenPoll) + uS::Socket(listenData->listenPoll).close(); + else if (listenData->listenTimer) { + uv_os_sock_t fd = listenData->sock; + uv_timer_stop(listenData->listenTimer); + ::close(fd); + + SSL *ssl = listenData->ssl; + if (ssl) { + SSL_free(ssl); + } + + uv_close(listenData->listenTimer, [](uv_handle_t *handle) { + delete handle; + }); + } + delete listenData; + } + } + + if (async) { + uv_close(async, [](uv_handle_t *h) { + delete (uv_async_t *) h; + }); + } +} + +template +void Group::onConnection(std::function, HttpRequest)> handler) { + connectionHandler = handler; +} + +template +void Group::onMessage(std::function, char *, size_t, OpCode)> handler) { + messageHandler = handler; +} + +template +void Group::onDisconnection(std::function, int, char *, size_t)> handler) { + disconnectionHandler = handler; +} + +template +void Group::onPing(std::function, char *, size_t)> handler) { + pingHandler = handler; +} + +template +void Group::onPong(std::function, char *, size_t)> handler) { + pongHandler = handler; +} + +template +void Group::onError(std::function handler) { + errorHandler = handler; +} + +template +void Group::onHttpConnection(std::function)> handler) { + httpConnectionHandler = handler; +} + +template +void Group::onHttpRequest(std::function handler) { + httpRequestHandler = handler; +} + +template +void Group::onHttpData(std::function handler) { + httpDataHandler = handler; +} + +template +void Group::onHttpDisconnection(std::function)> handler) { + httpDisconnectionHandler = handler; +} + +template +void Group::onCancelledHttpRequest(std::function handler) { + httpCancelledRequestHandler = handler; +} + +template +void Group::onHttpUpgrade(std::function, HttpRequest)> handler) { + httpUpgradeHandler = handler; +} + +template +void Group::broadcast(const char *message, size_t length, OpCode opCode) { + typename WebSocket::PreparedMessage *preparedMessage = WebSocket::prepareMessage((char *) message, length, opCode, false); + forEach([preparedMessage](uWS::WebSocket ws) { + ws.sendPrepared(preparedMessage); + }); + WebSocket::finalizeMessage(preparedMessage); +} + +template +void Group::terminate() { + forEach([](uWS::WebSocket ws) { + ws.terminate(); + }); + stopListening(); +} + +template +void Group::close(int code, char *message, size_t length) { + forEach([code, message, length](uWS::WebSocket ws) { + ws.close(code, message, length); + }); + stopListening(); + if (timer) { + uv_timer_stop(timer); + uv_close(timer, [](uv_handle_t *handle) { + delete (uv_timer_t *) handle; + }); + } +} + +template struct Group; +template struct Group; + +} diff --git a/src/uWS/Group.h b/src/uWS/Group.h new file mode 100644 index 000000000..1d4d231a3 --- /dev/null +++ b/src/uWS/Group.h @@ -0,0 +1,124 @@ +#ifndef GROUP_UWS_H +#define GROUP_UWS_H + +#include "WebSocket.h" +#include "HTTPSocket.h" +#include "Extensions.h" +#include +#include + +namespace uWS { + +struct Hub; + +template +struct WIN32_EXPORT Group : uS::NodeData { + friend struct Hub; + std::function, HttpRequest)> connectionHandler; + std::function, char *message, size_t length, OpCode opCode)> messageHandler; + std::function, int code, char *message, size_t length)> disconnectionHandler; + std::function, char *, size_t)> pingHandler; + std::function, char *, size_t)> pongHandler; + + std::function)> httpConnectionHandler; + std::function httpRequestHandler; + std::function httpDataHandler; + std::function httpCancelledRequestHandler; + + std::function)> httpDisconnectionHandler; + std::function, HttpRequest)> httpUpgradeHandler; + + using errorType = typename std::conditional::type; + std::function errorHandler; + + Hub *hub; + int extensionOptions; + uv_timer_t *timer = nullptr; + std::string userPingMessage; + + // todo: cannot be named user, collides with parent! + void *userData = nullptr; + void setUserData(void *user); + void *getUserData(); + void startAutoPing(int intervalMs, std::string userMessage = ""); + static void timerCallback(uv_timer_t *timer); + + uv_poll_t *webSocketHead = nullptr, *httpSocketHead = nullptr; + void addWebSocket(uv_poll_t *webSocket); + void removeWebSocket(uv_poll_t *webSocket); + + uv_timer_t *httpTimer = nullptr; + void addHttpSocket(uv_poll_t *httpSocket); + void removeHttpSocket(uv_poll_t *httpSocket); + + + std::stack iterators; + +protected: + Group(int extensionOptions, Hub *hub, uS::NodeData *nodeData); + void stopListening(); + +public: + void onConnection(std::function, HttpRequest)> handler); + void onMessage(std::function, char *, size_t, OpCode)> handler); + void onDisconnection(std::function, int code, char *message, size_t length)> handler); + void onPing(std::function, char *, size_t)> handler); + void onPong(std::function, char *, size_t)> handler); + void onError(std::function handler); + + void onHttpConnection(std::function)> handler); + void onHttpRequest(std::function handler); + void onHttpData(std::function handler); + void onHttpDisconnection(std::function)> handler); + void onCancelledHttpRequest(std::function handler); + void onHttpUpgrade(std::function, HttpRequest)> handler); + + + void broadcast(const char *message, size_t length, OpCode opCode); + void terminate(); + void close(int code = 1000, char *message = nullptr, size_t length = 0); + using NodeData::addAsync; + + // todo: handle nested forEachs with removeWebSocket + template + void forEach(const F &cb) { + uv_poll_t *iterator = webSocketHead; + iterators.push(iterator); + while (iterator) { + uv_poll_t *lastIterator = iterator; + cb(WebSocket(iterator)); + iterator = iterators.top(); + if (lastIterator == iterator) { + iterator = ((uS::SocketData *) iterator->data)->next; + iterators.top() = iterator; + } + } + iterators.pop(); + } + + // duplicated code for now! + template + void forEachHttpSocket(const F &cb) { + uv_poll_t *iterator = httpSocketHead; + iterators.push(iterator); + while (iterator) { + uv_poll_t *lastIterator = iterator; + cb(HttpSocket(iterator)); + iterator = iterators.top(); + if (lastIterator == iterator) { + iterator = ((uS::SocketData *) iterator->data)->next; + iterators.top() = iterator; + } + } + iterators.pop(); + } +}; + +template +Group *getGroup(uS::Socket s) { + return static_cast *>(s.getSocketData()->nodeData); +} + +} + +#endif // GROUP_UWS_H diff --git a/src/uWS/HTTPSocket.cpp b/src/uWS/HTTPSocket.cpp new file mode 100644 index 000000000..5f98796bf --- /dev/null +++ b/src/uWS/HTTPSocket.cpp @@ -0,0 +1,300 @@ +#include "HTTPSocket.h" +#include "Group.h" +#include "Extensions.h" +#include + +#define MAX_HEADERS 100 +#define MAX_HEADER_BUFFER_SIZE 4096 +#define FORCE_SLOW_PATH false + +#include + +namespace uWS { + +// UNSAFETY NOTE: assumes *end == '\r' (might unref end pointer) +char *getHeaders(char *buffer, char *end, Header *headers, size_t maxHeaders) { + for (unsigned int i = 0; i < maxHeaders; i++) { + for (headers->key = buffer; (*buffer != ':') & (*buffer > 32); *(buffer++) |= 32); + if (*buffer == '\r') { + if ((buffer != end) & (buffer[1] == '\n') & (i > 0)) { + headers->key = nullptr; + return buffer + 2; + } else { + return nullptr; + } + } else { + headers->keyLength = buffer - headers->key; + for (buffer++; (*buffer == ':' || *buffer < 33) && *buffer != '\r'; buffer++); + headers->value = buffer; + buffer = (char *) memchr(buffer, '\r', end - buffer); //for (; *buffer != '\r'; buffer++); + if (buffer /*!= end*/ && buffer[1] == '\n') { + headers->valueLength = buffer - headers->value; + buffer += 2; + headers++; + } else { + return nullptr; + } + } + } + return nullptr; +} + +// UNSAFETY NOTE: assumes 24 byte input length +static void base64(unsigned char *src, char *dst) { + static const char *b64 = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + for (int i = 0; i < 18; i += 3) { + *dst++ = b64[(src[i] >> 2) & 63]; + *dst++ = b64[((src[i] & 3) << 4) | ((src[i + 1] & 240) >> 4)]; + *dst++ = b64[((src[i + 1] & 15) << 2) | ((src[i + 2] & 192) >> 6)]; + *dst++ = b64[src[i + 2] & 63]; + } + *dst++ = b64[(src[18] >> 2) & 63]; + *dst++ = b64[((src[18] & 3) << 4) | ((src[19] & 240) >> 4)]; + *dst++ = b64[((src[19] & 15) << 2)]; + *dst++ = '='; +} + +template +void HttpSocket::onData(uS::Socket s, char *data, int length) { + HttpSocket httpSocket(s); + HttpSocket::Data *httpData = httpSocket.getData(); + + httpSocket.cork(true); + + if (httpData->contentLength) { + httpData->missedDeadline = false; + if (httpData->contentLength >= length) { + getGroup(s)->httpDataHandler(httpData->outstandingResponsesTail, data, length, httpData->contentLength -= length); + return; + } else { + getGroup(s)->httpDataHandler(httpData->outstandingResponsesTail, data, httpData->contentLength, 0); + data += httpData->contentLength; + length -= httpData->contentLength; + httpData->contentLength = 0; + } + } + + if (FORCE_SLOW_PATH || httpData->httpBuffer.length()) { + if (httpData->httpBuffer.length() + length > MAX_HEADER_BUFFER_SIZE) { + httpSocket.onEnd(s); + return; + } + + httpData->httpBuffer.reserve(httpData->httpBuffer.length() + length + WebSocketProtocol::CONSUME_POST_PADDING); + httpData->httpBuffer.append(data, length); + data = (char *) httpData->httpBuffer.data(); + length = httpData->httpBuffer.length(); + } + + char *end = data + length; + char *cursor = data; + *end = '\r'; + Header headers[MAX_HEADERS]; + do { + char *lastCursor = cursor; + if ((cursor = getHeaders(cursor, end, headers, MAX_HEADERS))) { + HttpRequest req(headers); + + if (isServer) { + headers->valueLength = std::max(0, headers->valueLength - 9); + httpData->missedDeadline = false; + if (req.getHeader("upgrade", 7)) { + if (getGroup(s)->httpUpgradeHandler) { + getGroup(s)->httpUpgradeHandler(HttpSocket(s), req); + } else { + Header secKey = req.getHeader("sec-websocket-key", 17); + Header extensions = req.getHeader("sec-websocket-extensions", 24); + Header subprotocol = req.getHeader("sec-websocket-protocol", 22); + if (secKey.valueLength == 24) { + bool perMessageDeflate; + httpSocket.upgrade(secKey.value, extensions.value, extensions.valueLength, + subprotocol.value, subprotocol.valueLength, &perMessageDeflate); + getGroup(s)->removeHttpSocket(s); + s.enterState>(new WebSocket::Data(perMessageDeflate, httpData)); + getGroup(s)->addWebSocket(s); + s.cork(true); + getGroup(s)->connectionHandler(WebSocket(s), req); + s.cork(false); + delete httpData; + } else { + httpSocket.onEnd(s); + } + } + return; + } else { + if (getGroup(s)->httpRequestHandler) { + + HttpResponse *res = HttpResponse::allocateResponse(httpSocket, httpData); + if (httpData->outstandingResponsesTail) { + httpData->outstandingResponsesTail->next = res; + } else { + httpData->outstandingResponsesHead = res; + } + httpData->outstandingResponsesTail = res; + + Header contentLength; + if (req.getMethod() != HttpMethod::METHOD_GET && (contentLength = req.getHeader("content-length", 14))) { + httpData->contentLength = atoi(contentLength.value); + size_t bytesToRead = std::min(httpData->contentLength, end - cursor); + getGroup(s)->httpRequestHandler(res, req, cursor, bytesToRead, httpData->contentLength -= bytesToRead); + cursor += bytesToRead; + } else { + getGroup(s)->httpRequestHandler(res, req, nullptr, 0, 0); + } + + if (s.isClosed() || s.isShuttingDown()) { + return; + } + } else { + httpSocket.onEnd(s); + return; + } + } + } else { + if (req.getHeader("upgrade", 7)) { + s.enterState>(new WebSocket::Data(false, httpData)); + + httpSocket.cancelTimeout(); + httpSocket.setUserData(httpData->httpUser); + getGroup(s)->addWebSocket(s); + s.cork(true); + getGroup(s)->connectionHandler(WebSocket(s), req); + s.cork(false); + + if (!(s.isClosed() || s.isShuttingDown())) { + WebSocketProtocol *kws = (WebSocketProtocol *) ((WebSocket::Data *) s.getSocketData()); + kws->consume(cursor, end - cursor, s); + } + + delete httpData; + } else { + httpSocket.onEnd(s); + } + return; + } + } else { + if (!httpData->httpBuffer.length()) { + if (length > MAX_HEADER_BUFFER_SIZE) { + httpSocket.onEnd(s); + } else { + httpData->httpBuffer.append(lastCursor, end - lastCursor); + } + } + return; + } + } while(cursor != end); + + httpSocket.cork(false); + httpData->httpBuffer.clear(); +} + +// todo: make this into a transformer and make use of sendTransformed +template +void HttpSocket::upgrade(const char *secKey, const char *extensions, size_t extensionsLength, + const char *subprotocol, size_t subprotocolLength, bool *perMessageDeflate) { + + uS::SocketData::Queue::Message *messagePtr; + + if (isServer) { + *perMessageDeflate = false; + std::string extensionsResponse; + if (extensionsLength) { + Group *group = getGroup(*this); + ExtensionsNegotiator extensionsNegotiator(group->extensionOptions); + extensionsNegotiator.readOffer(std::string(extensions, extensionsLength)); + extensionsResponse = extensionsNegotiator.generateOffer(); + if (extensionsNegotiator.getNegotiatedOptions() & PERMESSAGE_DEFLATE) { + *perMessageDeflate = true; + } + } + + unsigned char shaInput[] = "XXXXXXXXXXXXXXXXXXXXXXXX258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + memcpy(shaInput, secKey, 24); + unsigned char shaDigest[SHA_DIGEST_LENGTH]; + SHA1(shaInput, sizeof(shaInput) - 1, shaDigest); + + char upgradeBuffer[1024]; + memcpy(upgradeBuffer, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ", 97); + base64(shaDigest, upgradeBuffer + 97); + memcpy(upgradeBuffer + 125, "\r\n", 2); + size_t upgradeResponseLength = 127; + if (extensionsResponse.length()) { + memcpy(upgradeBuffer + upgradeResponseLength, "Sec-WebSocket-Extensions: ", 26); + memcpy(upgradeBuffer + upgradeResponseLength + 26, extensionsResponse.data(), extensionsResponse.length()); + memcpy(upgradeBuffer + upgradeResponseLength + 26 + extensionsResponse.length(), "\r\n", 2); + upgradeResponseLength += 26 + extensionsResponse.length() + 2; + } + if (subprotocolLength) { + memcpy(upgradeBuffer + upgradeResponseLength, "Sec-WebSocket-Protocol: ", 24); + memcpy(upgradeBuffer + upgradeResponseLength + 24, subprotocol, subprotocolLength); + memcpy(upgradeBuffer + upgradeResponseLength + 24 + subprotocolLength, "\r\n", 2); + upgradeResponseLength += 24 + subprotocolLength + 2; + } + static char stamp[] = "Sec-WebSocket-Version: 13\r\nWebSocket-Server: uWebSockets\r\n\r\n"; + memcpy(upgradeBuffer + upgradeResponseLength, stamp, sizeof(stamp) - 1); + upgradeResponseLength += sizeof(stamp) - 1; + + messagePtr = allocMessage(upgradeResponseLength, upgradeBuffer); + } else { + messagePtr = allocMessage(getData()->httpBuffer.length(), getData()->httpBuffer.data()); + getData()->httpBuffer.clear(); + } + + bool wasTransferred; + if (write(messagePtr, wasTransferred)) { + if (!wasTransferred) { + freeMessage(messagePtr); + } else { + messagePtr->callback = nullptr; + } + } else { + freeMessage(messagePtr); + } +} + +template +void HttpSocket::onEnd(uS::Socket s) { + if (!s.isShuttingDown()) { + if (isServer) { + getGroup(s)->removeHttpSocket(HttpSocket(s)); + getGroup(s)->httpDisconnectionHandler(HttpSocket(s)); + } + } else { + s.cancelTimeout(); + } + + Data *httpSocketData = (Data *) s.getSocketData(); + + s.close(); + + while (!httpSocketData->messageQueue.empty()) { + uS::SocketData::Queue::Message *message = httpSocketData->messageQueue.front(); + if (message->callback) { + message->callback(nullptr, message->callbackData, true, nullptr); + } + httpSocketData->messageQueue.pop(); + } + + while (httpSocketData->outstandingResponsesHead) { + getGroup(s)->httpCancelledRequestHandler(httpSocketData->outstandingResponsesHead); + HttpResponse *next = httpSocketData->outstandingResponsesHead->next; + delete httpSocketData->outstandingResponsesHead; + httpSocketData->outstandingResponsesHead = next; + } + + if (httpSocketData->preAllocatedResponse) { + delete httpSocketData->preAllocatedResponse; + } + + if (!isServer) { + s.cancelTimeout(); + getGroup(s)->errorHandler(httpSocketData->httpUser); + } + + delete httpSocketData; +} + +template struct HttpSocket; +template struct HttpSocket; + +} diff --git a/src/uWS/HTTPSocket.h b/src/uWS/HTTPSocket.h new file mode 100644 index 000000000..cfe1a3070 --- /dev/null +++ b/src/uWS/HTTPSocket.h @@ -0,0 +1,306 @@ +#ifndef HTTPSOCKET_UWS_H +#define HTTPSOCKET_UWS_H + +#include "Socket.h" +#include +// #include + +#include + +namespace uWS { + +struct Header { + char *key, *value; + unsigned int keyLength, valueLength; + + operator bool() { + return key; + } + + // slow without string_view! + std::string toString() { + return std::string(value, valueLength); + } +}; + +enum HttpMethod { + METHOD_GET, + METHOD_POST, + METHOD_PUT, + METHOD_DELETE, + METHOD_PATCH, + METHOD_OPTIONS, + METHOD_HEAD, + METHOD_TRACE, + METHOD_CONNECT, + METHOD_INVALID +}; + +struct HttpRequest { + Header *headers; + Header getHeader(const char *key) { + return getHeader(key, strlen(key)); + } + + HttpRequest(Header *headers = nullptr) : headers(headers) {} + + Header getHeader(const char *key, size_t length) { + if (headers) { + for (Header *h = headers; *++h; ) { + if (h->keyLength == length && !strncmp(h->key, key, length)) { + return *h; + } + } + } + return {nullptr, nullptr, 0, 0}; + } + + Header getUrl() { + if (headers->key) { + return *headers; + } + return {nullptr, nullptr, 0, 0}; + } + + HttpMethod getMethod() { + if (!headers->key) { + return METHOD_INVALID; + } + switch (headers->keyLength) { + case 3: + if (!strncmp(headers->key, "get", 3)) { + return METHOD_GET; + } else if (!strncmp(headers->key, "put", 3)) { + return METHOD_PUT; + } + break; + case 4: + if (!strncmp(headers->key, "post", 4)) { + return METHOD_POST; + } else if (!strncmp(headers->key, "head", 4)) { + return METHOD_HEAD; + } + break; + case 5: + if (!strncmp(headers->key, "patch", 5)) { + return METHOD_PATCH; + } else if (!strncmp(headers->key, "trace", 5)) { + return METHOD_TRACE; + } + break; + case 6: + if (!strncmp(headers->key, "delete", 6)) { + return METHOD_DELETE; + } + break; + case 7: + if (!strncmp(headers->key, "options", 7)) { + return METHOD_OPTIONS; + } else if (!strncmp(headers->key, "connect", 7)) { + return METHOD_CONNECT; + } + break; + } + return METHOD_INVALID; + } +}; + +struct HttpResponse; + +template +struct WIN32_EXPORT HttpSocket : private uS::Socket { + struct Data : uS::SocketData { + std::string httpBuffer; + size_t contentLength = 0; + void *httpUser; + bool missedDeadline = false; + + HttpResponse *outstandingResponsesHead = nullptr; + HttpResponse *outstandingResponsesTail = nullptr; + HttpResponse *preAllocatedResponse = nullptr; + + Data(uS::SocketData *socketData) : uS::SocketData(*socketData) {} + }; + + using uS::Socket::getUserData; + using uS::Socket::setUserData; + using uS::Socket::getAddress; + using uS::Socket::Address; + + uv_poll_t *getPollHandle() const {return p;} + + using uS::Socket::shutdown; + using uS::Socket::close; + + void terminate() { + onEnd(*this); + } + + HttpSocket(uS::Socket s) : uS::Socket(s) {} + + typename HttpSocket::Data *getData() { + return (HttpSocket::Data *) getSocketData(); + } + + void upgrade(const char *secKey, const char *extensions, + size_t extensionsLength, const char *subprotocol, + size_t subprotocolLength, bool *perMessageDeflate); + +private: + friend class uS::Socket; + friend struct HttpResponse; + friend struct Hub; + static void onData(uS::Socket s, char *data, int length); + static void onEnd(uS::Socket s); +}; + +struct HttpResponse { + + HttpSocket httpSocket; + HttpResponse *next = nullptr; + void *userData = nullptr; + void *extraUserData = nullptr; + uS::SocketData::Queue::Message *messageQueue = nullptr; + bool hasEnded = false; + bool hasHead = false; + + HttpResponse(HttpSocket httpSocket) : httpSocket(httpSocket) { + + } + + template + static HttpResponse *allocateResponse(HttpSocket httpSocket, typename HttpSocket::Data *httpData) { + if (httpData->preAllocatedResponse) { + HttpResponse *ret = httpData->preAllocatedResponse; + httpData->preAllocatedResponse = nullptr; + return ret; + } else { + return new HttpResponse(httpSocket); + } + } + + //template + void freeResponse(typename HttpSocket::Data *httpData) { + if (httpData->preAllocatedResponse) { + delete this; + } else { + httpData->preAllocatedResponse = this; + } + } + + void write(const char *message, size_t length = 0, + void(*callback)(void *httpSocket, void *data, bool cancelled, void *reserved) = nullptr, + void *callbackData = nullptr) { + + struct NoopTransformer { + static size_t estimate(const char *data, size_t length) { + return length; + } + + static size_t transform(const char *src, char *dst, size_t length, int transformData) { + memcpy(dst, src, length); + return length; + } + }; + + httpSocket.sendTransformed(message, length, callback, callbackData, 0); + hasHead = true; + } + + // todo: maybe this function should have a fast path for 0 length? + void end(const char *message = nullptr, size_t length = 0, + void(*callback)(void *httpResponse, void *data, bool cancelled, void *reserved) = nullptr, + void *callbackData = nullptr) { + + struct TransformData { + bool hasHead; + } transformData = {hasHead}; + + struct HttpTransformer { + + // todo: this should get TransformData! + static size_t estimate(const char *data, size_t length) { + return length + 128; + } + + static size_t transform(const char *src, char *dst, size_t length, TransformData transformData) { + // todo: sprintf is extremely slow + int offset = transformData.hasHead ? 0 : std::sprintf(dst, "HTTP/1.1 200 OK\r\nContent-Length: %u\r\n\r\n", (unsigned int) length); + memcpy(dst + offset, src, length); + return length + offset; + } + }; + + if (httpSocket.getData()->outstandingResponsesHead != this) { + uS::SocketData::Queue::Message *messagePtr = httpSocket.allocMessage(HttpTransformer::estimate(message, length)); + messagePtr->length = HttpTransformer::transform(message, (char *) messagePtr->data, length, transformData); + messagePtr->callback = callback; + messagePtr->callbackData = callbackData; + messagePtr->nextMessage = messageQueue; + messageQueue = messagePtr; + hasEnded = true; + } else { + httpSocket.sendTransformed(message, length, callback, callbackData, transformData); + // move head as far as possible + HttpResponse *head = next; + while (head) { + // empty message queue + uS::SocketData::Queue::Message *messagePtr = head->messageQueue; + while (messagePtr) { + uS::SocketData::Queue::Message *nextMessage = messagePtr->nextMessage; + + bool wasTransferred; + if (httpSocket.write(messagePtr, wasTransferred)) { + if (!wasTransferred) { + httpSocket.freeMessage(messagePtr); + if (callback) { + callback(this, callbackData, false, nullptr); + } + } else { + messagePtr->callback = callback; + messagePtr->callbackData = callbackData; + } + } else { + httpSocket.freeMessage(messagePtr); + if (callback) { + callback(this, callbackData, true, nullptr); + } + goto updateHead; + } + messagePtr = nextMessage; + } + // cannot go beyond unfinished responses + if (!head->hasEnded) { + break; + } else { + HttpResponse *next = head->next; + head->freeResponse(httpSocket.getData()); + head = next; + } + } + updateHead: + httpSocket.getData()->outstandingResponsesHead = head; + if (!head) { + httpSocket.getData()->outstandingResponsesTail = nullptr; + } + + freeResponse(httpSocket.getData()); + } + } + + void setUserData(void *userData) { + this->userData = userData; + } + + void *getUserData() { + return userData; + } + + HttpSocket getHttpSocket() { + return httpSocket; + } +}; + +} + +#endif // HTTPSOCKET_UWS_H diff --git a/src/uWS/Hub.cpp b/src/uWS/Hub.cpp new file mode 100644 index 000000000..c11b5acac --- /dev/null +++ b/src/uWS/Hub.cpp @@ -0,0 +1,163 @@ +#include "Hub.h" +#include "HTTPSocket.h" +#include + +static const int INFLATE_LESS_THAN_ROUGHLY = 16777216; + +namespace uWS { + +char *Hub::inflate(char *data, size_t &length) { + dynamicInflationBuffer.clear(); + + inflationStream.next_in = (Bytef *) data; + inflationStream.avail_in = length; + + int err; + do { + inflationStream.next_out = (Bytef *) inflationBuffer; + inflationStream.avail_out = LARGE_BUFFER_SIZE; + err = ::inflate(&inflationStream, Z_FINISH); + if (!inflationStream.avail_in) { + break; + } + + dynamicInflationBuffer.append(inflationBuffer, LARGE_BUFFER_SIZE - inflationStream.avail_out); + } while (err == Z_BUF_ERROR && dynamicInflationBuffer.length() <= INFLATE_LESS_THAN_ROUGHLY); + + inflateReset(&inflationStream); + + if ((err != Z_BUF_ERROR && err != Z_OK) || dynamicInflationBuffer.length() > INFLATE_LESS_THAN_ROUGHLY) { + length = 0; + return nullptr; + } + + if (dynamicInflationBuffer.length()) { + dynamicInflationBuffer.append(inflationBuffer, LARGE_BUFFER_SIZE - inflationStream.avail_out); + + length = dynamicInflationBuffer.length(); + return (char *) dynamicInflationBuffer.data(); + } + + length = LARGE_BUFFER_SIZE - inflationStream.avail_out; + return inflationBuffer; +} + +void Hub::onServerAccept(uS::Socket s) { + uS::SocketData *socketData = s.getSocketData(); + s.enterState>(new HttpSocket::Data(socketData)); + ((Group *) socketData->nodeData)->addHttpSocket(s); + ((Group *) socketData->nodeData)->httpConnectionHandler(s); + s.setNoDelay(true); + delete socketData; +} + +void Hub::onClientConnection(uS::Socket s, bool error) { + HttpSocket::Data *httpSocketData = (HttpSocket::Data *) s.getSocketData(); + + if (error) { + ((Group *) httpSocketData->nodeData)->errorHandler(httpSocketData->httpUser); + delete httpSocketData; + } else { + s.enterState>(s.getSocketData()); + HttpSocket(s).upgrade(nullptr, nullptr, 0, nullptr, 0, nullptr); + } +} + +bool Hub::listen(const char *host, int port, uS::TLS::Context sslContext, int options, Group *eh) { + if (!eh) { + eh = (Group *) this; + } + + if (uS::Node::listen(host, port, sslContext, options, (uS::NodeData *) eh, nullptr)) { + eh->errorHandler(port); + return false; + } + return true; +} + +bool Hub::listen(int port, uS::TLS::Context sslContext, int options, Group *eh) { + return listen(nullptr, port, sslContext, options, eh); +} + +void Hub::connect(std::string uri, void *user, int timeoutMs, Group *eh, std::string subprotocol) { + if (!eh) { + eh = (Group *) this; + } + + size_t offset = 0; + std::string protocol = uri.substr(offset, uri.find("://")), hostname, portStr, path; + if ((offset += protocol.length() + 3) < uri.length()) { + hostname = uri.substr(offset, uri.find_first_of(":/", offset) - offset); + + offset += hostname.length(); + if (uri[offset] == ':') { + offset++; + portStr = uri.substr(offset, uri.find("/", offset) - offset); + } + + offset += portStr.length(); + if (uri[offset] == '/') { + path = uri.substr(++offset); + } + } + + if (hostname.length()) { + int port = 80; + bool secure = false; + if (protocol == "wss") { + port = 443; + secure = true; + } else if (protocol != "ws") { + eh->errorHandler(user); + } + + if (portStr.length()) { + port = stoi(portStr); + } + + uS::SocketData socketData((uS::NodeData *) eh); + HttpSocket::Data *httpSocketData = new HttpSocket::Data(&socketData); + + std::string optionalSubprotocol; + if (!subprotocol.empty()) { + optionalSubprotocol = "Sec-WebSocket-Protocol: " + subprotocol + "\r\n"; + } + httpSocketData->httpUser = user; + httpSocketData->httpBuffer = "GET /" + path + " HTTP/1.1\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Key: x3JJHMbDL1EzLkh9GBhXDw==\r\n" + "Host: " + hostname + "\r\n" + + optionalSubprotocol + + "Sec-WebSocket-Version: 13\r\n\r\n"; + + uS::Socket s = uS::Node::connect(hostname.c_str(), port, secure, httpSocketData); + if (s) { + s.startTimeout::onEnd>(timeoutMs); + // getGroup(s)->addHttpSocket(s); + } + } else { + eh->errorHandler(user); + } +} + +void Hub::upgrade(uv_os_sock_t fd, const char *secKey, SSL *ssl, const char *extensions, size_t extensionsLength, const char *subprotocol, size_t subprotocolLength, Group *serverGroup) { + if (!serverGroup) { + serverGroup = &getDefaultGroup(); + } + + uS::Socket s = uS::Socket::init((uS::NodeData *) serverGroup, fd, ssl); + uS::SocketData *socketData = s.getSocketData(); + HttpSocket::Data *temporaryHttpData = new HttpSocket::Data(socketData); + delete socketData; + s.enterState>(temporaryHttpData); + + bool perMessageDeflate; + HttpSocket(s).upgrade(secKey, extensions, extensionsLength, subprotocol, subprotocolLength, &perMessageDeflate); + s.enterState>(new WebSocket::Data(perMessageDeflate, s.getSocketData())); + serverGroup->addWebSocket(s); + serverGroup->connectionHandler(WebSocket(s), HttpRequest({})); + delete temporaryHttpData; +} + +} diff --git a/src/uWS/Hub.h b/src/uWS/Hub.h new file mode 100644 index 000000000..ba4c03c79 --- /dev/null +++ b/src/uWS/Hub.h @@ -0,0 +1,81 @@ +#ifndef HUB_UWS_H +#define HUB_UWS_H + +#include "Group.h" +#include "Node.h" +#include +#include +#include + +static_assert (UV_VERSION_MINOR >= 3, "µWebSockets requires libuv >=1.3.0"); + +namespace uWS { + +struct WIN32_EXPORT Hub : private uS::Node, public Group, public Group { + + template + Group *createGroup(int extensionOptions = 0) { + return new Group(extensionOptions, this, nodeData); + } + + template + Group &getDefaultGroup() { + return (Group &) *this; + } + + struct ConnectionData { + std::string path; + void *user; + Group *group; + }; + + z_stream inflationStream = {}; + char *inflationBuffer; + char *inflate(char *data, size_t &length); + std::string dynamicInflationBuffer; + static const int LARGE_BUFFER_SIZE = 300 * 1024; + + static void onServerAccept(uS::Socket s); + static void onClientConnection(uS::Socket s, bool error); + + bool listen(int port, uS::TLS::Context sslContext = nullptr, int options = 0, Group *eh = nullptr); + bool listen(const char *host, int port, uS::TLS::Context sslContext = nullptr, int options = 0, Group *eh = nullptr); + void connect(std::string uri, void *user, int timeoutMs = 5000, Group *eh = nullptr, std::string subprotocol = ""); + void upgrade(uv_os_sock_t fd, const char *secKey, SSL *ssl, const char *extensions, size_t extensionsLength, const char *subprotocol, size_t subprotocolLength, Group *serverGroup = nullptr); + + Hub(int extensionOptions = 0, bool useDefaultLoop = false) : uS::Node(LARGE_BUFFER_SIZE, WebSocketProtocol::CONSUME_PRE_PADDING, WebSocketProtocol::CONSUME_POST_PADDING, useDefaultLoop), + Group(extensionOptions, this, nodeData), Group(0, this, nodeData) { + inflateInit2(&inflationStream, -15); + inflationBuffer = new char[LARGE_BUFFER_SIZE]; + } + + ~Hub() { + inflateEnd(&inflationStream); + delete [] inflationBuffer; + } + + using uS::Node::run; + using uS::Node::getLoop; + using Group::onConnection; + using Group::onConnection; + using Group::onMessage; + using Group::onMessage; + using Group::onDisconnection; + using Group::onDisconnection; + using Group::onPing; + using Group::onPing; + using Group::onPong; + using Group::onPong; + using Group::onError; + using Group::onError; + using Group::onHttpRequest; + using Group::onHttpData; + using Group::onHttpConnection; + using Group::onHttpDisconnection; + using Group::onHttpUpgrade; + using Group::onCancelledHttpRequest; +}; + +} + +#endif // HUB_UWS_H diff --git a/src/uWS/Networking.cpp b/src/uWS/Networking.cpp new file mode 100644 index 000000000..743f83b90 --- /dev/null +++ b/src/uWS/Networking.cpp @@ -0,0 +1,78 @@ +#include "Networking.h" + +namespace uS { + +namespace TLS { + +Context::Context(const Context &other) +{ + if (other.context) { + context = other.context; + SSL_CTX_up_ref(context); + } +} + +Context &Context::operator=(const Context &other) { + if (other.context) { + context = other.context; + SSL_CTX_up_ref(context); + } + return *this; +} + +Context::~Context() +{ + if (context) { + SSL_CTX_free(context); + } +} + +struct Init { + Init() {SSL_library_init();} + ~Init() {/*EVP_cleanup();*/} +} init; + +Context createContext(std::string certChainFileName, std::string keyFileName, std::string keyFilePassword) +{ + Context context(SSL_CTX_new(SSLv23_server_method())); + if (!context.context) { + return nullptr; + } + + if (keyFilePassword.length()) { + context.password.reset(new std::string(keyFilePassword)); + SSL_CTX_set_default_passwd_cb_userdata(context.context, context.password.get()); + SSL_CTX_set_default_passwd_cb(context.context, Context::passwordCallback); + } + + SSL_CTX_set_options(context.context, SSL_OP_NO_SSLv3); + + if (SSL_CTX_use_certificate_chain_file(context.context, certChainFileName.c_str()) != 1) { + return nullptr; + } else if (SSL_CTX_use_PrivateKey_file(context.context, keyFileName.c_str(), SSL_FILETYPE_PEM) != 1) { + return nullptr; + } + + return context; +} + +} + +#ifndef _WIN32 +struct Init { + Init() {signal(SIGPIPE, SIG_IGN);} +} init; +#endif + +#ifdef _WIN32 +#pragma comment(lib, "Ws2_32.lib") + +struct WindowsInit { + WSADATA wsaData; + WindowsInit() {WSAStartup(MAKEWORD(2, 2), &wsaData);} + ~WindowsInit() {WSACleanup();} +} windowsInit; + +#endif + +} diff --git a/src/uWS/Networking.h b/src/uWS/Networking.h new file mode 100644 index 000000000..1b7c1ad69 --- /dev/null +++ b/src/uWS/Networking.h @@ -0,0 +1,235 @@ +#ifndef NETWORKING_UWS_H +#define NETWORKING_UWS_H + +#include +#if OPENSSL_VERSION_NUMBER < 0x10100000L +#define SSL_CTX_up_ref(x) x->references++ +#define SSL_up_ref(x) x->references++ +#endif + +#ifndef __linux +#define MSG_NOSIGNAL 0 +#else +#include +#endif + +#ifdef __APPLE__ +#define htobe64(x) OSSwapHostToBigInt64(x) +#define be64toh(x) OSSwapBigToHostInt64(x) +#endif + +#ifdef _WIN32 +#define NOMINMAX +#include +#include +#define SHUT_WR SD_SEND +#define htobe64(x) htonll(x) +#define be64toh(x) ntohll(x) +#define __thread __declspec(thread) +#define pthread_t DWORD +#define pthread_self GetCurrentThreadId +#define WIN32_EXPORT __declspec(dllexport) + +inline void close(SOCKET fd) {closesocket(fd);} +inline int setsockopt(SOCKET fd, int level, int optname, const void *optval, socklen_t optlen) { + return setsockopt(fd, level, optname, (const char *) optval, optlen); +} + +inline SOCKET dup(SOCKET socket) { + WSAPROTOCOL_INFOW pi; + if (WSADuplicateSocketW(socket, GetCurrentProcessId(), &pi) == SOCKET_ERROR) { + return INVALID_SOCKET; + } + return WSASocketW(pi.iAddressFamily, pi.iSocketType, pi.iProtocol, &pi, 0, WSA_FLAG_OVERLAPPED); +} +#else +#include +#include +#include +#include +#define SOCKET_ERROR -1 +#define INVALID_SOCKET -1 +#define WIN32_EXPORT +#endif + +#include "uUV.h" +#include +#include +#include +#include +#include +#include + +namespace uS { + +namespace TLS { + +class WIN32_EXPORT Context { +private: + SSL_CTX *context = nullptr; + std::shared_ptr password; + + static int passwordCallback(char *buf, int size, int rwflag, void *u) + { + std::string *password = (std::string *) u; + int length = std::min(size, password->length()); + memcpy(buf, password->data(), length); + buf[length] = '\0'; + return length; + } + +public: + friend Context createContext(std::string certChainFileName, std::string keyFileName, std::string keyFilePassword); + Context(SSL_CTX *context) : context(context) { + + } + + Context() = default; + Context(const Context &other); + Context &operator=(const Context &other); + ~Context(); + operator bool() { + return context; + } + + SSL_CTX *getNativeContext() { + return context; + } +}; + +Context createContext(std::string certChainFileName, std::string keyFileName, std::string keyFilePassword = std::string()); + +} + +struct SocketData; + +struct WIN32_EXPORT NodeData { + char *recvBufferMemoryBlock; + char *recvBuffer; + int recvLength; + uv_loop_t *loop; + void *user = nullptr; + static const int preAllocMaxSize = 1024; + char **preAlloc; + SSL_CTX *clientContext; + + uv_async_t *async = nullptr; + pthread_t tid; + + struct TransferData { + uv_poll_t *p; + uv_os_sock_t fd; + SocketData *socketData; + uv_poll_cb pollCb; + void (*cb)(uv_poll_t *); + }; + + void addAsync() { + async = new uv_async_t; + async->data = this; + uv_async_init(loop, async, NodeData::asyncCallback); + } + + std::mutex *asyncMutex; + std::vector transferQueue; + std::vector changePollQueue; + static void asyncCallback(uv_async_t *async); + + static int getMemoryBlockIndex(size_t length) { + return (length >> 4) + bool(length & 15); + } + + char *getSmallMemoryBlock(int index) { + if (preAlloc[index]) { + char *memory = preAlloc[index]; + preAlloc[index] = nullptr; + return memory; + } else { + return new char[index << 4]; + } + } + + void freeSmallMemoryBlock(char *memory, int index) { + if (!preAlloc[index]) { + preAlloc[index] = memory; + } else { + delete [] memory; + } + } +}; + +struct SocketData { + NodeData *nodeData; + SSL *ssl; + void *user = nullptr; + + // combine these two! state! + int poll; + bool shuttingDown = false; + + SocketData(NodeData *nodeData) : nodeData(nodeData) { + + } + + struct Queue { + struct Message { + const char *data; + size_t length; + Message *nextMessage = nullptr; + void (*callback)(void *socket, void *data, bool cancelled, void *reserved) = nullptr; + void *callbackData = nullptr, *reserved = nullptr; + }; + + Message *head = nullptr, *tail = nullptr; + void pop() + { + Message *nextMessage; + if ((nextMessage = head->nextMessage)) { + delete [] (char *) head; + head = nextMessage; + } else { + delete [] (char *) head; + head = tail = nullptr; + } + } + + bool empty() {return head == nullptr;} + Message *front() {return head;} + + void push(Message *message) + { + message->nextMessage = nullptr; + if (tail) { + tail->nextMessage = message; + tail = message; + } else { + head = message; + tail = message; + } + } + } messageQueue; + + uv_poll_t *next = nullptr, *prev = nullptr; +}; + +struct ListenData : SocketData { + + ListenData(NodeData *nodeData) : SocketData(nodeData) { + + } + + uv_poll_t *listenPoll = nullptr; + uv_timer_t *listenTimer = nullptr; + uv_os_sock_t sock; + uS::TLS::Context sslContext; +}; + +enum SocketState : unsigned char { + CLOSED, + POLL_READ, + POLL_WRITE +}; + +} + +#endif // NETWORKING_UWS_H diff --git a/src/uWS/Node.cpp b/src/uWS/Node.cpp new file mode 100644 index 000000000..f7bf5f3ae --- /dev/null +++ b/src/uWS/Node.cpp @@ -0,0 +1,81 @@ +#include "Node.h" + +namespace uS { + +void NodeData::asyncCallback(uv_async_t *async) +{ + NodeData *nodeData = (NodeData *) async->data; + + nodeData->asyncMutex->lock(); + for (TransferData transferData : nodeData->transferQueue) { + uv_poll_init_socket(nodeData->loop, transferData.p, transferData.fd); + transferData.p->data = transferData.socketData; + transferData.socketData->nodeData = nodeData; + uv_poll_start(transferData.p, transferData.socketData->poll, transferData.pollCb); + + transferData.cb(transferData.p); + } + + for (uv_poll_t *p : nodeData->changePollQueue) { + SocketData *socketData = (SocketData *) p->data; + uv_poll_start(p, socketData->poll, /*p->poll_cb*/ Socket(p).getPollCallback()); + } + + nodeData->changePollQueue.clear(); + nodeData->transferQueue.clear(); + nodeData->asyncMutex->unlock(); +} + +Node::Node(int recvLength, int prePadding, int postPadding, bool useDefaultLoop) { + nodeData = new NodeData; + nodeData->recvBufferMemoryBlock = new char[recvLength]; + nodeData->recvBuffer = nodeData->recvBufferMemoryBlock + prePadding; + nodeData->recvLength = recvLength - prePadding - postPadding; + + nodeData->tid = pthread_self(); + + if (useDefaultLoop) { + loop = uv_default_loop(); + } else { + loop = uv_loop_new(); + } + + nodeData->loop = loop; + nodeData->asyncMutex = &asyncMutex; + + int indices = NodeData::getMemoryBlockIndex(NodeData::preAllocMaxSize) + 1; + nodeData->preAlloc = new char*[indices]; + for (int i = 0; i < indices; i++) { + nodeData->preAlloc[i] = nullptr; + } + + nodeData->clientContext = SSL_CTX_new(SSLv23_client_method()); + SSL_CTX_set_options(nodeData->clientContext, SSL_OP_NO_SSLv3); +} + +void Node::run() { + nodeData->tid = pthread_self(); + + uv_run(loop, UV_RUN_DEFAULT); +} + +Node::~Node() { + delete [] nodeData->recvBufferMemoryBlock; + SSL_CTX_free(nodeData->clientContext); + + int indices = NodeData::getMemoryBlockIndex(NodeData::preAllocMaxSize) + 1; + for (int i = 0; i < indices; i++) { + if (nodeData->preAlloc[i]) { + delete [] nodeData->preAlloc[i]; + } + } + delete [] nodeData->preAlloc; + + delete nodeData; + + if (loop != uv_default_loop()) { + uv_loop_delete(loop); + } +} + +} diff --git a/src/uWS/Node.h b/src/uWS/Node.h new file mode 100644 index 000000000..f9a031e89 --- /dev/null +++ b/src/uWS/Node.h @@ -0,0 +1,237 @@ +#ifndef NODE_UWS_H +#define NODE_UWS_H + +#include "Socket.h" +#include +#include + +namespace uS { + +enum ListenOptions : int { + REUSE_PORT = 1, + ONLY_IPV4 = 2 +}; + +class WIN32_EXPORT Node { +protected: + uv_loop_t *loop; + NodeData *nodeData; + std::mutex asyncMutex; + +public: + Node(int recvLength = 1024, int prePadding = 0, int postPadding = 0, bool useDefaultLoop = false); + ~Node(); + void run(); + + uv_loop_t *getLoop() { + return loop; + } + + template + static void connect_cb(uv_poll_t *p, int status, int events) { + C(p, status < 0); + } + + template + uS::Socket connect(const char *hostname, int port, bool secure, uS::SocketData *socketData) { + uv_poll_t *p = new uv_poll_t; + p->data = socketData; + + addrinfo hints, *result; + memset(&hints, 0, sizeof(addrinfo)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + if (getaddrinfo(hostname, std::to_string(port).c_str(), &hints, &result) != 0) { + C(p, true); + delete p; + return nullptr; + } + + uv_os_sock_t fd = socket(result->ai_family, result->ai_socktype, result->ai_protocol); + if (fd == -1) { + C(p, true); + delete p; + return nullptr; + } + +#ifdef __APPLE__ + int noSigpipe = 1; + setsockopt(fd, SOL_SOCKET, SO_NOSIGPIPE, &noSigpipe, sizeof(int)); +#endif + + ::connect(fd, result->ai_addr, result->ai_addrlen); + freeaddrinfo(result); + + NodeData *nodeData = socketData->nodeData; + if (secure) { + socketData->ssl = SSL_new(nodeData->clientContext); + SSL_set_fd(socketData->ssl, fd); + SSL_set_connect_state(socketData->ssl); + SSL_set_mode(socketData->ssl, SSL_MODE_RELEASE_BUFFERS); + SSL_set_tlsext_host_name(socketData->ssl, hostname); + } else { + socketData->ssl = nullptr; + } + + socketData->poll = UV_READABLE; + uv_poll_init_socket(loop, p, fd); + uv_poll_start(p, UV_WRITABLE, connect_cb); + return p; + } + + template + static void accept_poll_cb(uv_poll_t *p, int status, int events) { + ListenData *listenData = (ListenData *) p->data; + accept_cb(listenData); + } + + template + static void accept_timer_cb(uv_timer_t *p) { + ListenData *listenData = (ListenData *) p->data; + accept_cb(listenData); + } + + template + static void accept_cb(ListenData *listenData) { + uv_os_sock_t serverFd = listenData->sock; + uv_os_sock_t clientFd = accept(serverFd, nullptr, nullptr); + if (clientFd == INVALID_SOCKET) { + /* + * If accept is failing, the pending connection won't be removed and the + * polling will cause the server to spin, using 100% cpu. Switch to a timer + * event instead to avoid this. + */ + if (!TIMER && errno != EAGAIN && errno != EWOULDBLOCK) { + uv_poll_stop(listenData->listenPoll); + uv_close(listenData->listenPoll, [](uv_handle_t *handle) { + delete handle; + }); + listenData->listenPoll = nullptr; + + listenData->listenTimer = new uv_timer_t(); + listenData->listenTimer->data = listenData; + uv_timer_init(listenData->nodeData->loop, listenData->listenTimer); + uv_timer_start(listenData->listenTimer, accept_timer_cb, 1000, 1000); + } + return; + } else if (TIMER) { + uv_timer_stop(listenData->listenTimer); + uv_close(listenData->listenTimer, [](uv_handle_t *handle) { + delete handle; + }); + listenData->listenTimer = nullptr; + + listenData->listenPoll = new uv_poll_t; + listenData->listenPoll->data = listenData; + uv_poll_init_socket(listenData->nodeData->loop, listenData->listenPoll, serverFd); + uv_poll_start(listenData->listenPoll, UV_READABLE, accept_poll_cb); + } + do { + #ifdef __APPLE__ + int noSigpipe = 1; + setsockopt(clientFd, SOL_SOCKET, SO_NOSIGPIPE, &noSigpipe, sizeof(int)); + #endif + + SSL *ssl = nullptr; + if (listenData->sslContext) { + ssl = SSL_new(listenData->sslContext.getNativeContext()); + SSL_set_fd(ssl, clientFd); + SSL_set_accept_state(ssl); + SSL_set_mode(ssl, SSL_MODE_RELEASE_BUFFERS); + } + + SocketData *socketData = new SocketData(listenData->nodeData); + socketData->ssl = ssl; + + uv_poll_t *clientPoll = new uv_poll_t; +#ifdef USE_MICRO_UV + uv_poll_init_socket(listenData->listenPoll->get_loop(), clientPoll, clientFd); +#else + uv_poll_init_socket(listenData->listenPoll->loop, clientPoll, clientFd); +#endif + clientPoll->data = socketData; + + socketData->poll = UV_READABLE; + A(clientPoll); + } while ((clientFd = accept(serverFd, nullptr, nullptr)) != INVALID_SOCKET); + } + + // todo: hostname, backlog + template + bool listen(const char *host, int port, uS::TLS::Context sslContext, int options, uS::NodeData *nodeData, void *user) { + addrinfo hints, *result; + memset(&hints, 0, sizeof(addrinfo)); + + hints.ai_flags = AI_PASSIVE; + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + + if (getaddrinfo(host, std::to_string(port).c_str(), &hints, &result)) { + return true; + } + + uv_os_sock_t listenFd = SOCKET_ERROR; + addrinfo *listenAddr; + if ((options & uS::ONLY_IPV4) == 0) { + for (addrinfo *a = result; a && listenFd == SOCKET_ERROR; a = a->ai_next) { + if (a->ai_family == AF_INET6) { + listenFd = socket(a->ai_family, a->ai_socktype, a->ai_protocol); + listenAddr = a; + } + } + } + + for (addrinfo *a = result; a && listenFd == SOCKET_ERROR; a = a->ai_next) { + if (a->ai_family == AF_INET) { + listenFd = socket(a->ai_family, a->ai_socktype, a->ai_protocol); + listenAddr = a; + } + } + + if (listenFd == SOCKET_ERROR) { + freeaddrinfo(result); + return true; + } + +#ifdef __linux +#ifdef SO_REUSEPORT + if (options & REUSE_PORT) { + int optval = 1; + setsockopt(listenFd, SOL_SOCKET, SO_REUSEPORT, &optval, sizeof(optval)); + } +#endif +#endif + + int enabled = true; + setsockopt(listenFd, SOL_SOCKET, SO_REUSEADDR, &enabled, sizeof(enabled)); + + if (bind(listenFd, listenAddr->ai_addr, listenAddr->ai_addrlen) || ::listen(listenFd, 512)) { + ::close(listenFd); + freeaddrinfo(result); + return true; + } + + ListenData *listenData = new ListenData(nodeData); + listenData->sslContext = sslContext; + listenData->nodeData = nodeData; + + uv_poll_t *listenPoll = new uv_poll_t; + listenPoll->data = listenData; + + listenData->listenPoll = listenPoll; + listenData->sock = listenFd; + listenData->ssl = nullptr; + + uv_poll_init_socket(loop, listenPoll, listenFd); + uv_poll_start(listenPoll, UV_READABLE, accept_poll_cb); + + // should be vector of listen data! one group can have many listeners! + nodeData->user = listenData; + freeaddrinfo(result); + return false; + } +}; + +} + +#endif // NODE_UWS_H diff --git a/src/uWS/Socket.cpp b/src/uWS/Socket.cpp new file mode 100644 index 000000000..c35bbf8d9 --- /dev/null +++ b/src/uWS/Socket.cpp @@ -0,0 +1,28 @@ +#include "Socket.h" + +namespace uS { + +Socket::Address Socket::getAddress() +{ + uv_os_sock_t fd = getFd(); + + sockaddr_storage addr; + socklen_t addrLength = sizeof(addr); + if (getpeername(fd, (sockaddr *) &addr, &addrLength) == -1) { + return {0, "", ""}; + } + + static __thread char buf[INET6_ADDRSTRLEN]; + + if (addr.ss_family == AF_INET) { + sockaddr_in *ipv4 = (sockaddr_in *) &addr; + inet_ntop(AF_INET, &ipv4->sin_addr, buf, sizeof(buf)); + return {ntohs(ipv4->sin_port), buf, "IPv4"}; + } else { + sockaddr_in6 *ipv6 = (sockaddr_in6 *) &addr; + inet_ntop(AF_INET6, &ipv6->sin6_addr, buf, sizeof(buf)); + return {ntohs(ipv6->sin6_port), buf, "IPv6"}; + } +} + +} diff --git a/src/uWS/Socket.h b/src/uWS/Socket.h new file mode 100644 index 000000000..f05a6b8fc --- /dev/null +++ b/src/uWS/Socket.h @@ -0,0 +1,484 @@ +#ifndef SOCKET_UWS_H +#define SOCKET_UWS_H + +#include "Networking.h" + +namespace uS { + +class WIN32_EXPORT Socket { +protected: + uv_poll_t *p; + +public: + Socket(uv_poll_t *p) : p(p) { + + } + + uv_poll_cb getPollCallback() { +#ifdef USE_MICRO_UV + return p->get_poll_cb(); +#else + return p->poll_cb; +#endif + } + + void transfer(NodeData *nodeData, void (*cb)(uv_poll_t *)) { + SocketData *socketData = getSocketData(); + + nodeData->asyncMutex->lock(); + nodeData->transferQueue.push_back({new uv_poll_t, getFd(), socketData, getPollCallback(), cb}); + nodeData->asyncMutex->unlock(); + + if (socketData->nodeData->tid != nodeData->tid) { + uv_async_send(nodeData->async); + } else { + NodeData::asyncCallback(nodeData->async); + } + + uv_poll_stop(p); + uv_close(p, [](uv_handle_t *h) { + delete (uv_poll_t *) h; + }); + } + + static uv_poll_t *init(NodeData *nodeData, uv_os_sock_t fd, SSL *ssl) { + if (ssl) { + SSL_set_fd(ssl, fd); + SSL_set_mode(ssl, SSL_MODE_RELEASE_BUFFERS); + } + + SocketData *socketData = new SocketData(nodeData); + socketData->ssl = ssl; + socketData->poll = UV_READABLE; + + uv_poll_t *p = new uv_poll_t; + uv_poll_init_socket(nodeData->loop, p, fd); + p->data = socketData; + return p; + } + + uv_os_sock_t getFd() { +#ifdef _WIN32 + uv_os_sock_t fd; + uv_fileno((uv_handle_t *) p, (uv_os_fd_t *) &fd); + return fd; +#else +#ifdef USE_MICRO_UV + return p->fd; +#else + return p->io_watcher.fd; +#endif +#endif + } + + SocketData *getSocketData() { + return (SocketData *) p->data; + } + + NodeData *getNodeData(SocketData *socketData) { + return socketData->nodeData; + } + + void *getUserData() { + return getSocketData()->user; + } + + void setUserData(void *user) { + getSocketData()->user = user; + } + + operator uv_poll_t *() const { + return p; + } + + bool isClosed() { + return uv_is_closing(p); + } + + bool isShuttingDown() { + return getSocketData()->shuttingDown; + } + + struct Address { + unsigned int port; + const char *address; + const char *family; + }; + + Address getAddress(); + + void cork(int enable) { +#if defined(TCP_CORK) + // Linux & SmartOS have proper TCP_CORK + setsockopt(getFd(), IPPROTO_TCP, TCP_CORK, &enable, sizeof(int)); +#elif defined(TCP_NOPUSH) + // Mac OS X & FreeBSD have TCP_NOPUSH + setsockopt(getFd(), IPPROTO_TCP, TCP_NOPUSH, &enable, sizeof(int)); + if (!enable) { + // Tested on OS X, FreeBSD situation is unclear + ::send(getFd(), "", 0, MSG_NOSIGNAL); + } +#endif + } + + void setNoDelay(int enable) { + setsockopt(getFd(), IPPROTO_TCP, TCP_NODELAY, &enable, sizeof(int)); + } + + void shutdown() { + SSL *ssl = getSocketData()->ssl; + if (ssl) { + //todo: poll in/out - have the io_cb recall shutdown if failed + SSL_shutdown(ssl); + } else { + ::shutdown(getFd(), SHUT_WR); + } + } + + // clears user data! + template + void startTimeout(int timeoutMs = 15000) { + SocketData *socketData = getSocketData(); + NodeData *nodeData = getNodeData(socketData); + + uv_timer_t *timer = new uv_timer_t; + timer->data = p; + uv_timer_init(nodeData->loop, timer); + uv_timer_start(timer, [](uv_timer_t *timer) { + Socket s((uv_poll_t *) timer->data); + s.cancelTimeout(); + onTimeout(s); + }, timeoutMs, 0); + + socketData->user = timer; + } + + void cancelTimeout() { + uv_timer_t *timer = (uv_timer_t *) getUserData(); + if (timer) { + uv_timer_stop(timer); + uv_close(timer, [](uv_handle_t *handle) { + delete (uv_timer_t *) handle; + }); + getSocketData()->user = nullptr; + } + } + + template + static void ssl_io_cb(uv_poll_t *p, int status, int events) { + SocketData *socketData = Socket(p).getSocketData(); + NodeData *nodeData = socketData->nodeData; + SSL *ssl = socketData->ssl; + + if (status < 0) { + STATE::onEnd(p); + return; + } + + if (!socketData->messageQueue.empty() && ((events & UV_WRITABLE) || SSL_want(socketData->ssl) == SSL_READING)) { + Socket(p).cork(true); + while (true) { + SocketData::Queue::Message *messagePtr = socketData->messageQueue.front(); + int sent = SSL_write(socketData->ssl, messagePtr->data, messagePtr->length); + if (sent == (ssize_t) messagePtr->length) { + if (messagePtr->callback) { + messagePtr->callback(p, messagePtr->callbackData, false, messagePtr->reserved); + } + socketData->messageQueue.pop(); + if (socketData->messageQueue.empty()) { + if ((socketData->poll & UV_WRITABLE) && SSL_want(socketData->ssl) != SSL_WRITING) { + // todo, remove bit, don't set directly + socketData->poll = UV_READABLE; + uv_poll_start(p, UV_READABLE, Socket(p).getPollCallback()); + } + break; + } + } else if (sent <= 0) { + switch (SSL_get_error(socketData->ssl, sent)) { + case SSL_ERROR_WANT_READ: + break; + case SSL_ERROR_WANT_WRITE: + if ((socketData->poll & UV_WRITABLE) == 0) { + socketData->poll |= UV_WRITABLE; + uv_poll_start(p, socketData->poll, Socket(p).getPollCallback()); + } + break; + default: + STATE::onEnd(p); + return; + } + break; + } + } + Socket(p).cork(false); + } + + if (events & UV_READABLE) { + do { + int length = SSL_read(ssl, nodeData->recvBuffer, nodeData->recvLength); + if (length <= 0) { + switch (SSL_get_error(ssl, length)) { + case SSL_ERROR_WANT_READ: + break; + case SSL_ERROR_WANT_WRITE: + if ((socketData->poll & UV_WRITABLE) == 0) { + socketData->poll |= UV_WRITABLE; + uv_poll_start(p, socketData->poll, Socket(p).getPollCallback()); + } + break; + default: + STATE::onEnd(p); + return; + } + break; + } else { + STATE::onData(p, nodeData->recvBuffer, length); + if (Socket(p).isClosed() || Socket(p).isShuttingDown()) { + return; + } + } + } while (SSL_pending(ssl)); + } + } + + template + static void io_cb(uv_poll_t *p, int status, int events) { + SocketData *socketData = Socket(p).getSocketData(); + NodeData *nodeData = socketData->nodeData; + + if (status < 0) { + STATE::onEnd(p); + return; + } + + if (events & UV_WRITABLE) { + if (!socketData->messageQueue.empty() && (events & UV_WRITABLE)) { + Socket(p).cork(true); + while (true) { + SocketData::Queue::Message *messagePtr = socketData->messageQueue.front(); + ssize_t sent = ::send(Socket(p).getFd(), messagePtr->data, messagePtr->length, MSG_NOSIGNAL); + if (sent == (ssize_t) messagePtr->length) { + if (messagePtr->callback) { + messagePtr->callback(p, messagePtr->callbackData, false, messagePtr->reserved); + } + socketData->messageQueue.pop(); + if (socketData->messageQueue.empty()) { + // todo, remove bit, don't set directly + socketData->poll = UV_READABLE; + uv_poll_start(p, UV_READABLE, Socket(p).getPollCallback()); + break; + } + } else if (sent == SOCKET_ERROR) { + if (errno != EWOULDBLOCK) { + STATE::onEnd(p); + return; + } + break; + } else { + messagePtr->length -= sent; + messagePtr->data += sent; + break; + } + } + Socket(p).cork(false); + } + } + + if (events & UV_READABLE) { + int length = recv(Socket(p).getFd(), nodeData->recvBuffer, nodeData->recvLength, 0); + if (length > 0) { + STATE::onData(p, nodeData->recvBuffer, length); + } else if (length <= 0 || (length == SOCKET_ERROR && errno != EWOULDBLOCK)) { + STATE::onEnd(p); + } + } + + } + + template + void enterState(void *socketData) { + p->data = socketData; + if (Socket(p).getSocketData()->ssl) { + uv_poll_start(p, Socket(p).getSocketData()->poll, ssl_io_cb); + } else { + uv_poll_start(p, Socket(p).getSocketData()->poll, io_cb); + } + Socket(p).getSocketData()->poll = UV_READABLE; + } + + /*void setPoll(int poll = UV_READABLE) { + Socket(p).getSocketData()->poll = poll; + }*/ + + // does not change STATE, only poll for current state + /*void addPollBit(int events) { + uv_poll_start(p, p->flags, p->poll_cb); + } + + void removePollBit(int events) { + uv_poll_start(p, p->flags, p->poll_cb); + }*/ + + void close() { + uv_os_sock_t fd = getFd(); + uv_poll_stop(p); + ::close(fd); + + SSL *ssl = getSocketData()->ssl; + if (ssl) { + SSL_free(ssl); + } + + uv_close(p, [](uv_handle_t *h) { + delete (uv_poll_t *) h; + }); + } + + bool hasEmptyQueue() { + return getSocketData()->messageQueue.empty(); + } + + void enqueue(SocketData::Queue::Message *message) { + getSocketData()->messageQueue.push(message); + } + + SocketData::Queue::Message *allocMessage(size_t length, const char *data = 0) { + SocketData::Queue::Message *messagePtr = (SocketData::Queue::Message *) new char[sizeof(SocketData::Queue::Message) + length]; + messagePtr->length = length; + messagePtr->data = ((char *) messagePtr) + sizeof(SocketData::Queue::Message); + messagePtr->nextMessage = nullptr; + + if (data) { + memcpy((char *) messagePtr->data, data, messagePtr->length); + } + + return messagePtr; + } + + void freeMessage(SocketData::Queue::Message *message) { + delete [] (char *) message; + } + + void changePoll(SocketData *socketData) { + if (socketData->nodeData->tid != pthread_self()) { + socketData->nodeData->asyncMutex->lock(); + socketData->nodeData->changePollQueue.push_back(p); + socketData->nodeData->asyncMutex->unlock(); + uv_async_send(socketData->nodeData->async); + } else { + uv_poll_start(p, socketData->poll, getPollCallback()); + } + } + + bool write(SocketData::Queue::Message *message, bool &wasTransferred) { + ssize_t sent = 0; + SocketData *socketData = getSocketData(); + if (socketData->messageQueue.empty()) { + + if (socketData->ssl) { + sent = SSL_write(socketData->ssl, message->data, message->length); + if (sent == (ssize_t) message->length) { + wasTransferred = false; + return true; + } else if (sent < 0) { + switch (SSL_get_error(socketData->ssl, sent)) { + case SSL_ERROR_WANT_READ: + break; + case SSL_ERROR_WANT_WRITE: + if ((socketData->poll & UV_WRITABLE) == 0) { + socketData->poll |= UV_WRITABLE; + changePoll(socketData); + } + break; + default: + return false; + } + } + } else { + sent = ::send(getFd(), message->data, message->length, MSG_NOSIGNAL); + if (sent == (ssize_t) message->length) { + wasTransferred = false; + return true; + } else if (sent == SOCKET_ERROR) { + if (errno != EWOULDBLOCK) { + return false; + } + } else { + message->length -= sent; + message->data += sent; + } + + if ((socketData->poll & UV_WRITABLE) == 0) { + socketData->poll |= UV_WRITABLE; + changePoll(socketData); + } + } + } + socketData->messageQueue.push(message); + wasTransferred = true; + return true; + } + + template + void sendTransformed(const char *message, size_t length, void(*callback)(void *httpSocket, void *data, bool cancelled, void *reserved), void *callbackData, D transformData) { + size_t estimatedLength = T::estimate(message, length) + sizeof(uS::SocketData::Queue::Message); + + if (hasEmptyQueue()) { + if (estimatedLength <= uS::NodeData::preAllocMaxSize) { + int memoryLength = estimatedLength; + int memoryIndex = getSocketData()->nodeData->getMemoryBlockIndex(memoryLength); + + uS::SocketData::Queue::Message *messagePtr = (uS::SocketData::Queue::Message *) getSocketData()->nodeData->getSmallMemoryBlock(memoryIndex); + messagePtr->data = ((char *) messagePtr) + sizeof(uS::SocketData::Queue::Message); + messagePtr->length = T::transform(message, (char *) messagePtr->data, length, transformData); + + bool wasTransferred; + if (write(messagePtr, wasTransferred)) { + if (!wasTransferred) { + getSocketData()->nodeData->freeSmallMemoryBlock((char *) messagePtr, memoryIndex); + if (callback) { + callback(*this, callbackData, false, nullptr); + } + } else { + messagePtr->callback = callback; + messagePtr->callbackData = callbackData; + } + } else { + if (callback) { + callback(*this, callbackData, true, nullptr); + } + } + } else { + uS::SocketData::Queue::Message *messagePtr = allocMessage(estimatedLength - sizeof(uS::SocketData::Queue::Message)); + messagePtr->length = T::transform(message, (char *) messagePtr->data, length, transformData); + + bool wasTransferred; + if (write(messagePtr, wasTransferred)) { + if (!wasTransferred) { + freeMessage(messagePtr); + if (callback) { + callback(*this, callbackData, false, nullptr); + } + } else { + messagePtr->callback = callback; + messagePtr->callbackData = callbackData; + } + } else { + if (callback) { + callback(*this, callbackData, true, nullptr); + } + } + } + } else { + uS::SocketData::Queue::Message *messagePtr = allocMessage(estimatedLength - sizeof(uS::SocketData::Queue::Message)); + messagePtr->length = T::transform(message, (char *) messagePtr->data, length, transformData); + messagePtr->callback = callback; + messagePtr->callbackData = callbackData; + enqueue(messagePtr); + } + } +}; + +} + +#endif // SOCKET_UWS_H diff --git a/src/uWS/WebSocket.cpp b/src/uWS/WebSocket.cpp new file mode 100644 index 000000000..381ed2fa0 --- /dev/null +++ b/src/uWS/WebSocket.cpp @@ -0,0 +1,178 @@ +#include "WebSocket.h" +#include "Group.h" + +namespace uWS { + +template +void WebSocket::send(const char *message, size_t length, OpCode opCode, void(*callback)(void *webSocket, void *data, bool cancelled, void *reserved), void *callbackData) { + const int HEADER_LENGTH = WebSocketProtocol::LONG_MESSAGE_HEADER; + + struct TransformData { + OpCode opCode; + } transformData = {opCode}; + + struct WebSocketTransformer { + static size_t estimate(const char *data, size_t length) { + return length + HEADER_LENGTH; + } + + static size_t transform(const char *src, char *dst, size_t length, TransformData transformData) { + return WebSocketProtocol::formatMessage(dst, src, length, transformData.opCode, length, false); + } + }; + + sendTransformed((char *) message, length, callback, callbackData, transformData); +} + +template +typename WebSocket::PreparedMessage *WebSocket::prepareMessage(char *data, size_t length, OpCode opCode, bool compressed, void(*callback)(void *webSocket, void *data, bool cancelled, void *reserved)) { + PreparedMessage *preparedMessage = new PreparedMessage; + preparedMessage->buffer = new char[length + 10]; + preparedMessage->length = WebSocketProtocol::formatMessage(preparedMessage->buffer, data, length, opCode, length, compressed); + preparedMessage->references = 1; + preparedMessage->callback = callback; + return preparedMessage; +} + +template +typename WebSocket::PreparedMessage *WebSocket::prepareMessageBatch(std::vector &messages, std::vector &excludedMessages, OpCode opCode, bool compressed, void (*callback)(void *, void *, bool, void *)) +{ + // should be sent in! + size_t batchLength = 0; + for (size_t i = 0; i < messages.size(); i++) { + batchLength += messages[i].length(); + } + + PreparedMessage *preparedMessage = new PreparedMessage; + preparedMessage->buffer = new char[batchLength + 10 * messages.size()]; + + int offset = 0; + for (size_t i = 0; i < messages.size(); i++) { + offset += WebSocketProtocol::formatMessage(preparedMessage->buffer + offset, messages[i].data(), messages[i].length(), opCode, messages[i].length(), compressed); + } + preparedMessage->length = offset; + preparedMessage->references = 1; + preparedMessage->callback = callback; + return preparedMessage; +} + +// todo: see if this can be made a transformer instead +template +void WebSocket::sendPrepared(typename WebSocket::PreparedMessage *preparedMessage, void *callbackData) { + preparedMessage->references++; + void (*callback)(void *webSocket, void *userData, bool cancelled, void *reserved) = [](void *webSocket, void *userData, bool cancelled, void *reserved) { + PreparedMessage *preparedMessage = (PreparedMessage *) userData; + bool lastReference = !--preparedMessage->references; + + if (preparedMessage->callback) { + preparedMessage->callback(webSocket, reserved, cancelled, (void *) lastReference); + } + + if (lastReference) { + delete [] preparedMessage->buffer; + delete preparedMessage; + } + }; + + // candidate for fixed size pool allocator + int memoryLength = sizeof(uS::SocketData::Queue::Message); + int memoryIndex = getSocketData()->nodeData->getMemoryBlockIndex(memoryLength); + + uS::SocketData::Queue::Message *messagePtr = (uS::SocketData::Queue::Message *) getSocketData()->nodeData->getSmallMemoryBlock(memoryIndex); + messagePtr->data = preparedMessage->buffer; + messagePtr->length = preparedMessage->length; + + bool wasTransferred; + if (write(messagePtr, wasTransferred)) { + if (!wasTransferred) { + getSocketData()->nodeData->freeSmallMemoryBlock((char *) messagePtr, memoryIndex); + if (callback) { + callback(*this, preparedMessage, false, callbackData); + } + } else { + messagePtr->callback = callback; + messagePtr->callbackData = preparedMessage; + messagePtr->reserved = callbackData; + } + } else { + if (callback) { + callback(*this, preparedMessage, true, callbackData); + } + } +} + +template +void WebSocket::finalizeMessage(typename WebSocket::PreparedMessage *preparedMessage) { + if (!--preparedMessage->references) { + delete [] preparedMessage->buffer; + delete preparedMessage; + } +} + +template +void WebSocket::onData(uS::Socket s, char *data, int length) { + Data *webSocketData = (Data *) s.getSocketData(); + webSocketData->hasOutstandingPong = false; + if (!s.isShuttingDown()) { + s.cork(true); + ((WebSocketProtocol *) webSocketData)->consume(data, length, s); + if (!s.isClosed()) { + s.cork(false); + } + } +} + +template +void WebSocket::terminate() { + WebSocket::onEnd(*this); +} + +template +void WebSocket::close(int code, const char *message, size_t length) { + static const int MAX_CLOSE_PAYLOAD = 123; + length = std::min(MAX_CLOSE_PAYLOAD, length); + getGroup(*this)->removeWebSocket(*this); + getGroup(*this)->disconnectionHandler(*this, code, (char *) message, length); + getSocketData()->shuttingDown = true; + + // todo: using the shared timer in the group, we can skip creating a new timer per socket + // only this line and the one in Hub::connect uses the timeout feature + startTimeout::onEnd>(); + + char closePayload[MAX_CLOSE_PAYLOAD + 2]; + int closePayloadLength = WebSocketProtocol::formatClosePayload(closePayload, code, message, length); + send(closePayload, closePayloadLength, OpCode::CLOSE, [](void *p, void *data, bool cancelled, void *reserved) { + if (!cancelled) { + Socket((uv_poll_t *) p).shutdown(); + } + }); +} + +template +void WebSocket::onEnd(uS::Socket s) { + if (!s.isShuttingDown()) { + getGroup(s)->removeWebSocket(s); + getGroup(s)->disconnectionHandler(WebSocket(s), 1006, nullptr, 0); + } else { + s.cancelTimeout(); + } + + Data *webSocketData = (Data *) s.getSocketData(); + + s.close(); + + while (!webSocketData->messageQueue.empty()) { + uS::SocketData::Queue::Message *message = webSocketData->messageQueue.front(); + if (message->callback) { + message->callback(nullptr, message->callbackData, true, nullptr); + } + webSocketData->messageQueue.pop(); + } + + delete webSocketData; +} + +template struct WebSocket; +template struct WebSocket; + +} diff --git a/src/uWS/WebSocket.h b/src/uWS/WebSocket.h new file mode 100644 index 000000000..eacb6550f --- /dev/null +++ b/src/uWS/WebSocket.h @@ -0,0 +1,86 @@ +#ifndef WEBSOCKET_UWS_H +#define WEBSOCKET_UWS_H + +#include "WebSocketProtocol.h" +#include "Socket.h" + +namespace uWS { + +template +struct Group; + +template +struct WIN32_EXPORT WebSocket : protected uS::Socket { + struct Data : uS::SocketData, WebSocketProtocol { + std::string fragmentBuffer, controlBuffer; + enum CompressionStatus : char { + DISABLED, + ENABLED, + COMPRESSED_FRAME + } compressionStatus; + bool hasOutstandingPong = false; + + Data(bool perMessageDeflate, uS::SocketData *socketData) : uS::SocketData(*socketData) { + compressionStatus = perMessageDeflate ? CompressionStatus::ENABLED : CompressionStatus::DISABLED; + } + }; + + WebSocket(uS::Socket s = nullptr) : uS::Socket(s) { + + } + + struct PreparedMessage { + char *buffer; + size_t length; + int references; + void(*callback)(void *webSocket, void *data, bool cancelled, void *reserved); + }; + + using uS::Socket::getUserData; + using uS::Socket::setUserData; + using uS::Socket::getAddress; + using uS::Socket::Address; + + void transfer(Group *group) { + ((Group *) getSocketData()->nodeData)->removeWebSocket(p); + uS::Socket::transfer((uS::NodeData *) group, [](uv_poll_t *p) { + uS::Socket s(p); + ((Group *) s.getSocketData()->nodeData)->addWebSocket(s); + }); + } + + uv_poll_t *getPollHandle() const {return p;} + void terminate(); + void close(int code = 1000, const char *message = nullptr, size_t length = 0); + void ping(const char *message) {send(message, OpCode::PING);} + void send(const char *message, OpCode opCode = OpCode::TEXT) {send(message, strlen(message), opCode);} + void send(const char *message, size_t length, OpCode opCode, void(*callback)(void *webSocket, void *data, bool cancelled, void *reserved) = nullptr, void *callbackData = nullptr); + static PreparedMessage *prepareMessage(char *data, size_t length, OpCode opCode, bool compressed, void(*callback)(void *webSocket, void *data, bool cancelled, void *reserved) = nullptr); + static PreparedMessage *prepareMessageBatch(std::vector &messages, std::vector &excludedMessages, OpCode opCode, bool compressed, void(*callback)(void *webSocket, void *data, bool cancelled, void *reserved) = nullptr); + void sendPrepared(PreparedMessage *preparedMessage, void *callbackData = nullptr); + static void finalizeMessage(PreparedMessage *preparedMessage); + bool operator==(const WebSocket &other) const {return p == other.p;} + bool operator<(const WebSocket &other) const {return p < other.p;} + +private: + friend class uS::Socket; + template friend struct Group; + static void onData(uS::Socket s, char *data, int length); + static void onEnd(uS::Socket s); +}; + +} + +namespace std { + +template +struct hash> { + std::size_t operator()(const uWS::WebSocket &webSocket) const + { + return std::hash()(webSocket.getPollHandle()); + } +}; + +} + +#endif // WEBSOCKET_UWS_H diff --git a/src/uWS/WebSocketImpl.cpp b/src/uWS/WebSocketImpl.cpp new file mode 100644 index 000000000..e9d91ca98 --- /dev/null +++ b/src/uWS/WebSocketImpl.cpp @@ -0,0 +1,115 @@ +#include "Hub.h" + +namespace uWS { + +template +bool WebSocketProtocol::setCompressed(void *user) { + uS::Socket s((uv_poll_t *) user); + typename WebSocket::Data *webSocketData = (typename WebSocket::Data *) s.getSocketData(); + + if (webSocketData->compressionStatus == WebSocket::Data::CompressionStatus::ENABLED) { + webSocketData->compressionStatus = WebSocket::Data::CompressionStatus::COMPRESSED_FRAME; + return true; + } else { + return false; + } +} + +template +bool WebSocketProtocol::refusePayloadLength(void *user, int length) { + return length > 16777216; +} + +template +void WebSocketProtocol::forceClose(void *user) { + WebSocket((uv_poll_t *) user).terminate(); +} + +template +bool WebSocketProtocol::handleFragment(char *data, size_t length, unsigned int remainingBytes, int opCode, bool fin, void *user) { + uS::Socket s((uv_poll_t *) user); + typename WebSocket::Data *webSocketData = (typename WebSocket::Data *) s.getSocketData(); + + if (opCode < 3) { + if (!remainingBytes && fin && !webSocketData->fragmentBuffer.length()) { + if (webSocketData->compressionStatus == WebSocket::Data::CompressionStatus::COMPRESSED_FRAME) { + webSocketData->compressionStatus = WebSocket::Data::CompressionStatus::ENABLED; + Hub *hub = ((Group *) s.getSocketData()->nodeData)->hub; + data = hub->inflate(data, length); + if (!data) { + forceClose(user); + return true; + } + } + + if (opCode == 1 && !isValidUtf8((unsigned char *) data, length)) { + forceClose(user); + return true; + } + + ((Group *) s.getSocketData()->nodeData)->messageHandler(WebSocket(s), data, length, (OpCode) opCode); + if (s.isClosed() || s.isShuttingDown()) { + return true; + } + } else { + webSocketData->fragmentBuffer.append(data, length); + if (!remainingBytes && fin) { + length = webSocketData->fragmentBuffer.length(); + if (webSocketData->compressionStatus == WebSocket::Data::CompressionStatus::COMPRESSED_FRAME) { + webSocketData->compressionStatus = WebSocket::Data::CompressionStatus::ENABLED; + Hub *hub = ((Group *) s.getSocketData()->nodeData)->hub; + webSocketData->fragmentBuffer.append("...."); + data = hub->inflate((char *) webSocketData->fragmentBuffer.data(), length); + if (!data) { + forceClose(user); + return true; + } + } else { + data = (char *) webSocketData->fragmentBuffer.data(); + } + + if (opCode == 1 && !isValidUtf8((unsigned char *) data, length)) { + forceClose(user); + return true; + } + + ((Group *) s.getSocketData()->nodeData)->messageHandler(WebSocket(s), data, length, (OpCode) opCode); + if (s.isClosed() || s.isShuttingDown()) { + return true; + } + webSocketData->fragmentBuffer.clear(); + } + } + } else { + // todo: we don't need to buffer up in most cases! + webSocketData->controlBuffer.append(data, length); + if (!remainingBytes && fin) { + if (opCode == CLOSE) { + CloseFrame closeFrame = parseClosePayload((char *) webSocketData->controlBuffer.data(), webSocketData->controlBuffer.length()); + WebSocket(s).close(closeFrame.code, closeFrame.message, closeFrame.length); + return true; + } else { + if (opCode == PING) { + WebSocket(s).send(webSocketData->controlBuffer.data(), webSocketData->controlBuffer.length(), (OpCode) OpCode::PONG); + ((Group *) s.getSocketData()->nodeData)->pingHandler(WebSocket(s), (char *) webSocketData->controlBuffer.data(), webSocketData->controlBuffer.length()); + if (s.isClosed() || s.isShuttingDown()) { + return true; + } + } else if (opCode == PONG) { + ((Group *) s.getSocketData()->nodeData)->pongHandler(WebSocket(s), (char *) webSocketData->controlBuffer.data(), webSocketData->controlBuffer.length()); + if (s.isClosed() || s.isShuttingDown()) { + return true; + } + } + } + webSocketData->controlBuffer.clear(); + } + } + + return false; +} + +template class WebSocketProtocol; +template class WebSocketProtocol; + +} diff --git a/src/uWS/WebSocketProtocol.h b/src/uWS/WebSocketProtocol.h new file mode 100644 index 000000000..daefcc11d --- /dev/null +++ b/src/uWS/WebSocketProtocol.h @@ -0,0 +1,372 @@ +#ifndef WEBSOCKETPROTOCOL_UWS_H +#define WEBSOCKETPROTOCOL_UWS_H + +// we do need to include this for htobe64, should be moved from networking! +#include "Networking.h" + +#include +#include + +namespace uWS { + +enum OpCode : unsigned char { + TEXT = 1, + BINARY = 2, + CLOSE = 8, + PING = 9, + PONG = 10 +}; + +enum { + CLIENT, + SERVER +}; + +template +class WebSocketProtocol { +public: + static const int SHORT_MESSAGE_HEADER = isServer ? 6 : 2; + static const int MEDIUM_MESSAGE_HEADER = isServer ? 8 : 4; + static const int LONG_MESSAGE_HEADER = isServer ? 14 : 10; + +private: + typedef uint16_t frameFormat; + static inline bool isFin(frameFormat &frame) {return frame & 128;} + static inline unsigned char getOpCode(frameFormat &frame) {return frame & 15;} + static inline unsigned char payloadLength(frameFormat &frame) {return (frame >> 8) & 127;} + static inline bool rsv23(frameFormat &frame) {return frame & 48;} + static inline bool rsv1(frameFormat &frame) {return frame & 64;} + static inline bool getMask(frameFormat &frame) {return frame & 32768;} + + static inline void unmaskImprecise(char *dst, char *src, char *mask, unsigned int length) + { + for (unsigned int n = (length >> 2) + 1; n; n--) { + *(dst++) = *(src++) ^ mask[0]; + *(dst++) = *(src++) ^ mask[1]; + *(dst++) = *(src++) ^ mask[2]; + *(dst++) = *(src++) ^ mask[3]; + } + } + + static inline void unmaskImpreciseCopyMask(char *dst, char *src, char *maskPtr, unsigned int length) + { + char mask[4] = {maskPtr[0], maskPtr[1], maskPtr[2], maskPtr[3]}; + unmaskImprecise(dst, src, mask, length); + } + + static inline void rotateMask(unsigned int offset, char *mask) + { + char originalMask[4] = {mask[0], mask[1], mask[2], mask[3]}; + mask[(0 + offset) % 4] = originalMask[0]; + mask[(1 + offset) % 4] = originalMask[1]; + mask[(2 + offset) % 4] = originalMask[2]; + mask[(3 + offset) % 4] = originalMask[3]; + } + + static inline void unmaskInplace(char *data, char *stop, char *mask) + { + while (data < stop) { + *(data++) ^= mask[0]; + *(data++) ^= mask[1]; + *(data++) ^= mask[2]; + *(data++) ^= mask[3]; + } + } + + enum state_t { + READ_HEAD, + READ_MESSAGE + }; + + enum send_state_t { + SND_CONTINUATION = 1, + SND_NO_FIN = 2, + SND_COMPRESSED = 64 + }; + + template + inline bool consumeMessage(T payLength, char *&src, unsigned int &length, frameFormat frame, void *user) { + if (getOpCode(frame)) { + if (opStack == 1 || (!lastFin && getOpCode(frame) < 2)) { + forceClose(user); + return true; + } + opCode[(unsigned char) ++opStack] = (OpCode) getOpCode(frame); + } else if (opStack == -1) { + forceClose(user); + return true; + } + lastFin = isFin(frame); + + if (refusePayloadLength(user, payLength)) { + forceClose(user); + return true; + } + + if (int(payLength) <= int(length - MESSAGE_HEADER)) { + if (isServer) { + unmaskImpreciseCopyMask(src, src + MESSAGE_HEADER, src + MESSAGE_HEADER - 4, payLength); + if (handleFragment(src, payLength, 0, opCode[(unsigned char) opStack], isFin(frame), user)) { + return true; + } + } else { + if (handleFragment(src + MESSAGE_HEADER, payLength, 0, opCode[(unsigned char) opStack], isFin(frame), user)) { + return true; + } + } + + if (isFin(frame)) { + opStack--; + } + + src += payLength + MESSAGE_HEADER; + length -= payLength + MESSAGE_HEADER; + spillLength = 0; + return false; + } else { + spillLength = 0; + state = READ_MESSAGE; + remainingBytes = payLength - length + MESSAGE_HEADER; + + if (isServer) { + memcpy(mask, src + MESSAGE_HEADER - 4, 4); + unmaskImprecise(src, src + MESSAGE_HEADER, mask, length); + rotateMask(4 - (length - MESSAGE_HEADER) % 4, mask); + } else { + src += MESSAGE_HEADER; + } + handleFragment(src, length - MESSAGE_HEADER, remainingBytes, opCode[(unsigned char) opStack], isFin(frame), user); + return true; + } + } + + inline bool consumeContinuation(char *&src, unsigned int &length, void *user) { + if (remainingBytes <= length) { + if (isServer) { + int n = remainingBytes >> 2; + unmaskInplace(src, src + n * 4, mask); + for (int i = 0, s = remainingBytes % 4; i < s; i++) { + src[n * 4 + i] ^= mask[i]; + } + } + + if (handleFragment(src, remainingBytes, 0, opCode[(unsigned char) opStack], lastFin, user)) { + return false; + } + + if (lastFin) { + opStack--; + } + + src += remainingBytes; + length -= remainingBytes; + state = READ_HEAD; + return true; + } else { + if (isServer) { + unmaskInplace(src, src + ((length >> 2) + 1) * 4, mask); + } + + remainingBytes -= length; + if (handleFragment(src, length, remainingBytes, opCode[(unsigned char) opStack], lastFin, user)) { + return false; + } + + if (isServer && length % 4) { + rotateMask(4 - (length % 4), mask); + } + return false; + } + } + + // this can hold two states (1 bit) + // this can hold length of spill (up to 16 = 4 bit) + unsigned char state = READ_HEAD; + unsigned char spillLength = 0; // remove this! + char opStack = -1; // remove this too + char lastFin = true; // hold in state! + unsigned char spill[LONG_MESSAGE_HEADER - 1]; + unsigned int remainingBytes = 0; // denna kan hålla spillLength om state är READ_HEAD, och remainingBytes när state är annat? + char mask[isServer ? 4 : 1]; + OpCode opCode[2]; + +public: + WebSocketProtocol() { + + } + + // Based on utf8_check.c by Markus Kuhn, 2005 + // https://www.cl.cam.ac.uk/~mgk25/ucs/utf8_check.c + // Optimized for predominantly 7-bit content, 2016 + static bool isValidUtf8(unsigned char *s, size_t length) + { + for (unsigned char *e = s + length; s != e; ) { + if (s + 4 <= e && ((*(uint32_t *) s) & 0x80808080) == 0) { + s += 4; + } else { + while (!(*s & 0x80)) { + if (++s == e) { + return true; + } + } + + if ((s[0] & 0x60) == 0x40) { + if (s + 1 >= e || (s[1] & 0xc0) != 0x80 || (s[0] & 0xfe) == 0xc0) { + return false; + } + s += 2; + } else if ((s[0] & 0xf0) == 0xe0) { + if (s + 2 >= e || (s[1] & 0xc0) != 0x80 || (s[2] & 0xc0) != 0x80 || + (s[0] == 0xe0 && (s[1] & 0xe0) == 0x80) || (s[0] == 0xed && (s[1] & 0xe0) == 0xa0)) { + return false; + } + s += 3; + } else if ((s[0] & 0xf8) == 0xf0) { + if (s + 3 >= e || (s[1] & 0xc0) != 0x80 || (s[2] & 0xc0) != 0x80 || (s[3] & 0xc0) != 0x80 || + (s[0] == 0xf0 && (s[1] & 0xf0) == 0x80) || (s[0] == 0xf4 && s[1] > 0x8f) || s[0] > 0xf4) { + return false; + } + s += 4; + } else { + return false; + } + } + } + return true; + } + + struct CloseFrame { + uint16_t code; + char *message; + size_t length; + }; + + static inline CloseFrame parseClosePayload(char *src, size_t length) { + CloseFrame cf = {}; + if (length >= 2) { + memcpy(&cf.code, src, 2); + cf = {ntohs(cf.code), src + 2, length - 2}; + if (cf.code < 1000 || cf.code > 4999 || (cf.code > 1011 && cf.code < 4000) || + (cf.code >= 1004 && cf.code <= 1006) || !isValidUtf8((unsigned char *) cf.message, cf.length)) { + return {}; + } + } + return cf; + } + + static inline size_t formatClosePayload(char *dst, uint16_t code, const char *message, size_t length) { + if (code) { + code = htons(code); + memcpy(dst, &code, 2); + memcpy(dst + 2, message, length); + return length + 2; + } + return 0; + } + + static inline size_t formatMessage(char *dst, const char *src, size_t length, OpCode opCode, size_t reportedLength, bool compressed) { + size_t messageLength; + size_t headerLength; + if (reportedLength < 126) { + headerLength = 2; + dst[1] = reportedLength; + } else if (reportedLength <= UINT16_MAX) { + headerLength = 4; + dst[1] = 126; + *((uint16_t *) &dst[2]) = htons(reportedLength); + } else { + headerLength = 10; + dst[1] = 127; + *((uint64_t *) &dst[2]) = htobe64(reportedLength); + } + + int flags = 0; + dst[0] = (flags & SND_NO_FIN ? 0 : 128) | (compressed ? SND_COMPRESSED : 0); + if (!(flags & SND_CONTINUATION)) { + dst[0] |= opCode; + } + + char mask[4]; + if (!isServer) { + dst[1] |= 0x80; + uint32_t random = rand(); + memcpy(mask, &random, 4); + memcpy(dst + headerLength, &random, 4); + headerLength += 4; + } + + messageLength = headerLength + length; + memcpy(dst + headerLength, src, length); + + if (!isServer) { + + // overwrites up to 3 bytes outside of the given buffer! + //WebSocketProtocol::unmaskInplace(dst + headerLength, dst + headerLength + length, mask); + + // this is not optimal + char *start = dst + headerLength; + char *stop = start + length; + int i = 0; + while (start != stop) { + (*start++) ^= mask[i++ % 4]; + } + } + return messageLength; + } + + void consume(char *src, unsigned int length, void *user) { + if (spillLength) { + src -= spillLength; + length += spillLength; + memcpy(src, spill, spillLength); + } + if (state == READ_HEAD) { + parseNext: + for (frameFormat frame; length >= SHORT_MESSAGE_HEADER; ) { + memcpy(&frame, src, sizeof(frameFormat)); + + // invalid reserved bits / invalid opcodes / invalid control frames / set compressed frame + if ((rsv1(frame) && !setCompressed(user)) || rsv23(frame) || (getOpCode(frame) > 2 && getOpCode(frame) < 8) || + getOpCode(frame) > 10 || (getOpCode(frame) > 2 && (!isFin(frame) || payloadLength(frame) > 125))) { + forceClose(user); + return; + } + + if (payloadLength(frame) < 126) { + if (consumeMessage(payloadLength(frame), src, length, frame, user)) { + return; + } + } else if (payloadLength(frame) == 126) { + if (length < MEDIUM_MESSAGE_HEADER) { + break; + } else if(consumeMessage(ntohs(*(uint16_t *) &src[2]), src, length, frame, user)) { + return; + } + } else if (length < LONG_MESSAGE_HEADER) { + break; + } else if (consumeMessage(be64toh(*(uint64_t *) &src[2]), src, length, frame, user)) { + return; + } + } + if (length) { + memcpy(spill, src, length); + spillLength = length; + } + } else if (consumeContinuation(src, length, user)) { + goto parseNext; + } + } + + static const int CONSUME_POST_PADDING = 18; + static const int CONSUME_PRE_PADDING = LONG_MESSAGE_HEADER - 1; + + // events to be implemented by application (can't be inline currently) + bool refusePayloadLength(void *user, int length); + bool setCompressed(void *user); + void forceClose(void *user); + bool handleFragment(char *data, size_t length, unsigned int remainingBytes, int opCode, bool fin, void *user); +}; + +} + +#endif // WEBSOCKETPROTOCOL_UWS_H diff --git a/src/uWS/uUV.cpp b/src/uWS/uUV.cpp new file mode 100644 index 000000000..1c72c6c70 --- /dev/null +++ b/src/uWS/uUV.cpp @@ -0,0 +1,375 @@ +#include "uUV.h" + +#ifndef USE_MICRO_UV +void uv_close(uv_async_t *handle, uv_close_cb cb) { + uv_close((uv_handle_t *) handle, cb); +} +void uv_close(uv_idle_t *handle, uv_close_cb cb) { + uv_close((uv_handle_t *) handle, cb); +} +void uv_close(uv_poll_t *handle, uv_close_cb cb) { + uv_close((uv_handle_t *) handle, cb); +} +void uv_close(uv_timer_t *handle, uv_close_cb cb) { + uv_close((uv_handle_t *) handle, cb); +} + +bool uv_is_closing(uv_async_t *handle) { + return uv_is_closing((uv_handle_t *) handle); +} +bool uv_is_closing(uv_idle_t *handle) { + return uv_is_closing((uv_handle_t *) handle); +} +bool uv_is_closing(uv_poll_t *handle) { + return uv_is_closing((uv_handle_t *) handle); +} +bool uv_is_closing(uv_timer_t *handle) { + return uv_is_closing((uv_handle_t *) handle); +} +#else + +#include + +//namespace uUV { + +uv_loop_t *loops[128]; +int loopHead = 0; + +#define CALLBACK_ARR_SIZE 128 +uv_async_cb async_callbacks[CALLBACK_ARR_SIZE]; +int asyncCbHead = 0; +uv_idle_cb idle_callbacks[CALLBACK_ARR_SIZE]; +int idleCbHead = 0; +uv_poll_cb poll_callbacks[CALLBACK_ARR_SIZE]; +int pollCbHead = 0; +uv_timer_cb timer_callbacks[CALLBACK_ARR_SIZE]; +int timerCbHead = 0; + +uv_loop_t *uv_handle_t::get_loop() const { + return loops[loopIndex]; +} + +inline uv_loop_t *uv_loop_helper() { + uv_loop_t *loop = new uv_loop_t; + loop->efd = epoll_create(1); + loop->index = loopHead++; + loop->numEvents = 0; + + loop->asyncWakeupFd = eventfd(0, EFD_SEMAPHORE | EFD_NONBLOCK); + struct epoll_event wakeupEvents; + wakeupEvents.events = EPOLLHUP | EPOLLERR | EPOLLIN | EPOLLET; + wakeupEvents.data.ptr = nullptr; + epoll_ctl(loop->efd, EPOLL_CTL_ADD, loop->asyncWakeupFd, &wakeupEvents); + + loops[loop->index] = loop; + return loop; +} + +inline void init() { + uv_loop_helper(); +} + +uv_loop_t *uv_default_loop() { + if (!loopHead) { + init(); + } + return loops[0]; +} + +uv_loop_t *uv_loop_new() { + if (!loopHead) + init(); + return uv_loop_helper(); +} + +void uv_loop_delete(uv_loop_t *loop) { + epoll_ctl(loop->efd, EPOLL_CTL_DEL, loop->asyncWakeupFd, nullptr); + close(loop->efd); + loops[loop->index] = nullptr; + delete loop; +} + +void uv_async_init(uv_loop_t *loop, uv_async_t *async, uv_async_cb cb) { + async->loopIndex = loop->index; + loop->numEvents++; + + async->cbIndex = asyncCbHead; + for (int i = 0; i < asyncCbHead; i++) { + if (async_callbacks[i] == cb) { + async->cbIndex = i; + break; + } + } + if (async->cbIndex == asyncCbHead) { + async_callbacks[asyncCbHead++] = cb; + } + + loop->asyncs.insert(async); +} + +void uv_async_send(uv_async_t *async) { + uv_loop_t *loop = async->get_loop(); + loop->async_mutex.lock(); + uint64_t val = 1; + int w = write(loop->asyncWakeupFd, &val, sizeof(val)); + async->run = true; + loop->async_mutex.unlock(); +} + +void uv_close(uv_async_t *handle, uv_close_cb cb) { + uv_loop_t *loop = handle->get_loop(); + + loop->asyncs.erase((uv_async_t *) handle); + + handle->flags |= UV_HANDLE_CLOSING; + loop->closing.push_back({(uv_handle_t *) handle, cb}); +} + +bool uv_is_closing(uv_async_t *handle) { + return handle->flags & (UV_HANDLE_CLOSING | UV_HANDLE_CLOSED); +} + +void uv_idle_init(uv_loop_t *loop, uv_idle_t *idle) { + idle->loopIndex = loop->index; + loop->numEvents++; +} + +void uv_idle_start(uv_idle_t *idle, uv_idle_cb cb) { + idle->cbIndex = idleCbHead; + for (int i = 0; i < idleCbHead; i++) { + if (idle_callbacks[i] == cb) { + idle->cbIndex = i; + break; + } + } + if (idle->cbIndex == idleCbHead) { + idle_callbacks[idleCbHead++] = cb; + } + + idle->get_loop()->idlers.insert(idle); +} + +void uv_idle_stop(uv_idle_t *idle) { + idle->get_loop()->idlers.erase(idle); +} + +void uv_close(uv_idle_t *handle, uv_close_cb cb) { + uv_loop_t *loop = handle->get_loop(); + handle->flags |= UV_HANDLE_CLOSING; + loop->closing.push_back({(uv_handle_t *) handle, cb}); +} + +bool uv_is_closing(uv_idle_t *handle) { + return handle->flags & (UV_HANDLE_CLOSING | UV_HANDLE_CLOSED); +} + +uv_poll_cb uv_poll_t::get_poll_cb() const { + return poll_callbacks[cbIndex]; +} + +int uv_poll_init_socket(uv_loop_t *loop, uv_poll_t *poll, uv_os_sock_t socket) { + int flags = fcntl(socket, F_GETFL, 0); + if (flags == -1) { + return -1; + } + flags |= O_NONBLOCK; + flags = fcntl (socket, F_SETFL, flags); + if (flags == -1) { + return -1; + } + + poll->loopIndex = loop->index; + poll->fd = socket; + poll->event.events = 0; + poll->event.data.ptr = poll; + loop->numEvents++; + return epoll_ctl(loop->efd, EPOLL_CTL_ADD, socket, &poll->event); +} + +int uv_poll_start(uv_poll_t *poll, int events, uv_poll_cb cb) { + poll->event.events = events; + poll->cbIndex = pollCbHead; + for (int i = 0; i < pollCbHead; i++) { + if (poll_callbacks[i] == cb) { + poll->cbIndex = i; + break; + } + } + if (poll->cbIndex == pollCbHead) { + poll_callbacks[pollCbHead++] = cb; + } + return epoll_ctl(poll->get_loop()->efd, EPOLL_CTL_MOD, poll->fd, &poll->event); +} + +int uv_poll_stop(uv_poll_t *poll) { + return epoll_ctl(poll->get_loop()->efd, EPOLL_CTL_DEL, poll->fd, &poll->event); +} + +void uv_close(uv_poll_t *handle, uv_close_cb cb) { + uv_loop_t *loop = handle->get_loop(); + + uv_poll_t *poll = (uv_poll_t *) handle; + poll->fd = -1; + + loop->closing.push_back({(uv_handle_t *) handle, cb}); +} + +bool uv_is_closing(uv_poll_t *handle) { + return handle->fd == -1; +} + +int uv_fileno(uv_poll_t *handle) { + return handle->fd; +} + +void uv_timer_init(uv_loop_t *loop, uv_timer_t *timer) { + timer->loopIndex = loop->index; + loop->numEvents++; + loop->timepoint = std::chrono::system_clock::now(); +} + +void uv_timer_enqueue(uv_timer_t *timer, int timeout) { + timer->timepoint = timer->get_loop()->timepoint + std::chrono::milliseconds(timeout); + // sort timers from farthest to soonest so we can pop from back in O(1) + uv_loop_t *loop = timer->get_loop(); + if (loop->timers.size() && timeout) { + loop->timers.insert( + std::upper_bound(loop->timers.begin(), loop->timers.end(), timer, [](uv_timer_t* a, uv_timer_t* b) { + return a->timepoint > b->timepoint; + }), + timer + ); + } + else + loop->timers.push_back(timer); +} +void uv_timer_start(uv_timer_t *timer, uv_timer_cb cb, int timeout, int repeat) { + timer->cbIndex = timerCbHead; + for (int i = 0; i < timerCbHead; i++) { + if (timer_callbacks[i] == cb) { + timer->cbIndex = i; + break; + } + } + if (timer->cbIndex == timerCbHead) { + timer_callbacks[timerCbHead++] = cb; + } + + timer->repeat = repeat; + timer->flags = UV_HANDLE_RUNNING; + uv_timer_enqueue(timer, timeout); +} + +void uv_timer_stop(uv_timer_t *timer) { + timer->flags &= ~UV_HANDLE_RUNNING; + uv_loop_t *loop = timer->get_loop(); + for (int i = 0; i < loop->timers.size(); ++i) + if (loop->timers[i] == timer) + { + loop->timers.erase(loop->timers.begin() + i); + break; + } +} + +void uv_close(uv_timer_t *handle, uv_close_cb cb) { + uv_loop_t *loop = handle->get_loop(); + handle->flags |= UV_HANDLE_CLOSING; + loop->closing.push_back({(uv_handle_t *) handle, cb}); +} + +bool uv_is_closing(uv_timer_t *handle) { + return handle->flags & (UV_HANDLE_CLOSING | UV_HANDLE_CLOSED); +} + +void uv_run(uv_loop_t *loop, int mode) { + loop->timepoint = std::chrono::system_clock::now(); + signal(SIGPIPE, SIG_IGN); + int iter = 0; + while (loop->numEvents) { + ++iter; + // Close any events that are ready to close + if (loop->closing.size()) { + // Make a copy so that its ok to call uv_close in the callbacks + std::vector> closingCopy = loop->closing; + loop->closing.clear(); + + for (std::pair c : closingCopy) { + loop->numEvents--; + c.first->flags &= ~UV_HANDLE_CLOSING; + c.first->flags |= UV_HANDLE_CLOSED; + c.second(c.first); + } + } + + // Wait for events to be ready + loop->timepoint = std::chrono::system_clock::now(); + int delay = -1; + if (loop->idlers.size()) { + delay = 0; + } else if (loop->timers.size()) { + delay = std::max(std::chrono::duration_cast(loop->timers.back()->timepoint - loop->timepoint).count(), 0); + } + epoll_event readyEvents[1024]; + int numFdReady = epoll_wait(loop->efd, readyEvents, 1024, delay); + + // Handle polling events + for (int i = 0; i < numFdReady; i++) { + uv_poll_t *poll = (uv_poll_t *) readyEvents[i].data.ptr; + if (poll) { + int status = -bool(readyEvents[i].events & EPOLLERR); + poll_callbacks[poll->cbIndex](poll, status, readyEvents[i].events); + } else { // async wakeup event has nullptr + loop->async_mutex.lock(); + uint64_t val; + int r = read(loop->asyncWakeupFd, &val, sizeof(val)); + loop->async_mutex.unlock(); + } + } + + // Handle async events + if (loop->asyncs.size()) { + std::vector readyAsyncs; + // Find ready asyncs first so we can safely modify the set inside callbacks + loop->async_mutex.lock(); + for (uv_async_t *async : loop->asyncs) + if (async->run) + { + async->run = false; + readyAsyncs.push_back(async); + } + loop->async_mutex.unlock(); + for (uv_async_t *async : readyAsyncs) + async_callbacks[async->cbIndex](async); + } + + // Handle idle events + if (loop->idlers.size()) { + std::unordered_set readyIdlers = loop->idlers; + for (uv_idle_t *idle : readyIdlers) + idle_callbacks[idle->cbIndex](idle); + } + + // Handle timer events + if (loop->timers.size()) { + loop->timepoint = std::chrono::system_clock::now(); + // Copy ready timers to separate vector so callbacks can safely modify the original + std::vector readyTimers; + while (loop->timers.size() && std::chrono::duration_cast(loop->timers.back()->timepoint - loop->timepoint).count() <= 0) { + readyTimers.push_back(loop->timers.back()); + loop->timers.pop_back(); + } + for (uv_timer_t* timer : readyTimers) { + if (timer->flags & UV_HANDLE_RUNNING) { + timer_callbacks[timer->cbIndex](timer); + // Have to check for running again in case timer was stopped in callback + if (timer->repeat && timer->flags & UV_HANDLE_RUNNING) { + uv_timer_enqueue(timer, timer->repeat); + } + } + } + } + } +} + +//} +#endif diff --git a/src/uWS/uUV.h b/src/uWS/uUV.h new file mode 100644 index 000000000..ea3f749ac --- /dev/null +++ b/src/uWS/uUV.h @@ -0,0 +1,162 @@ +#ifndef UUV_H +#define UUV_H + +#ifndef USE_MICRO_UV +#include +void uv_close(uv_async_t *handle, uv_close_cb cb); +void uv_close(uv_idle_t *handle, uv_close_cb cb); +void uv_close(uv_poll_t *handle, uv_close_cb cb); +void uv_close(uv_timer_t *handle, uv_close_cb cb); + +bool uv_is_closing(uv_async_t *handle); +bool uv_is_closing(uv_idle_t *handle); +bool uv_is_closing(uv_poll_t *handle); +bool uv_is_closing(uv_timer_t *handle); +#else + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define UV_VERSION_MINOR 3 + +#include +#include +#include +#include + +//namespace uUV { + +struct uv_handle_t; +struct uv_loop_t; + +struct uv_async_t; +struct uv_idle_t; +struct uv_poll_t; +struct uv_timer_t; + +// Error codes +const int UV_EINVAL = -1; +const int UV_EBADF = -2; + +// Handle flags +const unsigned char UV_HANDLE_CLOSING = 0x01 << 0; +const unsigned char UV_HANDLE_CLOSED = 0x01 << 1; +const unsigned char UV_HANDLE_RUNNING = 0x01 << 2; + +const int UV_WRITABLE = EPOLLOUT; +const int UV_READABLE = EPOLLIN | EPOLLHUP; +const int UV_DISCONNECT = 4; // Not sure which epoll events correspond to disconnect. This value is taken from libuv source code instead. +const int UV_RUN_DEFAULT = 0; +typedef int uv_os_sock_t; +typedef void (*uv_close_cb)(uv_handle_t *handle); +typedef void (*uv_async_cb)(uv_async_t *handle); +typedef void (*uv_idle_cb)(uv_idle_t *handle); +typedef void (*uv_poll_cb)(uv_poll_t *poll, int status, int events); +typedef void (*uv_timer_cb)(uv_timer_t *handle); + +/* + * All struct members are specifically ordered to optimize packing! Be careful + * when making changes. + * + * Member size reference: + * std::mutex .................................... 40 bytes + * std::chrono::system_clock::time_point ......... 8 bytes + * std::vector ................................... 24 bytes + * std::unordered_set ............................ 56 bytes + * epoll_event ................................... 12 bytes + */ + +// 16 bytes +struct uv_handle_t { + void *data; + unsigned char flags = 0; + unsigned char loopIndex; + + uv_loop_t *get_loop() const; +}; + +// 224 bytes +struct uv_loop_t { + std::unordered_set asyncs; + std::unordered_set idlers; + std::vector timers; + std::vector> closing; + std::mutex async_mutex; + std::chrono::system_clock::time_point timepoint; + int efd; + int index; + int numEvents; + int asyncWakeupFd; +}; + +uv_loop_t *uv_default_loop(); +uv_loop_t *uv_loop_new(); +void uv_loop_delete(uv_loop_t *loop); + + +// 16 bytes +struct uv_async_t : uv_handle_t { + unsigned char cbIndex; + bool run = false; +}; + +void uv_async_init(uv_loop_t *loop, uv_async_t *async, uv_async_cb cb); +void uv_async_send(uv_async_t *async); +void uv_close(uv_async_t *handle, uv_close_cb cb); +bool uv_is_closing(uv_async_t *handle); + +// 16 bytes +struct uv_idle_t : uv_handle_t { + unsigned char cbIndex; +}; + +void uv_idle_init(uv_loop_t *loop, uv_idle_t *idle); +void uv_idle_start(uv_idle_t *idle, uv_idle_cb cb); +void uv_idle_stop(uv_idle_t *idle); +void uv_close(uv_idle_t *handle, uv_close_cb cb); +bool uv_is_closing(uv_idle_t *handle); + +// 32 bytes +struct uv_poll_t : uv_handle_t { + unsigned char cbIndex; + int fd; + epoll_event event; + + uv_poll_cb get_poll_cb() const; +}; + +int uv_poll_init_socket(uv_loop_t *loop, uv_poll_t *poll, uv_os_sock_t socket); +int uv_poll_start(uv_poll_t *poll, int events, uv_poll_cb cb); +int uv_poll_stop(uv_poll_t *poll); +void uv_close(uv_poll_t *handle, uv_close_cb cb); +bool uv_is_closing(uv_poll_t *handle); +int uv_fileno(uv_poll_t *handle); + +// 24 bytes +struct uv_timer_t : uv_handle_t { + unsigned char cbIndex; + int repeat; + std::chrono::system_clock::time_point timepoint; +}; + +void uv_timer_init(uv_loop_t *loop, uv_timer_t *timer); +void uv_timer_start(uv_timer_t *timer, uv_timer_cb cb, int timeout, int repeat); +void uv_timer_stop(uv_timer_t *timer); +void uv_close(uv_timer_t *handle, uv_close_cb cb); +bool uv_is_closing(uv_timer_t *handle); + +void uv_run(uv_loop_t *loop, int mode); + +//} // namespace uUV + +#endif +#endif // UUV_H diff --git a/src/uWS/uWS.h b/src/uWS/uWS.h new file mode 100644 index 000000000..40a0e403d --- /dev/null +++ b/src/uWS/uWS.h @@ -0,0 +1,6 @@ +#ifndef UWS_UWS_H +#define UWS_UWS_H + +#include "Hub.h" + +#endif // UWS_UWS_H