(openssl) verify that the certificate we are getting match the domain we are requesting
This commit is contained in:
		@@ -10,6 +10,8 @@
 | 
			
		||||
#include <cassert>
 | 
			
		||||
#include <iostream>
 | 
			
		||||
 | 
			
		||||
#include <openssl/x509v3.h>
 | 
			
		||||
 | 
			
		||||
#include <errno.h>
 | 
			
		||||
#define socketerrno errno
 | 
			
		||||
 | 
			
		||||
@@ -65,18 +67,6 @@ SSL *openssl_create_connection(SSL_CTX *ctx, int socket)
 | 
			
		||||
    return ssl;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool openssl_check_server_cert(SSL *ssl, std::string& errMsg)
 | 
			
		||||
{
 | 
			
		||||
    X509 *server_cert = SSL_get_peer_certificate(ssl);
 | 
			
		||||
    if (server_cert == nullptr)
 | 
			
		||||
    {
 | 
			
		||||
        errMsg = "OpenSSL failed - peer didn't present a X509 certificate.";
 | 
			
		||||
        return false;
 | 
			
		||||
    }
 | 
			
		||||
    X509_free(server_cert);
 | 
			
		||||
    return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // anonymous namespace
 | 
			
		||||
 | 
			
		||||
namespace ix 
 | 
			
		||||
@@ -166,7 +156,126 @@ namespace ix
 | 
			
		||||
        return ctx;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    bool SocketOpenSSL::openSSLHandshake(std::string& errMsg) 
 | 
			
		||||
    /**
 | 
			
		||||
     * Check whether a hostname matches a pattern
 | 
			
		||||
     *
 | 
			
		||||
     * The pattern MUST contain at most a single, leading asterisk. This means that
 | 
			
		||||
     * this function cannot serve as a generic validation function, as that would
 | 
			
		||||
     * allow for partial wildcards, too. Also, this does not check whether the
 | 
			
		||||
     * wildcard covers multiple levels of labels. For RTM, this suffices, as we
 | 
			
		||||
     * are only interested in the main domain name.
 | 
			
		||||
     *
 | 
			
		||||
     * @param[in] hostname The hostname of the server
 | 
			
		||||
     * @param[in] pattern The hostname pattern from a SSL certificate
 | 
			
		||||
     * @return TRUE if the pattern matches, FALSE otherwise
 | 
			
		||||
     */
 | 
			
		||||
    bool SocketOpenSSL::checkHost(const std::string& host, const char *pattern)
 | 
			
		||||
    {
 | 
			
		||||
        const char* hostname = host.c_str();
 | 
			
		||||
 | 
			
		||||
        while (*pattern && *hostname)
 | 
			
		||||
        {
 | 
			
		||||
            if (*pattern == '*')
 | 
			
		||||
            {
 | 
			
		||||
                while (*hostname != '.' && *hostname) hostname++;
 | 
			
		||||
                if (*(++pattern) != '.')
 | 
			
		||||
                {
 | 
			
		||||
                    return false;
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            else
 | 
			
		||||
            {
 | 
			
		||||
                char p = *pattern;
 | 
			
		||||
                char h = *hostname;
 | 
			
		||||
                if ((p & ~32) >= 'A' && (p & ~32) <= 'Z')
 | 
			
		||||
                {
 | 
			
		||||
                    p &= ~32;
 | 
			
		||||
                    h &= ~32;
 | 
			
		||||
                }
 | 
			
		||||
                if (*pattern != *hostname)
 | 
			
		||||
                {
 | 
			
		||||
                    return false;
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            pattern++;
 | 
			
		||||
            hostname++;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        bool success = !(*hostname || *pattern);
 | 
			
		||||
        return success;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    bool SocketOpenSSL::openSSLCheckServerCert(SSL *ssl,
 | 
			
		||||
                                               const std::string& hostname,
 | 
			
		||||
                                               std::string& errMsg)
 | 
			
		||||
    {
 | 
			
		||||
        X509 *server_cert = SSL_get_peer_certificate(ssl);
 | 
			
		||||
        if (server_cert == nullptr)
 | 
			
		||||
        {
 | 
			
		||||
            errMsg = "OpenSSL failed - peer didn't present a X509 certificate.";
 | 
			
		||||
            return false;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
#if OPENSSL_VERSION_NUMBER < 0x10100000L
 | 
			
		||||
        // Check server name
 | 
			
		||||
        bool hostname_verifies_ok = false;
 | 
			
		||||
        STACK_OF(GENERAL_NAME) *san_names = 
 | 
			
		||||
            (STACK_OF(GENERAL_NAME)*) X509_get_ext_d2i((X509 *)server_cert,
 | 
			
		||||
                                                       NID_subject_alt_name, NULL, NULL);
 | 
			
		||||
        if (san_names)
 | 
			
		||||
        {
 | 
			
		||||
            for (int i=0; i<sk_GENERAL_NAME_num(san_names); i++)
 | 
			
		||||
            {
 | 
			
		||||
                const GENERAL_NAME *sk_name = sk_GENERAL_NAME_value(san_names, i);
 | 
			
		||||
                if (sk_name->type == GEN_DNS)
 | 
			
		||||
                {
 | 
			
		||||
                    char *name = (char *)ASN1_STRING_data(sk_name->d.dNSName);
 | 
			
		||||
                    if ((size_t)ASN1_STRING_length(sk_name->d.dNSName) == strlen(name) && 
 | 
			
		||||
                        checkHost(hostname, name)) 
 | 
			
		||||
                    {
 | 
			
		||||
                        hostname_verifies_ok = true;
 | 
			
		||||
                        break;
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        sk_GENERAL_NAME_pop_free(san_names, GENERAL_NAME_free);
 | 
			
		||||
 | 
			
		||||
        if (!hostname_verifies_ok)
 | 
			
		||||
        {
 | 
			
		||||
            int cn_pos = X509_NAME_get_index_by_NID(X509_get_subject_name((X509 *)server_cert),
 | 
			
		||||
                                                    NID_commonName, -1);
 | 
			
		||||
            if (cn_pos)
 | 
			
		||||
            {
 | 
			
		||||
                X509_NAME_ENTRY *cn_entry = X509_NAME_get_entry(
 | 
			
		||||
                    X509_get_subject_name((X509 *)server_cert), cn_pos);
 | 
			
		||||
 | 
			
		||||
                if (cn_entry)
 | 
			
		||||
                {
 | 
			
		||||
                    ASN1_STRING *cn_asn1 = X509_NAME_ENTRY_get_data(cn_entry);
 | 
			
		||||
                    char *cn = (char *)ASN1_STRING_data(cn_asn1);
 | 
			
		||||
 | 
			
		||||
                    if ((size_t)ASN1_STRING_length(cn_asn1) == strlen(cn) && 
 | 
			
		||||
                       checkHost(hostname, cn)) 
 | 
			
		||||
                    {
 | 
			
		||||
                        hostname_verifies_ok = true;
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (!hostname_verifies_ok)
 | 
			
		||||
        {
 | 
			
		||||
            errMsg = "OpenSSL failed - certificate was issued for a different domain.";
 | 
			
		||||
            return false;
 | 
			
		||||
        }
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
        X509_free(server_cert);
 | 
			
		||||
        return true;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    bool SocketOpenSSL::openSSLHandshake(const std::string& host, std::string& errMsg) 
 | 
			
		||||
    {
 | 
			
		||||
        while (true)
 | 
			
		||||
        {
 | 
			
		||||
@@ -179,7 +288,7 @@ namespace ix
 | 
			
		||||
            int connect_result = SSL_connect(_ssl_connection);
 | 
			
		||||
            if (connect_result == 1)
 | 
			
		||||
            {
 | 
			
		||||
                return openssl_check_server_cert(_ssl_connection, errMsg);
 | 
			
		||||
                return openSSLCheckServerCert(_ssl_connection, host, errMsg);
 | 
			
		||||
            }
 | 
			
		||||
            int reason = SSL_get_error(_ssl_connection, connect_result);
 | 
			
		||||
 | 
			
		||||
@@ -245,7 +354,7 @@ namespace ix
 | 
			
		||||
            // SNI support
 | 
			
		||||
            SSL_set_tlsext_host_name(_ssl_connection, host.c_str());
 | 
			
		||||
 | 
			
		||||
            handshakeSuccessful = openSSLHandshake(errMsg);
 | 
			
		||||
            handshakeSuccessful = openSSLHandshake(host, errMsg);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (!handshakeSuccessful)
 | 
			
		||||
 
 | 
			
		||||
@@ -36,7 +36,11 @@ namespace ix
 | 
			
		||||
    private:
 | 
			
		||||
        std::string getSSLError(int ret);
 | 
			
		||||
        SSL_CTX* openSSLCreateContext(std::string& errMsg);
 | 
			
		||||
        bool openSSLHandshake(std::string& errMsg);
 | 
			
		||||
        bool openSSLHandshake(const std::string& hostname, std::string& errMsg);
 | 
			
		||||
        bool openSSLCheckServerCert(SSL *ssl,
 | 
			
		||||
                                    const std::string& hostname,
 | 
			
		||||
                                    std::string& errMsg);
 | 
			
		||||
        bool checkHost(const std::string& host, const char *pattern);
 | 
			
		||||
 | 
			
		||||
        SSL_CTX* _ssl_context;
 | 
			
		||||
        SSL* _ssl_connection;
 | 
			
		||||
 
 | 
			
		||||
@@ -202,7 +202,9 @@ namespace ix {
 | 
			
		||||
        if (sscanf(line, "HTTP/1.0 %d", &status) == 1)
 | 
			
		||||
        {
 | 
			
		||||
            std::stringstream ss;
 | 
			
		||||
            ss << "Server version is HTTP/1.0. Rejecting connection to " << host;
 | 
			
		||||
            ss << "Server version is HTTP/1.0. Rejecting connection to " << host
 | 
			
		||||
               << ", status: " << status
 | 
			
		||||
               << ", HTTP Status line: " << line;
 | 
			
		||||
            return WebSocketInitResult(false, status, ss.str());
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user