openssl cleanup
This commit is contained in:
		@@ -4,6 +4,10 @@
 | 
			
		||||
 *  Copyright (c) 2018 Machine Zone, Inc. All rights reserved.
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
//
 | 
			
		||||
// http://itamarst.org/writings/win32sockets.html
 | 
			
		||||
//
 | 
			
		||||
 | 
			
		||||
#include "IXSocketConnect.h"
 | 
			
		||||
#include "IXDNSLookup.h"
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -18,67 +18,15 @@
 | 
			
		||||
#include <errno.h>
 | 
			
		||||
#define socketerrno errno
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
std::mutex initMutex;
 | 
			
		||||
bool openSSLInitialized = false;
 | 
			
		||||
bool openSSLInitializationSuccessful = false;
 | 
			
		||||
 | 
			
		||||
bool openSSLInitialize(std::string& errMsg)
 | 
			
		||||
{
 | 
			
		||||
    std::lock_guard<std::mutex> lock(initMutex);
 | 
			
		||||
 | 
			
		||||
    if (openSSLInitialized)
 | 
			
		||||
    {
 | 
			
		||||
        return openSSLInitializationSuccessful;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
#if OPENSSL_VERSION_NUMBER >= 0x10100000L
 | 
			
		||||
    if (!OPENSSL_init_ssl(OPENSSL_INIT_LOAD_CONFIG, nullptr))
 | 
			
		||||
    {
 | 
			
		||||
        errMsg = "OPENSSL_init_ssl failure";
 | 
			
		||||
 | 
			
		||||
        openSSLInitializationSuccessful = false;
 | 
			
		||||
        openSSLInitialized = true;
 | 
			
		||||
        return false;
 | 
			
		||||
    }
 | 
			
		||||
#else
 | 
			
		||||
    (void) OPENSSL_config(nullptr);
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
    (void) OpenSSL_add_ssl_algorithms();
 | 
			
		||||
    (void) SSL_load_error_strings();
 | 
			
		||||
 | 
			
		||||
    openSSLInitializationSuccessful = true;
 | 
			
		||||
    return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int openssl_verify_callback(int preverify, X509_STORE_CTX *x509_ctx)
 | 
			
		||||
{
 | 
			
		||||
    return preverify;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* create new SSL connection state object */
 | 
			
		||||
SSL *openssl_create_connection(SSL_CTX *ctx, int socket)
 | 
			
		||||
{
 | 
			
		||||
    assert(ctx != nullptr);
 | 
			
		||||
    assert(socket > 0);
 | 
			
		||||
 | 
			
		||||
    SSL *ssl = SSL_new(ctx);
 | 
			
		||||
    if (ssl)
 | 
			
		||||
        SSL_set_fd(ssl, socket);
 | 
			
		||||
    return ssl;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // anonymous namespace
 | 
			
		||||
 | 
			
		||||
namespace ix 
 | 
			
		||||
{
 | 
			
		||||
    std::atomic<bool> SocketOpenSSL::_openSSLInitializationSuccessful(false);
 | 
			
		||||
 | 
			
		||||
    SocketOpenSSL::SocketOpenSSL(int fd) : Socket(fd),
 | 
			
		||||
        _ssl_connection(nullptr), 
 | 
			
		||||
        _ssl_context(nullptr)
 | 
			
		||||
    {
 | 
			
		||||
        ;
 | 
			
		||||
        std::call_once(_openSSLInitFlag, &SocketOpenSSL::openSSLInitialize, this);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    SocketOpenSSL::~SocketOpenSSL()
 | 
			
		||||
@@ -86,6 +34,20 @@ namespace ix
 | 
			
		||||
        SocketOpenSSL::close();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    void SocketOpenSSL::openSSLInitialize()
 | 
			
		||||
    {
 | 
			
		||||
#if OPENSSL_VERSION_NUMBER >= 0x10100000L
 | 
			
		||||
        if (!OPENSSL_init_ssl(OPENSSL_INIT_LOAD_CONFIG, nullptr)) return;
 | 
			
		||||
#else
 | 
			
		||||
        (void) OPENSSL_config(nullptr);
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
        (void) OpenSSL_add_ssl_algorithms();
 | 
			
		||||
        (void) SSL_load_error_strings();
 | 
			
		||||
 | 
			
		||||
        _openSSLInitializationSuccessful = true;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    std::string SocketOpenSSL::getSSLError(int ret)
 | 
			
		||||
    {
 | 
			
		||||
        unsigned long e;
 | 
			
		||||
@@ -153,7 +115,12 @@ namespace ix
 | 
			
		||||
        if (ctx)
 | 
			
		||||
        {
 | 
			
		||||
            // To skip verification, pass in SSL_VERIFY_NONE
 | 
			
		||||
            SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER, openssl_verify_callback);
 | 
			
		||||
            SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER,
 | 
			
		||||
                               [](int preverify, X509_STORE_CTX*) -> int
 | 
			
		||||
                               {
 | 
			
		||||
                                   return preverify;
 | 
			
		||||
                               });
 | 
			
		||||
 | 
			
		||||
            SSL_CTX_set_verify_depth(ctx, 4);
 | 
			
		||||
            SSL_CTX_set_options(ctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3);
 | 
			
		||||
        }
 | 
			
		||||
@@ -283,8 +250,9 @@ namespace ix
 | 
			
		||||
        {
 | 
			
		||||
            std::lock_guard<std::mutex> lock(_mutex);
 | 
			
		||||
 | 
			
		||||
            if (!openSSLInitialize(errMsg))
 | 
			
		||||
            if (!_openSSLInitializationSuccessful)
 | 
			
		||||
            {
 | 
			
		||||
                errMsg = "OPENSSL_init_ssl failure";
 | 
			
		||||
                return false;
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
@@ -306,14 +274,15 @@ namespace ix
 | 
			
		||||
                errMsg += ERR_error_string(ssl_err, nullptr);
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            _ssl_connection = openssl_create_connection(_ssl_context, _sockfd);
 | 
			
		||||
            if (nullptr == _ssl_connection)
 | 
			
		||||
            _ssl_connection = SSL_new(_ssl_context);
 | 
			
		||||
            if (_ssl_connection == nullptr)
 | 
			
		||||
            {
 | 
			
		||||
                errMsg = "OpenSSL failed to connect";
 | 
			
		||||
                SSL_CTX_free(_ssl_context);
 | 
			
		||||
                _ssl_context = nullptr;
 | 
			
		||||
                return false;
 | 
			
		||||
            }
 | 
			
		||||
            SSL_set_fd(_ssl_connection, _sockfd);
 | 
			
		||||
 | 
			
		||||
            // SNI support
 | 
			
		||||
            SSL_set_tlsext_host_name(_ssl_connection, host.c_str());
 | 
			
		||||
 
 | 
			
		||||
@@ -36,6 +36,7 @@ namespace ix
 | 
			
		||||
        virtual int recv(void* buffer, size_t length) final;
 | 
			
		||||
 | 
			
		||||
    private:
 | 
			
		||||
        void openSSLInitialize();
 | 
			
		||||
        std::string getSSLError(int ret);
 | 
			
		||||
        SSL_CTX* openSSLCreateContext(std::string& errMsg);
 | 
			
		||||
        bool openSSLHandshake(const std::string& hostname, std::string& errMsg);
 | 
			
		||||
@@ -48,6 +49,9 @@ namespace ix
 | 
			
		||||
        SSL_CTX* _ssl_context;
 | 
			
		||||
        const SSL_METHOD* _ssl_method;
 | 
			
		||||
        mutable std::mutex _mutex;  // OpenSSL routines are not thread-safe
 | 
			
		||||
 | 
			
		||||
        std::once_flag _openSSLInitFlag;
 | 
			
		||||
        static std::atomic<bool> _openSSLInitializationSuccessful;
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -22,6 +22,7 @@ add_executable(ixwebsocket_unittest
 | 
			
		||||
  test_runner.cpp
 | 
			
		||||
  cmd_websocket_chat.cpp
 | 
			
		||||
  IXWebSocketServerTest.cpp
 | 
			
		||||
  IXSocketTest.cpp
 | 
			
		||||
  IXTest.cpp
 | 
			
		||||
  msgpack11.cpp
 | 
			
		||||
)
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										89
									
								
								test/IXSocketTest.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										89
									
								
								test/IXSocketTest.cpp
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,89 @@
 | 
			
		||||
/*
 | 
			
		||||
 *  IXSocketTest.cpp
 | 
			
		||||
 *  Author: Benjamin Sergeant
 | 
			
		||||
 *  Copyright (c) 2019 Machine Zone. All rights reserved.
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
#include <iostream>
 | 
			
		||||
#include <ixwebsocket/IXSocket.h>
 | 
			
		||||
#include <ixwebsocket/IXCancellationRequest.h>
 | 
			
		||||
 | 
			
		||||
#if defined(__APPLE__) or defined(__linux__)
 | 
			
		||||
# ifdef __APPLE__
 | 
			
		||||
#  include <ixwebsocket/IXSocketAppleSSL.h>
 | 
			
		||||
# else
 | 
			
		||||
#  include <ixwebsocket/IXSocketOpenSSL.h>
 | 
			
		||||
# endif
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#include "IXTest.h" 
 | 
			
		||||
#include "catch.hpp"
 | 
			
		||||
 | 
			
		||||
using namespace ix;
 | 
			
		||||
 | 
			
		||||
namespace ix
 | 
			
		||||
{
 | 
			
		||||
    void testSocket(const std::string& host,
 | 
			
		||||
                    int port,
 | 
			
		||||
                    const std::string& request,
 | 
			
		||||
                    std::shared_ptr<Socket> socket,
 | 
			
		||||
                    int expectedStatus,
 | 
			
		||||
                    int timeoutSecs)
 | 
			
		||||
    {
 | 
			
		||||
        std::string errMsg;
 | 
			
		||||
        static std::atomic<bool> requestInitCancellation(false);
 | 
			
		||||
        auto isCancellationRequested =
 | 
			
		||||
            makeCancellationRequestWithTimeout(timeoutSecs, requestInitCancellation);
 | 
			
		||||
 | 
			
		||||
        bool success = socket->connect(host, port, errMsg, isCancellationRequested);
 | 
			
		||||
        REQUIRE(success);
 | 
			
		||||
 | 
			
		||||
        std::cout << "Sending request: " << request
 | 
			
		||||
                  << "to " << host << ":" << port
 | 
			
		||||
                  << std::endl;
 | 
			
		||||
        socket->writeBytes(request, isCancellationRequested);
 | 
			
		||||
 | 
			
		||||
        auto lineResult = socket->readLine(isCancellationRequested);
 | 
			
		||||
        auto lineValid = lineResult.first;
 | 
			
		||||
        auto line = lineResult.second;
 | 
			
		||||
 | 
			
		||||
        REQUIRE(lineValid);
 | 
			
		||||
 | 
			
		||||
        int status = -1;
 | 
			
		||||
        REQUIRE(sscanf(line.c_str(), "HTTP/1.1 %d", &status) == 1);
 | 
			
		||||
        REQUIRE(status == expectedStatus);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_CASE("socket", "[socket]")
 | 
			
		||||
{
 | 
			
		||||
    SECTION("Connect to google HTTP server. Send GET request without header. Should return 200")
 | 
			
		||||
    {
 | 
			
		||||
        std::shared_ptr<Socket> socket(new Socket);
 | 
			
		||||
        std::string host("www.google.com");
 | 
			
		||||
        int port = 80;
 | 
			
		||||
        std::string request("GET / HTTP/1.1\r\n\r\n");
 | 
			
		||||
        int expectedStatus = 200;
 | 
			
		||||
        int timeoutSecs = 1;
 | 
			
		||||
 | 
			
		||||
        testSocket(host, port, request, socket, expectedStatus, timeoutSecs);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
#if defined(__APPLE__) or defined(__linux__)
 | 
			
		||||
    SECTION("Connect to google HTTPS server. Send GET request without header. Should return 200")
 | 
			
		||||
    {
 | 
			
		||||
# ifdef __APPLE__
 | 
			
		||||
        std::shared_ptr<Socket> socket = std::make_shared<SocketAppleSSL>();
 | 
			
		||||
# else
 | 
			
		||||
        std::shared_ptr<Socket> socket = std::make_shared<SocketOpenSSL>();
 | 
			
		||||
# endif
 | 
			
		||||
        std::string host("www.google.com");
 | 
			
		||||
        int port = 443;
 | 
			
		||||
        std::string request("GET / HTTP/1.1\r\n\r\n");
 | 
			
		||||
        int expectedStatus = 200;
 | 
			
		||||
        int timeoutSecs = 1;
 | 
			
		||||
 | 
			
		||||
        testSocket(host, port, request, socket, expectedStatus, timeoutSecs);
 | 
			
		||||
    }
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
@@ -42,9 +42,10 @@ def findFiles(prefix):
 | 
			
		||||
 | 
			
		||||
    return paths
 | 
			
		||||
 | 
			
		||||
for path in findFiles('.'):
 | 
			
		||||
    print(path)
 | 
			
		||||
#for path in findFiles('.'):
 | 
			
		||||
#    print(path)
 | 
			
		||||
 | 
			
		||||
# We need to copy the zlib DLL in the current work directory
 | 
			
		||||
shutil.copy(os.path.join(
 | 
			
		||||
    '..',
 | 
			
		||||
    '..',
 | 
			
		||||
@@ -56,7 +57,7 @@ shutil.copy(os.path.join(
 | 
			
		||||
    'bin',
 | 
			
		||||
    'zlib.dll'), '.')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# unittest broken on Windows
 | 
			
		||||
if osName != 'Windows':
 | 
			
		||||
    os.system(testBinary)
 | 
			
		||||
    testCommand = '{} {}'.format(testBinary, os.getenv('TEST', ''))
 | 
			
		||||
    os.system(testCommand)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user