refactor select code + add protection against large fds (cf Android 9)

This commit is contained in:
Benjamin Sergeant 2019-06-25 15:41:39 -07:00
parent 0423ed01a6
commit e8a20c7e8a
8 changed files with 76 additions and 59 deletions

View File

@ -221,7 +221,7 @@ setOnConnectionCallback(
{ {
// Build a string for the response // Build a string for the response
std::stringstream ss; std::stringstream ss;
ss << request->method ss << request->method
<< " " << " "
<< request->uri; << request->uri;

View File

@ -131,7 +131,7 @@ namespace ix
return false; return false;
} }
return response->payload.empty() return response->payload.empty()
? true ? true
: socket->writeBytes(response->payload, nullptr); : socket->writeBytes(response->payload, nullptr);
} }

View File

@ -131,7 +131,7 @@ namespace ix
// Log request // Log request
std::stringstream ss; std::stringstream ss;
ss << request->method ss << request->method
<< " " << " "
<< request->uri << request->uri
<< " " << " "

View File

@ -47,23 +47,44 @@ namespace ix
PollResultType Socket::poll(bool readyToRead, PollResultType Socket::poll(bool readyToRead,
int timeoutMs, int timeoutMs,
int sockfd, int sockfd,
int interruptFd) std::shared_ptr<SelectInterrupt> selectInterrupt)
{ {
fd_set rfds; fd_set rfds;
fd_set wfds; fd_set wfds;
fd_set efds;
FD_ZERO(&rfds); FD_ZERO(&rfds);
FD_ZERO(&wfds); FD_ZERO(&wfds);
FD_ZERO(&efds);
// FD_SET cannot handle fds larger than FD_SETSIZE.
if (sockfd >= FD_SETSIZE)
{
return PollResultType::Error;
}
fd_set* fds = (readyToRead) ? &rfds : & wfds; fd_set* fds = (readyToRead) ? &rfds : & wfds;
if (sockfd != -1) if (sockfd != -1)
{ {
FD_SET(sockfd, fds); FD_SET(sockfd, fds);
FD_SET(sockfd, &efds);
} }
// File descriptor used to interrupt select when needed // File descriptor used to interrupt select when needed
if (interruptFd != -1) int interruptFd = -1;
if (selectInterrupt)
{ {
FD_SET(interruptFd, fds); interruptFd = selectInterrupt->getFd();
// FD_SET cannot handle fds larger than FD_SETSIZE.
if (interruptFd >= FD_SETSIZE)
{
return PollResultType::Error;
}
if (interruptFd != -1)
{
FD_SET(interruptFd, fds);
}
} }
struct timeval timeout; struct timeval timeout;
@ -73,7 +94,7 @@ namespace ix
// Compute the highest fd. // Compute the highest fd.
int nfds = (std::max)(sockfd, interruptFd); int nfds = (std::max)(sockfd, interruptFd);
int ret = ::select(nfds + 1, &rfds, &wfds, nullptr, int ret = ::select(nfds + 1, &rfds, &wfds, &efds,
(timeoutMs < 0) ? nullptr : &timeout); (timeoutMs < 0) ? nullptr : &timeout);
PollResultType pollResult = PollResultType::ReadyForRead; PollResultType pollResult = PollResultType::ReadyForRead;
@ -87,7 +108,7 @@ namespace ix
} }
else if (interruptFd != -1 && FD_ISSET(interruptFd, &rfds)) else if (interruptFd != -1 && FD_ISSET(interruptFd, &rfds))
{ {
uint64_t value = _selectInterrupt->read(); uint64_t value = selectInterrupt->read();
if (value == kSendRequest) if (value == kSendRequest)
{ {
@ -105,6 +126,25 @@ namespace ix
else if (sockfd != -1 && !readyToRead && FD_ISSET(sockfd, &wfds)) else if (sockfd != -1 && !readyToRead && FD_ISSET(sockfd, &wfds))
{ {
pollResult = PollResultType::ReadyForWrite; pollResult = PollResultType::ReadyForWrite;
#ifdef _WIN32
// On connect error, in async mode, windows will write to the exceptions fds
if (FD_ISSET(fd, &efds))
{
pollResult = PollResultType::Error;
}
#else
int optval = -1;
socklen_t optlen = sizeof(optval);
// getsockopt() puts the errno value for connect into optval so 0
// means no-error.
if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, &optval, &optlen) == -1 ||
optval != 0)
{
pollResult = PollResultType::Error;
}
#endif
} }
return pollResult; return pollResult;
@ -118,7 +158,7 @@ namespace ix
} }
bool readyToRead = true; bool readyToRead = true;
return poll(readyToRead, timeoutMs, _sockfd, _selectInterrupt->getFd()); return poll(readyToRead, timeoutMs, _sockfd, _selectInterrupt);
} }
PollResultType Socket::isReadyToWrite(int timeoutMs) PollResultType Socket::isReadyToWrite(int timeoutMs)
@ -129,7 +169,7 @@ namespace ix
} }
bool readyToRead = false; bool readyToRead = false;
return poll(readyToRead, timeoutMs, _sockfd, _selectInterrupt->getFd()); return poll(readyToRead, timeoutMs, _sockfd, _selectInterrupt);
} }
// Wake up from poll/select by writing to the pipe which is watched by select // Wake up from poll/select by writing to the pipe which is watched by select
@ -247,7 +287,7 @@ namespace ix
else else
{ {
buffer += ret; buffer += ret;
len -= ret; len -= ret;
continue; continue;
} }
} }

