#ifndef LWIP_OPEN_SRC #define LWIP_OPEN_SRC #endif #include #include #include "ArduinoOTA.h" #include "ESPmDNS.h" #include "MD5Builder.h" #include "Update.h" // #define OTA_DEBUG Serial ArduinoOTAClass::ArduinoOTAClass() : _port(0) , _initialized(false) , _rebootOnSuccess(true) , _mdnsEnabled(true) , _state(OTA_IDLE) , _size(0) , _cmd(0) , _ota_port(0) , _ota_timeout(1000) , _start_callback(NULL) , _end_callback(NULL) , _error_callback(NULL) , _progress_callback(NULL) { } ArduinoOTAClass::~ArduinoOTAClass(){ _udp_ota.stop(); } ArduinoOTAClass& ArduinoOTAClass::onStart(THandlerFunction fn) { _start_callback = fn; return *this; } ArduinoOTAClass& ArduinoOTAClass::onEnd(THandlerFunction fn) { _end_callback = fn; return *this; } ArduinoOTAClass& ArduinoOTAClass::onProgress(THandlerFunction_Progress fn) { _progress_callback = fn; return *this; } ArduinoOTAClass& ArduinoOTAClass::onError(THandlerFunction_Error fn) { _error_callback = fn; return *this; } ArduinoOTAClass& ArduinoOTAClass::setPort(uint16_t port) { if (!_initialized && !_port && port) { _port = port; } return *this; } ArduinoOTAClass& ArduinoOTAClass::setHostname(const char * hostname) { if (!_initialized && !_hostname.length() && hostname) { _hostname = hostname; } return *this; } String ArduinoOTAClass::getHostname() { return _hostname; } ArduinoOTAClass& ArduinoOTAClass::setPassword(const char * password) { if (!_initialized && !_password.length() && password) { MD5Builder passmd5; passmd5.begin(); passmd5.add(password); passmd5.calculate(); _password = passmd5.toString(); } return *this; } ArduinoOTAClass& ArduinoOTAClass::setPasswordHash(const char * password) { if (!_initialized && !_password.length() && password) { _password = password; } return *this; } ArduinoOTAClass& ArduinoOTAClass::setRebootOnSuccess(bool reboot){ _rebootOnSuccess = reboot; return *this; } ArduinoOTAClass& ArduinoOTAClass::setMdnsEnabled(bool enabled){ _mdnsEnabled = enabled; return *this; } void ArduinoOTAClass::begin() { if (_initialized){ log_w("already initialized"); return; } if (!_port) { _port = 3232; } if(!_udp_ota.begin(_port)){ log_e("udp bind failed"); return; } if (!_hostname.length()) { char tmp[20]; uint8_t mac[6]; WiFi.macAddress(mac); sprintf(tmp, "esp32-%02x%02x%02x%02x%02x%02x", mac[0], mac[1], mac[2], mac[3], mac[4], mac[5]); _hostname = tmp; } if(_mdnsEnabled){ MDNS.begin(_hostname.c_str()); MDNS.enableArduino(_port, (_password.length() > 0)); } _initialized = true; _state = OTA_IDLE; #ifdef OTA_DEBUG OTA_DEBUG.printf("OTA server at: %s.local:%u\n", _hostname.c_str(), _port); #endif } int ArduinoOTAClass::parseInt(){ char data[INT_BUFFER_SIZE]; uint8_t index = 0; char value; while(_udp_ota.peek() == ' ') _udp_ota.read(); while(index < INT_BUFFER_SIZE - 1){ value = _udp_ota.peek(); if(value < '0' || value > '9'){ data[index++] = '\0'; return atoi(data); } data[index++] = _udp_ota.read(); } return 0; } String ArduinoOTAClass::readStringUntil(char end){ String res = ""; int value; while(true){ value = _udp_ota.read(); if(value <= 0 || value == end){ return res; } res += (char)value; } return res; } void ArduinoOTAClass::_onRx(){ if (_state == OTA_IDLE) { int cmd = parseInt(); if (cmd != U_FLASH && cmd != U_SPIFFS) return; _cmd = cmd; _ota_port = parseInt(); _size = parseInt(); _udp_ota.read(); _md5 = readStringUntil('\n'); _md5.trim(); if(_md5.length() != 32){ return; } if (_password.length()){ MD5Builder nonce_md5; nonce_md5.begin(); nonce_md5.add(String(micros())); nonce_md5.calculate(); _nonce = nonce_md5.toString(); _udp_ota.beginPacket(_udp_ota.remoteIP(), _udp_ota.remotePort()); _udp_ota.printf("AUTH %s", _nonce.c_str()); _udp_ota.endPacket(); _state = OTA_WAITAUTH; return; } else { _udp_ota.beginPacket(_udp_ota.remoteIP(), _udp_ota.remotePort()); _udp_ota.print("OK"); _udp_ota.endPacket(); _ota_ip = _udp_ota.remoteIP(); _state = OTA_RUNUPDATE; } } else if (_state == OTA_WAITAUTH) { int cmd = parseInt(); if (cmd != U_AUTH) { _state = OTA_IDLE; return; } _udp_ota.read(); String cnonce = readStringUntil(' '); String response = readStringUntil('\n'); if (cnonce.length() != 32 || response.length() != 32) { _state = OTA_IDLE; return; } String challenge = _password + ":" + String(_nonce) + ":" + cnonce; MD5Builder _challengemd5; _challengemd5.begin(); _challengemd5.add(challenge); _challengemd5.calculate(); String result = _challengemd5.toString(); if(result.equals(response)){ _udp_ota.beginPacket(_udp_ota.remoteIP(), _udp_ota.remotePort()); _udp_ota.print("OK"); _udp_ota.endPacket(); _ota_ip = _udp_ota.remoteIP(); _state = OTA_RUNUPDATE; } else { _udp_ota.beginPacket(_udp_ota.remoteIP(), _udp_ota.remotePort()); _udp_ota.print("Authentication Failed"); _udp_ota.endPacket(); if (_error_callback) _error_callback(OTA_AUTH_ERROR); _state = OTA_IDLE; } } } void ArduinoOTAClass::_runUpdate() { if (!Update.begin(_size, _cmd)) { #ifdef OTA_DEBUG Update.printError(OTA_DEBUG); #endif if (_error_callback) { _error_callback(OTA_BEGIN_ERROR); } _state = OTA_IDLE; return; } Update.setMD5(_md5.c_str()); if (_start_callback) { _start_callback(); } if (_progress_callback) { _progress_callback(0, _size); } WiFiClient client; if (!client.connect(_ota_ip, _ota_port)) { if (_error_callback) { _error_callback(OTA_CONNECT_ERROR); } _state = OTA_IDLE; } uint32_t written = 0, total = 0, tried = 0; while (!Update.isFinished() && client.connected()) { size_t waited = _ota_timeout; size_t available = client.available(); while (!available && waited){ delay(1); waited -=1 ; available = client.available(); } if (!waited){ if(written && tried++ < 3){ #ifdef OTA_DEBUG OTA_DEBUG.printf("Try[%u]: %u\n", tried, written); #endif if(!client.printf("%u", written)){ #ifdef OTA_DEBUG OTA_DEBUG.printf("failed to respond\n"); #endif _state = OTA_IDLE; break; } continue; } #ifdef OTA_DEBUG OTA_DEBUG.printf("Receive Failed\n"); #endif if (_error_callback) { _error_callback(OTA_RECEIVE_ERROR); } _state = OTA_IDLE; Update.abort(); return; } if(!available){ #ifdef OTA_DEBUG OTA_DEBUG.printf("No Data: %u\n", waited); #endif _state = OTA_IDLE; break; } tried = 0; static uint8_t buf[1460]; if(available > 1460){ available = 1460; } size_t r = client.read(buf, available); if(r != available){ log_w("didn't read enough! %u != %u", r, available); } written = Update.write(buf, r); if (written > 0) { if(written != r){ log_w("didn't write enough! %u != %u", written, r); } if(!client.printf("%u", written)){ #ifdef OTA_DEBUG OTA_DEBUG.printf("failed to respond\n"); #endif } total += written; if(_progress_callback) { _progress_callback(total, _size); } } else { #ifdef OTA_DEBUG Update.printError(OTA_DEBUG); #endif } } if (Update.end()) { client.print("OK"); client.stop(); delay(10); if (_end_callback) { _end_callback(); } if(_rebootOnSuccess){ //let serial/network finish tasks that might be given in _end_callback delay(100); ESP.restart(); } } else { if (_error_callback) { _error_callback(OTA_END_ERROR); } Update.printError(client); client.stop(); delay(10); #ifdef OTA_DEBUG OTA_DEBUG.print("Update ERROR: "); Update.printError(OTA_DEBUG); #endif _state = OTA_IDLE; } } void ArduinoOTAClass::end() { _initialized = false; _udp_ota.stop(); if(_mdnsEnabled){ MDNS.end(); } _state = OTA_IDLE; #ifdef OTA_DEBUG OTA_DEBUG.println("OTA server stopped."); #endif } void ArduinoOTAClass::handle() { if (!_initialized) { return; } if (_state == OTA_RUNUPDATE) { _runUpdate(); _state = OTA_IDLE; } if(_udp_ota.parsePacket()){ _onRx(); } _udp_ota.flush(); // always flush, even zero length packets must be flushed. } int ArduinoOTAClass::getCommand() { return _cmd; } void ArduinoOTAClass::setTimeout(int timeoutInMillis) { _ota_timeout = timeoutInMillis; } #if !defined(NO_GLOBAL_INSTANCES) && !defined(NO_GLOBAL_ARDUINOOTA) ArduinoOTAClass ArduinoOTA; #endif