Port SSL fingerprint checking from ESP8266 WiFiClientSecure to ESP32 (#1397)

This commit is contained in:
chemicstry 2018-05-14 14:00:40 +03:00 committed by Me No Dev
parent 0ea9ea4447
commit 00f962439a
4 changed files with 157 additions and 0 deletions

View File

@ -210,6 +210,14 @@ void WiFiClientSecure::setPrivateKey (const char *private_key)
_private_key = private_key; _private_key = private_key;
} }
bool WiFiClientSecure::verify(const char* fp, const char* domain_name)
{
if (!sslclient)
return false;
return verify_ssl_fingerprint(sslclient, fp, domain_name);
}
int WiFiClientSecure::lastError(char *buf, const size_t size) int WiFiClientSecure::lastError(char *buf, const size_t size)
{ {
if (!_lastError) { if (!_lastError) {

View File

@ -58,6 +58,7 @@ public:
void setCACert(const char *rootCA); void setCACert(const char *rootCA);
void setCertificate(const char *client_ca); void setCertificate(const char *client_ca);
void setPrivateKey (const char *private_key); void setPrivateKey (const char *private_key);
bool verify(const char* fingerprint, const char* domain_name);
operator bool() operator bool()
{ {

View File

@ -12,6 +12,10 @@
#include <lwip/sockets.h> #include <lwip/sockets.h>
#include <lwip/sys.h> #include <lwip/sys.h>
#include <lwip/netdb.h> #include <lwip/netdb.h>
#include <mbedtls/sha256.h>
#include <mbedtls/oid.h>
#include <algorithm>
#include <string>
#include "ssl_client.h" #include "ssl_client.h"
@ -262,3 +266,145 @@ int get_ssl_receive(sslclient_context *ssl_client, uint8_t *data, int length)
//log_v( "%d bytes read", ret); //for low level debug //log_v( "%d bytes read", ret); //for low level debug
return ret; return ret;
} }
static bool parseHexNibble(char pb, uint8_t* res)
{
if (pb >= '0' && pb <= '9') {
*res = (uint8_t) (pb - '0'); return true;
} else if (pb >= 'a' && pb <= 'f') {
*res = (uint8_t) (pb - 'a' + 10); return true;
} else if (pb >= 'A' && pb <= 'F') {
*res = (uint8_t) (pb - 'A' + 10); return true;
}
return false;
}
// Compare a name from certificate and domain name, return true if they match
static bool matchName(const std::string& name, const std::string& domainName)
{
size_t wildcardPos = name.find('*');
if (wildcardPos == std::string::npos) {
// Not a wildcard, expect an exact match
return name == domainName;
}
size_t firstDotPos = name.find('.');
if (wildcardPos > firstDotPos) {
// Wildcard is not part of leftmost component of domain name
// Do not attempt to match (rfc6125 6.4.3.1)
return false;
}
if (wildcardPos != 0 || firstDotPos != 1) {
// Matching of wildcards such as baz*.example.com and b*z.example.com
// is optional. Maybe implement this in the future?
return false;
}
size_t domainNameFirstDotPos = domainName.find('.');
if (domainNameFirstDotPos == std::string::npos) {
return false;
}
return domainName.substr(domainNameFirstDotPos) == name.substr(firstDotPos);
}
// Verifies certificate provided by the peer to match specified SHA256 fingerprint
bool verify_ssl_fingerprint(sslclient_context *ssl_client, const char* fp, const char* domain_name)
{
// Convert hex string to byte array
uint8_t fingerprint_local[32];
int len = strlen(fp);
int pos = 0;
for (size_t i = 0; i < sizeof(fingerprint_local); ++i) {
while (pos < len && ((fp[pos] == ' ') || (fp[pos] == ':'))) {
++pos;
}
if (pos > len - 2) {
log_d("pos:%d len:%d fingerprint too short", pos, len);
return false;
}
uint8_t high, low;
if (!parseHexNibble(fp[pos], &high) || !parseHexNibble(fp[pos+1], &low)) {
log_d("pos:%d len:%d invalid hex sequence: %c%c", pos, len, fp[pos], fp[pos+1]);
return false;
}
pos += 2;
fingerprint_local[i] = low | (high << 4);
}
// Get certificate provided by the peer
const mbedtls_x509_crt* crt = mbedtls_ssl_get_peer_cert(&ssl_client->ssl_ctx);
if (!crt)
{
log_d("could not fetch peer certificate");
return false;
}
// Calculate certificate's SHA256 fingerprint
uint8_t fingerprint_remote[32];
mbedtls_sha256_context sha256_ctx;
mbedtls_sha256_init(&sha256_ctx);
mbedtls_sha256_starts(&sha256_ctx, false);
mbedtls_sha256_update(&sha256_ctx, crt->raw.p, crt->raw.len);
mbedtls_sha256_finish(&sha256_ctx, fingerprint_remote);
// Check if fingerprints match
if (memcmp(fingerprint_local, fingerprint_remote, 32))
{
log_d("fingerprint doesn't match");
return false;
}
// Additionally check if certificate has domain name if provided
if (domain_name)
return verify_ssl_dn(ssl_client, domain_name);
else
return true;
}
// Checks if peer certificate has specified domain in CN or SANs
bool verify_ssl_dn(sslclient_context *ssl_client, const char* domain_name)
{
log_d("domain name: '%s'", (domain_name)?domain_name:"(null)");
std::string domain_name_str(domain_name);
std::transform(domain_name_str.begin(), domain_name_str.end(), domain_name_str.begin(), ::tolower);
// Get certificate provided by the peer
const mbedtls_x509_crt* crt = mbedtls_ssl_get_peer_cert(&ssl_client->ssl_ctx);
// Check for domain name in SANs
const mbedtls_x509_sequence* san = &crt->subject_alt_names;
while (san != nullptr)
{
std::string san_str((const char*)san->buf.p, san->buf.len);
std::transform(san_str.begin(), san_str.end(), san_str.begin(), ::tolower);
if (matchName(san_str, domain_name_str))
return true;
log_d("SAN '%s': no match", san_str.c_str());
// Fetch next SAN
san = san->next;
}
// Check for domain name in CN
const mbedtls_asn1_named_data* common_name = &crt->subject;
while (common_name != nullptr)
{
// While iterating through DN objects, check for CN object
if (!MBEDTLS_OID_CMP(MBEDTLS_OID_AT_CN, &common_name->oid))
{
std::string common_name_str((const char*)common_name->val.p, common_name->val.len);
if (matchName(common_name_str, domain_name_str))
return true;
log_d("CN '%s': not match", common_name_str.c_str());
}
// Fetch next DN object
common_name = common_name->next;
}
return false;
}

View File

@ -32,5 +32,7 @@ void stop_ssl_socket(sslclient_context *ssl_client, const char *rootCABuff, cons
int data_to_read(sslclient_context *ssl_client); int data_to_read(sslclient_context *ssl_client);
int send_ssl_data(sslclient_context *ssl_client, const uint8_t *data, uint16_t len); int send_ssl_data(sslclient_context *ssl_client, const uint8_t *data, uint16_t len);
int get_ssl_receive(sslclient_context *ssl_client, uint8_t *data, int length); int get_ssl_receive(sslclient_context *ssl_client, uint8_t *data, int length);
bool verify_ssl_fingerprint(sslclient_context *ssl_client, const char* fp, const char* domain_name);
bool verify_ssl_dn(sslclient_context *ssl_client, const char* domain_name);
#endif #endif