View File

@ -88,6 +88,12 @@ namespace ix
static bool isWaitNeeded(); static bool isWaitNeeded();
static void closeSocket(int fd); static void closeSocket(int fd);
static PollResultType poll(bool readyToRead,
int timeoutMs,
int sockfd,
std::shared_ptr<SelectInterrupt> selectInterrupt = nullptr);
// Used as special codes for pipe communication // Used as special codes for pipe communication
static const uint64_t kSendRequest; static const uint64_t kSendRequest;
static const uint64_t kCloseRequest; static const uint64_t kCloseRequest;
@ -97,11 +103,6 @@ namespace ix
std::mutex _socketMutex; std::mutex _socketMutex;
private: private:
PollResultType poll(bool readyToRead,
int timeoutMs,
int sockfd,
int interruptFd);
static const int kDefaultPollTimeout; static const int kDefaultPollTimeout;
static const int kDefaultPollNoTimeout; static const int kDefaultPollNoTimeout;

View File

@ -63,55 +63,31 @@ namespace ix
return -1; return -1;
} }
// On Linux the timeout needs to be re-initialized everytime int timeoutMs = 10;
// http://man7.org/linux/man-pages/man2/select.2.html bool readyToRead = false;
struct timeval timeout; PollResultType pollResult = Socket::poll(readyToRead, timeoutMs, fd);
timeout.tv_sec = 0;
timeout.tv_usec = 10 * 1000; // 10ms timeout
fd_set wfds; if (pollResult == PollResultType::Timeout)
fd_set efds; {
continue;
FD_ZERO(&wfds); }
FD_SET(fd, &wfds); else if (pollResult == PollResultType::Error)
FD_ZERO(&efds);
FD_SET(fd, &efds);
// Use select to check the status of the new connection
res = select(fd + 1, nullptr, &wfds, &efds, &timeout);
if (res < 0 && (Socket::getErrno() == EBADF || Socket::getErrno() == EINVAL))
{ {
Socket::closeSocket(fd); Socket::closeSocket(fd);
errMsg = std::string("Connect error, select error: ") + strerror(Socket::getErrno()); errMsg = std::string("Connect error: ") +
strerror(Socket::getErrno());
return -1; return -1;
} }
else if (pollResult == PollResultType::ReadyForWrite)
// Nothing was written to the socket, wait again.
if (!FD_ISSET(fd, &wfds)) continue;
// Something was written to the socket. Check for errors.
int optval = -1;
socklen_t optlen = sizeof(optval);
#ifdef _WIN32
// On connect error, in async mode, windows will write to the exceptions fds
if (FD_ISSET(fd, &efds))
#else
// getsockopt() puts the errno value for connect into optval so 0
// means no-error.
if (getsockopt(fd, SOL_SOCKET, SO_ERROR, &optval, &optlen) == -1 ||
optval != 0)
#endif
{ {
Socket::closeSocket(fd); return fd;
errMsg = strerror(optval);
return -1;
} }
else else
{ {
// Success ! Socket::closeSocket(fd);
return fd; errMsg = std::string("Connect error: ") +
strerror(Socket::getErrno());
return -1;
} }
} }

View File

@ -542,7 +542,7 @@ namespace ix
) { ) {
unmaskReceiveBuffer(ws); unmaskReceiveBuffer(ws);
MessageKind messageKind = MessageKind messageKind =
(ws.opcode == wsheader_type::TEXT_FRAME) (ws.opcode == wsheader_type::TEXT_FRAME)
? MessageKind::MSG_TEXT ? MessageKind::MSG_TEXT
: MessageKind::MSG_BINARY; : MessageKind::MSG_BINARY;

View File

@ -33,7 +33,7 @@ namespace
uint16_t getCloseCode(); uint16_t getCloseCode();
const std::string& getCloseReason(); const std::string& getCloseReason();
bool getCloseRemote(); bool getCloseRemote();
bool hasConnectionError() const; bool hasConnectionError() const;
private: private:
@ -56,7 +56,7 @@ namespace
{ {
; ;
} }
bool WebSocketClient::hasConnectionError() const bool WebSocketClient::hasConnectionError() const
{ {
return _connectionError; return _connectionError;