diff --git a/libraries/WiFiClientSecure/src/WiFiClientSecure.cpp b/libraries/WiFiClientSecure/src/WiFiClientSecure.cpp index 5793ccf1..e3ad4be9 100644 --- a/libraries/WiFiClientSecure/src/WiFiClientSecure.cpp +++ b/libraries/WiFiClientSecure/src/WiFiClientSecure.cpp @@ -210,6 +210,14 @@ void WiFiClientSecure::setPrivateKey (const char *private_key) _private_key = private_key; } +bool WiFiClientSecure::verify(const char* fp, const char* domain_name) +{ + if (!sslclient) + return false; + + return verify_ssl_fingerprint(sslclient, fp, domain_name); +} + int WiFiClientSecure::lastError(char *buf, const size_t size) { if (!_lastError) { diff --git a/libraries/WiFiClientSecure/src/WiFiClientSecure.h b/libraries/WiFiClientSecure/src/WiFiClientSecure.h index d57669e7..60f24e79 100644 --- a/libraries/WiFiClientSecure/src/WiFiClientSecure.h +++ b/libraries/WiFiClientSecure/src/WiFiClientSecure.h @@ -58,6 +58,7 @@ public: void setCACert(const char *rootCA); void setCertificate(const char *client_ca); void setPrivateKey (const char *private_key); + bool verify(const char* fingerprint, const char* domain_name); operator bool() { diff --git a/libraries/WiFiClientSecure/src/ssl_client.cpp b/libraries/WiFiClientSecure/src/ssl_client.cpp index d409c2ef..cab8e02a 100644 --- a/libraries/WiFiClientSecure/src/ssl_client.cpp +++ b/libraries/WiFiClientSecure/src/ssl_client.cpp @@ -12,6 +12,10 @@ #include #include #include +#include +#include +#include +#include #include "ssl_client.h" @@ -262,3 +266,145 @@ int get_ssl_receive(sslclient_context *ssl_client, uint8_t *data, int length) //log_v( "%d bytes read", ret); //for low level debug return ret; } + +static bool parseHexNibble(char pb, uint8_t* res) +{ + if (pb >= '0' && pb <= '9') { + *res = (uint8_t) (pb - '0'); return true; + } else if (pb >= 'a' && pb <= 'f') { + *res = (uint8_t) (pb - 'a' + 10); return true; + } else if (pb >= 'A' && pb <= 'F') { + *res = (uint8_t) (pb - 'A' + 10); return true; + } + return false; +} + +// Compare a name from certificate and domain name, return true if they match +static bool matchName(const std::string& name, const std::string& domainName) +{ + size_t wildcardPos = name.find('*'); + if (wildcardPos == std::string::npos) { + // Not a wildcard, expect an exact match + return name == domainName; + } + + size_t firstDotPos = name.find('.'); + if (wildcardPos > firstDotPos) { + // Wildcard is not part of leftmost component of domain name + // Do not attempt to match (rfc6125 6.4.3.1) + return false; + } + if (wildcardPos != 0 || firstDotPos != 1) { + // Matching of wildcards such as baz*.example.com and b*z.example.com + // is optional. Maybe implement this in the future? + return false; + } + size_t domainNameFirstDotPos = domainName.find('.'); + if (domainNameFirstDotPos == std::string::npos) { + return false; + } + return domainName.substr(domainNameFirstDotPos) == name.substr(firstDotPos); +} + +// Verifies certificate provided by the peer to match specified SHA256 fingerprint +bool verify_ssl_fingerprint(sslclient_context *ssl_client, const char* fp, const char* domain_name) +{ + // Convert hex string to byte array + uint8_t fingerprint_local[32]; + int len = strlen(fp); + int pos = 0; + for (size_t i = 0; i < sizeof(fingerprint_local); ++i) { + while (pos < len && ((fp[pos] == ' ') || (fp[pos] == ':'))) { + ++pos; + } + if (pos > len - 2) { + log_d("pos:%d len:%d fingerprint too short", pos, len); + return false; + } + uint8_t high, low; + if (!parseHexNibble(fp[pos], &high) || !parseHexNibble(fp[pos+1], &low)) { + log_d("pos:%d len:%d invalid hex sequence: %c%c", pos, len, fp[pos], fp[pos+1]); + return false; + } + pos += 2; + fingerprint_local[i] = low | (high << 4); + } + + // Get certificate provided by the peer + const mbedtls_x509_crt* crt = mbedtls_ssl_get_peer_cert(&ssl_client->ssl_ctx); + + if (!crt) + { + log_d("could not fetch peer certificate"); + return false; + } + + // Calculate certificate's SHA256 fingerprint + uint8_t fingerprint_remote[32]; + mbedtls_sha256_context sha256_ctx; + mbedtls_sha256_init(&sha256_ctx); + mbedtls_sha256_starts(&sha256_ctx, false); + mbedtls_sha256_update(&sha256_ctx, crt->raw.p, crt->raw.len); + mbedtls_sha256_finish(&sha256_ctx, fingerprint_remote); + + // Check if fingerprints match + if (memcmp(fingerprint_local, fingerprint_remote, 32)) + { + log_d("fingerprint doesn't match"); + return false; + } + + // Additionally check if certificate has domain name if provided + if (domain_name) + return verify_ssl_dn(ssl_client, domain_name); + else + return true; +} + +// Checks if peer certificate has specified domain in CN or SANs +bool verify_ssl_dn(sslclient_context *ssl_client, const char* domain_name) +{ + log_d("domain name: '%s'", (domain_name)?domain_name:"(null)"); + std::string domain_name_str(domain_name); + std::transform(domain_name_str.begin(), domain_name_str.end(), domain_name_str.begin(), ::tolower); + + // Get certificate provided by the peer + const mbedtls_x509_crt* crt = mbedtls_ssl_get_peer_cert(&ssl_client->ssl_ctx); + + // Check for domain name in SANs + const mbedtls_x509_sequence* san = &crt->subject_alt_names; + while (san != nullptr) + { + std::string san_str((const char*)san->buf.p, san->buf.len); + std::transform(san_str.begin(), san_str.end(), san_str.begin(), ::tolower); + + if (matchName(san_str, domain_name_str)) + return true; + + log_d("SAN '%s': no match", san_str.c_str()); + + // Fetch next SAN + san = san->next; + } + + // Check for domain name in CN + const mbedtls_asn1_named_data* common_name = &crt->subject; + while (common_name != nullptr) + { + // While iterating through DN objects, check for CN object + if (!MBEDTLS_OID_CMP(MBEDTLS_OID_AT_CN, &common_name->oid)) + { + std::string common_name_str((const char*)common_name->val.p, common_name->val.len); + + if (matchName(common_name_str, domain_name_str)) + return true; + + log_d("CN '%s': not match", common_name_str.c_str()); + } + + // Fetch next DN object + common_name = common_name->next; + } + + return false; +} diff --git a/libraries/WiFiClientSecure/src/ssl_client.h b/libraries/WiFiClientSecure/src/ssl_client.h index 96903971..81e0b33a 100644 --- a/libraries/WiFiClientSecure/src/ssl_client.h +++ b/libraries/WiFiClientSecure/src/ssl_client.h @@ -32,5 +32,7 @@ void stop_ssl_socket(sslclient_context *ssl_client, const char *rootCABuff, cons int data_to_read(sslclient_context *ssl_client); int send_ssl_data(sslclient_context *ssl_client, const uint8_t *data, uint16_t len); int get_ssl_receive(sslclient_context *ssl_client, uint8_t *data, int length); +bool verify_ssl_fingerprint(sslclient_context *ssl_client, const char* fp, const char* domain_name); +bool verify_ssl_dn(sslclient_context *ssl_client, const char* domain_name); #endif