(openssl) verify that the certificate we are getting match the domain we are requesting

This commit is contained in:
Benjamin Sergeant 2018-10-05 12:08:45 -07:00
parent 96903b4d25
commit 10ce046b0f
3 changed files with 132 additions and 17 deletions

View File

@ -10,6 +10,8 @@
#include <cassert>
#include <iostream>
#include <openssl/x509v3.h>
#include <errno.h>
#define socketerrno errno
@ -65,18 +67,6 @@ SSL *openssl_create_connection(SSL_CTX *ctx, int socket)
return ssl;
}
bool openssl_check_server_cert(SSL *ssl, std::string& errMsg)
{
X509 *server_cert = SSL_get_peer_certificate(ssl);
if (server_cert == nullptr)
{
errMsg = "OpenSSL failed - peer didn't present a X509 certificate.";
return false;
}
X509_free(server_cert);
return true;
}
} // anonymous namespace
namespace ix
@ -166,7 +156,126 @@ namespace ix
return ctx;
}
bool SocketOpenSSL::openSSLHandshake(std::string& errMsg)
/**
* Check whether a hostname matches a pattern
*
* The pattern MUST contain at most a single, leading asterisk. This means that
* this function cannot serve as a generic validation function, as that would
* allow for partial wildcards, too. Also, this does not check whether the
* wildcard covers multiple levels of labels. For RTM, this suffices, as we
* are only interested in the main domain name.
*
* @param[in] hostname The hostname of the server
* @param[in] pattern The hostname pattern from a SSL certificate
* @return TRUE if the pattern matches, FALSE otherwise
*/
bool SocketOpenSSL::checkHost(const std::string& host, const char *pattern)
{
const char* hostname = host.c_str();
while (*pattern && *hostname)
{
if (*pattern == '*')
{
while (*hostname != '.' && *hostname) hostname++;
if (*(++pattern) != '.')
{
return false;
}
}
else
{
char p = *pattern;
char h = *hostname;
if ((p & ~32) >= 'A' && (p & ~32) <= 'Z')
{
p &= ~32;
h &= ~32;
}
if (*pattern != *hostname)
{
return false;
}
}
pattern++;
hostname++;
}
bool success = !(*hostname || *pattern);
return success;
}
bool SocketOpenSSL::openSSLCheckServerCert(SSL *ssl,
const std::string& hostname,
std::string& errMsg)
{
X509 *server_cert = SSL_get_peer_certificate(ssl);
if (server_cert == nullptr)
{
errMsg = "OpenSSL failed - peer didn't present a X509 certificate.";
return false;
}
#if OPENSSL_VERSION_NUMBER < 0x10100000L
// Check server name
bool hostname_verifies_ok = false;
STACK_OF(GENERAL_NAME) *san_names =
(STACK_OF(GENERAL_NAME)*) X509_get_ext_d2i((X509 *)server_cert,
NID_subject_alt_name, NULL, NULL);
if (san_names)
{
for (int i=0; i<sk_GENERAL_NAME_num(san_names); i++)
{
const GENERAL_NAME *sk_name = sk_GENERAL_NAME_value(san_names, i);
if (sk_name->type == GEN_DNS)
{
char *name = (char *)ASN1_STRING_data(sk_name->d.dNSName);
if ((size_t)ASN1_STRING_length(sk_name->d.dNSName) == strlen(name) &&
checkHost(hostname, name))
{
hostname_verifies_ok = true;
break;
}
}
}
}
sk_GENERAL_NAME_pop_free(san_names, GENERAL_NAME_free);
if (!hostname_verifies_ok)
{
int cn_pos = X509_NAME_get_index_by_NID(X509_get_subject_name((X509 *)server_cert),
NID_commonName, -1);
if (cn_pos)
{
X509_NAME_ENTRY *cn_entry = X509_NAME_get_entry(
X509_get_subject_name((X509 *)server_cert), cn_pos);
if (cn_entry)
{
ASN1_STRING *cn_asn1 = X509_NAME_ENTRY_get_data(cn_entry);
char *cn = (char *)ASN1_STRING_data(cn_asn1);
if ((size_t)ASN1_STRING_length(cn_asn1) == strlen(cn) &&
checkHost(hostname, cn))
{
hostname_verifies_ok = true;
}
}
}
}
if (!hostname_verifies_ok)
{
errMsg = "OpenSSL failed - certificate was issued for a different domain.";
return false;
}
#endif
X509_free(server_cert);
return true;
}
bool SocketOpenSSL::openSSLHandshake(const std::string& host, std::string& errMsg)
{
while (true)
{
@ -179,7 +288,7 @@ namespace ix
int connect_result = SSL_connect(_ssl_connection);
if (connect_result == 1)
{
return openssl_check_server_cert(_ssl_connection, errMsg);
return openSSLCheckServerCert(_ssl_connection, host, errMsg);
}
int reason = SSL_get_error(_ssl_connection, connect_result);
@ -245,7 +354,7 @@ namespace ix
// SNI support
SSL_set_tlsext_host_name(_ssl_connection, host.c_str());
handshakeSuccessful = openSSLHandshake(errMsg);
handshakeSuccessful = openSSLHandshake(host, errMsg);
}
if (!handshakeSuccessful)

View File

@ -36,7 +36,11 @@ namespace ix
private:
std::string getSSLError(int ret);
SSL_CTX* openSSLCreateContext(std::string& errMsg);
bool openSSLHandshake(std::string& errMsg);
bool openSSLHandshake(const std::string& hostname, std::string& errMsg);
bool openSSLCheckServerCert(SSL *ssl,
const std::string& hostname,
std::string& errMsg);
bool checkHost(const std::string& host, const char *pattern);
SSL_CTX* _ssl_context;
SSL* _ssl_connection;

View File

@ -202,7 +202,9 @@ namespace ix {
if (sscanf(line, "HTTP/1.0 %d", &status) == 1)
{
std::stringstream ss;
ss << "Server version is HTTP/1.0. Rejecting connection to " << host;
ss << "Server version is HTTP/1.0. Rejecting connection to " << host
<< ", status: " << status
<< ", HTTP Status line: " << line;
return WebSocketInitResult(false, status, ss.str());
}