/* * IXSocketServer.cpp * Author: Benjamin Sergeant * Copyright (c) 2018 Machine Zone, Inc. All rights reserved. */ #include "IXSocketServer.h" #include "IXNetSystem.h" #include "IXSelectInterrupt.h" #include "IXSelectInterruptFactory.h" #include "IXSetThreadName.h" #include "IXSocket.h" #include "IXSocketConnect.h" #include "IXSocketFactory.h" #include <assert.h> #include <sstream> #include <stdio.h> #include <string.h> namespace ix { const int SocketServer::kDefaultPort(8080); const std::string SocketServer::kDefaultHost(""); const int SocketServer::kDefaultTcpBacklog(5); const size_t SocketServer::kDefaultMaxConnections(128); const int SocketServer::kDefaultAddressFamily(AF_INET); SocketServer::SocketServer( int port, const std::string& host, int backlog, size_t maxConnections, int addressFamily) : _port(port) , _host(host) , _backlog(backlog) , _maxConnections(maxConnections) , _addressFamily(addressFamily) , _serverFd(-1) , _stop(false) , _stopGc(false) , _connectionStateFactory(&ConnectionState::createConnectionState) , _acceptSelectInterrupt(createSelectInterrupt()) { } SocketServer::~SocketServer() { stop(); } void SocketServer::logError(const std::string& str) { std::lock_guard<std::mutex> lock(_logMutex); fprintf(stderr, "%s\n", str.c_str()); } void SocketServer::logInfo(const std::string& str) { std::lock_guard<std::mutex> lock(_logMutex); fprintf(stdout, "%s\n", str.c_str()); } std::pair<bool, std::string> SocketServer::listen() { std::string acceptSelectInterruptInitErrorMsg; if (!_acceptSelectInterrupt->init(acceptSelectInterruptInitErrorMsg)) { std::stringstream ss; ss << "SocketServer::listen() error in SelectInterrupt::init: " << acceptSelectInterruptInitErrorMsg; return std::make_pair(false, ss.str()); } if (_addressFamily != AF_INET && _addressFamily != AF_INET6) { std::string errMsg("SocketServer::listen() AF_INET and AF_INET6 are currently " "the only supported address families"); return std::make_pair(false, errMsg); } // Get a socket for accepting connections. if ((_serverFd = socket(_addressFamily, SOCK_STREAM, 0)) < 0) { std::stringstream ss; ss << "SocketServer::listen() error creating socket): " << strerror(Socket::getErrno()); return std::make_pair(false, ss.str()); } // Make that socket reusable. (allow restarting this server at will) int enable = 1; if (setsockopt(_serverFd, SOL_SOCKET, SO_REUSEADDR, (char*) &enable, sizeof(enable)) < 0) { std::stringstream ss; ss << "SocketServer::listen() error calling setsockopt(SO_REUSEADDR) " << "at address " << _host << ":" << _port << " : " << strerror(Socket::getErrno()); Socket::closeSocket(_serverFd); return std::make_pair(false, ss.str()); } if (_addressFamily == AF_INET) { struct sockaddr_in server; server.sin_family = _addressFamily; server.sin_port = htons(_port); if (inet_pton(_addressFamily, _host.c_str(), &server.sin_addr.s_addr) <= 0) { std::stringstream ss; ss << "SocketServer::listen() error calling inet_pton " << "at address " << _host << ":" << _port << " : " << strerror(Socket::getErrno()); Socket::closeSocket(_serverFd); return std::make_pair(false, ss.str()); } // Bind the socket to the server address. if (bind(_serverFd, (struct sockaddr*) &server, sizeof(server)) < 0) { std::stringstream ss; ss << "SocketServer::listen() error calling bind " << "at address " << _host << ":" << _port << " : " << strerror(Socket::getErrno()); Socket::closeSocket(_serverFd); return std::make_pair(false, ss.str()); } } else // AF_INET6 { struct sockaddr_in6 server; server.sin6_family = _addressFamily; server.sin6_port = htons(_port); if (inet_pton(_addressFamily, _host.c_str(), &server.sin6_addr) <= 0) { std::stringstream ss; ss << "SocketServer::listen() error calling inet_pton " << "at address " << _host << ":" << _port << " : " << strerror(Socket::getErrno()); Socket::closeSocket(_serverFd); return std::make_pair(false, ss.str()); } // Bind the socket to the server address. if (bind(_serverFd, (struct sockaddr*) &server, sizeof(server)) < 0) { std::stringstream ss; ss << "SocketServer::listen() error calling bind " << "at address " << _host << ":" << _port << " : " << strerror(Socket::getErrno()); Socket::closeSocket(_serverFd); return std::make_pair(false, ss.str()); } } // // Listen for connections. Specify the tcp backlog. // if (::listen(_serverFd, _backlog) < 0) { std::stringstream ss; ss << "SocketServer::listen() error calling listen " << "at address " << _host << ":" << _port << " : " << strerror(Socket::getErrno()); Socket::closeSocket(_serverFd); return std::make_pair(false, ss.str()); } return std::make_pair(true, ""); } void SocketServer::start() { _stop = false; if (!_thread.joinable()) { _thread = std::thread(&SocketServer::run, this); } if (!_gcThread.joinable()) { _gcThread = std::thread(&SocketServer::runGC, this); } } void SocketServer::wait() { std::unique_lock<std::mutex> lock(_conditionVariableMutex); _conditionVariable.wait(lock); } void SocketServer::stopAcceptingConnections() { _stop = true; } void SocketServer::stop() { // Stop accepting connections, and close the 'accept' thread if (_thread.joinable()) { _stop = true; // Wake up select if (!_acceptSelectInterrupt->notify(SelectInterrupt::kCloseRequest)) { logError("SocketServer::stop: Cannot wake up from select"); } _thread.join(); _stop = false; } // Join all threads and make sure that all connections are terminated if (_gcThread.joinable()) { _stopGc = true; _conditionVariableGC.notify_one(); _gcThread.join(); _stopGc = false; } _conditionVariable.notify_one(); Socket::closeSocket(_serverFd); } void SocketServer::setConnectionStateFactory( const ConnectionStateFactory& connectionStateFactory) { _connectionStateFactory = connectionStateFactory; } // // join the threads for connections that have been closed // // When a connection is closed by a client, the connection state terminated // field becomes true, and we can use that to know that we can join that thread // and remove it from our _connectionsThreads data structure (a list). // void SocketServer::closeTerminatedThreads() { std::lock_guard<std::mutex> lock(_connectionsThreadsMutex); auto it = _connectionsThreads.begin(); auto itEnd = _connectionsThreads.end(); while (it != itEnd) { auto& connectionState = it->first; auto& thread = it->second; if (!connectionState->isTerminated()) { ++it; continue; } if (thread.joinable()) thread.join(); it = _connectionsThreads.erase(it); } } void SocketServer::run() { // Set the socket to non blocking mode, so that accept calls are not blocking SocketConnect::configure(_serverFd); setThreadName("SocketServer::accept"); for (;;) { if (_stop) return; // Use poll to check whether a new connection is in progress int timeoutMs = -1; #ifdef _WIN32 // select cannot be interrupted on Windows so we need to pass a small timeout timeoutMs = 10; #endif bool readyToRead = true; PollResultType pollResult = Socket::poll(readyToRead, timeoutMs, _serverFd, _acceptSelectInterrupt); if (pollResult == PollResultType::Error) { std::stringstream ss; ss << "SocketServer::run() error in select: " << strerror(Socket::getErrno()); logError(ss.str()); continue; } if (pollResult != PollResultType::ReadyForRead) { continue; } // Accept a connection. // FIXME: Is this working for ipv6 ? struct sockaddr_in client; // client address information int clientFd; // socket connected to client socklen_t addressLen = sizeof(client); memset(&client, 0, sizeof(client)); if ((clientFd = accept(_serverFd, (struct sockaddr*) &client, &addressLen)) < 0) { if (!Socket::isWaitNeeded()) { // FIXME: that error should be propagated int err = Socket::getErrno(); std::stringstream ss; ss << "SocketServer::run() error accepting connection: " << err << ", " << strerror(err); logError(ss.str()); } continue; } if (getConnectedClientsCount() >= _maxConnections) { std::stringstream ss; ss << "SocketServer::run() reached max connections = " << _maxConnections << ". " << "Not accepting connection"; logError(ss.str()); Socket::closeSocket(clientFd); continue; } // Retrieve connection info, the ip address of the remote peer/client) std::string remoteIp; int remotePort; if (_addressFamily == AF_INET) { char remoteIp4[INET_ADDRSTRLEN]; if (inet_ntop(AF_INET, &client.sin_addr, remoteIp4, INET_ADDRSTRLEN) == nullptr) { int err = Socket::getErrno(); std::stringstream ss; ss << "SocketServer::run() error calling inet_ntop (ipv4): " << err << ", " << strerror(err); logError(ss.str()); Socket::closeSocket(clientFd); continue; } remotePort = client.sin_port; remoteIp = remoteIp4; } else // AF_INET6 { char remoteIp6[INET6_ADDRSTRLEN]; if (inet_ntop(AF_INET6, &client.sin_addr, remoteIp6, INET6_ADDRSTRLEN) == nullptr) { int err = Socket::getErrno(); std::stringstream ss; ss << "SocketServer::run() error calling inet_ntop (ipv6): " << err << ", " << strerror(err); logError(ss.str()); Socket::closeSocket(clientFd); continue; } remotePort = client.sin_port; remoteIp = remoteIp6; } std::shared_ptr<ConnectionState> connectionState; if (_connectionStateFactory) { connectionState = _connectionStateFactory(); } connectionState->setOnSetTerminatedCallback([this] { onSetTerminatedCallback(); }); connectionState->setRemoteIp(remoteIp); connectionState->setRemotePort(remotePort); if (_stop) return; // create socket std::string errorMsg; bool tls = _socketTLSOptions.tls; auto socket = createSocket(tls, clientFd, errorMsg, _socketTLSOptions); if (socket == nullptr) { logError("SocketServer::run() cannot create socket: " + errorMsg); Socket::closeSocket(clientFd); continue; } // Set the socket to non blocking mode + other tweaks SocketConnect::configure(clientFd); if (!socket->accept(errorMsg)) { logError("SocketServer::run() tls accept failed: " + errorMsg); Socket::closeSocket(clientFd); continue; } // Launch the handleConnection work asynchronously in its own thread. std::lock_guard<std::mutex> lock(_connectionsThreadsMutex); _connectionsThreads.push_back(std::make_pair( connectionState, std::thread( &SocketServer::handleConnection, this, std::move(socket), connectionState))); } } size_t SocketServer::getConnectionsThreadsCount() { std::lock_guard<std::mutex> lock(_connectionsThreadsMutex); return _connectionsThreads.size(); } void SocketServer::runGC() { setThreadName("SocketServer::GC"); for (;;) { // Garbage collection to shutdown/join threads for closed connections. closeTerminatedThreads(); // We quit this thread if all connections are closed and we received // a stop request by setting _stopGc to true. if (_stopGc && getConnectionsThreadsCount() == 0) { break; } // Unless we are stopping the server, wait for a connection // to be terminated to run the threads GC, instead of busy waiting // with a sleep if (!_stopGc) { std::unique_lock<std::mutex> lock(_conditionVariableMutexGC); _conditionVariableGC.wait(lock); } } } void SocketServer::setTLSOptions(const SocketTLSOptions& socketTLSOptions) { _socketTLSOptions = socketTLSOptions; } void SocketServer::onSetTerminatedCallback() { // a connection got terminated, we can run the connection thread GC, // so wake up the thread responsible for that _conditionVariableGC.notify_one(); } } // namespace ix