diff --git a/Cargo.lock b/Cargo.lock index ebbc4f3..2bf74e3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1521,7 +1521,7 @@ checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" [[package]] name = "sign-firmware" -version = "0.1.2" +version = "0.1.3" dependencies = [ "anyhow", "async-io", diff --git a/Cargo.toml b/Cargo.toml index 704d0ac..5745542 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sign-firmware" -version = "0.1.2" +version = "0.1.3" authors = ["Jack Hogan "] edition = "2021" license = "MIT OR Apache-2.0" diff --git a/src/lib.rs b/src/lib.rs index 394ac66..b84c5c2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,10 +1,28 @@ // pub mod eeprom; // pub mod schema; +pub mod net; +use anyhow::anyhow; use esp_idf_svc::{hal::ledc::LedcDriver, sys::EspError}; use std::net::TcpStream; use std::os::fd::{AsRawFd, IntoRawFd}; +#[macro_export] +macro_rules! anyesp { + ($err: expr) => {{ + let res = $err; + if res != ::esp_idf_svc::sys::ESP_OK { + Err(::anyhow::anyhow!("Bad exit code {res}")) + } else { + Ok(()) + } + }}; +} + +pub fn convert_error(e: EspError) -> anyhow::Error { + anyhow!("Bad exit code {e}") +} + pub struct Leds { channels: [LedcDriver<'static>; 15], } diff --git a/src/main.rs b/src/main.rs index 4e7a512..6ddf040 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,334 +1,34 @@ #![feature(type_alias_impl_trait)] -use anyhow::anyhow; -use async_io::Async; use build_time::build_time_utc; use chrono_tz::US::Eastern; -use dotenvy_macro::dotenv; -use embassy_time::{with_timeout, Timer}; +use embassy_time::Timer; use esp_idf_svc::{ eventloop::EspSystemEventLoop, hal::{ gpio::PinDriver, ledc::{config::TimerConfig, LedcDriver, LedcTimerDriver}, peripherals::Peripherals, - reset::restart, task::block_on, }, - io::{self, Write}, + io, nvs::EspDefaultNvsPartition, ota::EspOta, sntp, - sys::EspError, timer::EspTaskTimerService, - tls::EspAsyncTls, - wifi::{AsyncWifi, ClientConfiguration, Configuration, EspWifi}, + wifi::{AsyncWifi, EspWifi}, }; -use http::Request; use lightning_time::LightningTime; use log::info; use palette::rgb::Rgb; -use sign_firmware::{Block, EspTlsSocket, Leds}; -use std::net::TcpStream; -use url::Url; +use sign_firmware::{ + net::{connect_to_network, self_update}, + Block, Leds, +}; extern crate alloc; -use core::str::FromStr; -use std::net::ToSocketAddrs; - -macro_rules! anyesp { - ($err: expr) => {{ - let res = $err; - if res != ::esp_idf_svc::sys::ESP_OK { - Err(::anyhow::anyhow!("Bad exit code {res}")) - } else { - Ok(()) - } - }}; -} - -#[derive(Debug, serde::Deserialize)] -struct GithubResponse { - tag_name: String, - assets: Vec, -} - -#[derive(Debug, serde::Deserialize)] -struct GithubAsset { - browser_download_url: String, -} - -async fn generate_tls(url: &str) -> anyhow::Result> { - let url = Url::from_str(url).unwrap(); - let host = url.host_str().unwrap(); - let addr = format!("{host}:443") - .to_socket_addrs() - .unwrap() - .collect::>(); - - let socket = Async::::connect(addr[0]).await.unwrap(); - - let mut tls = esp_idf_svc::tls::EspAsyncTls::adopt(EspTlsSocket::new(socket)).unwrap(); - - tls.negotiate(host, &esp_idf_svc::tls::Config::new()) - .await - .unwrap(); - - Ok(tls) -} - -fn create_raw_request(request: http::Request) -> String { - let method = request.method(); - let uri = request.uri(); - let headers = request.headers(); - - let mut request_text = format!("{} {} HTTP/1.1\r\n", method, uri); - for (key, value) in headers { - request_text.push_str(&format!("{}: {}\r\n", key, value.to_str().unwrap())); - } - request_text.push_str("\r\n"); // End of headers - - request_text -} - -async fn handle_redirect(url: &str) -> anyhow::Result> { - let request = Request::builder() - .method("GET") - .header("User-Agent", "PHSign/1.0.0") - .header("Host", "github.com") - .uri(url) - .body(()) - .unwrap(); - - let mut tls = generate_tls(url).await?; - - let request_text = create_raw_request(request); - - tls.write_all(request_text.as_bytes()) - .await - .map_err(convert_error)?; - - let mut body = [0; 8192]; - - let _read = io::utils::asynch::try_read_full(&mut tls, &mut body) - .await - .map_err(|(e, _)| e) - .unwrap(); - - let body = String::from_utf8(body.into()).expect("valid UTF8"); - - let splits = body.split("\r\n"); - - for split in splits { - if split.to_lowercase().starts_with("location: ") { - let location = split.split(": ").nth(1).expect("location value"); - - let request = Request::builder() - .method("GET") - .header("User-Agent", "PHSign/1.0.0") - .header("Host", "githubusercontent.com") - .uri(location) - .body(()) - .unwrap(); - - let tls = generate_tls(location).await?; - let request_text = create_raw_request(request); - - tls.write_all(request_text.as_bytes()) - .await - .map_err(convert_error)?; - - return Ok(tls); - } - } - - unreachable!("location must be in returned value!") -} - -async fn self_update(leds: &mut Leds) -> anyhow::Result<()> { - info!("Checking for self-update"); - - let manifest: GithubResponse = { - let url = "https://api.github.com/repos/purduehackers/sign-firmware/releases/latest"; - - let request = Request::builder() - .method("GET") - .header("User-Agent", "PHSign/1.0.0") - .header("Host", "api.github.com") - .uri(url) - .body(()) - .unwrap(); - - let mut tls = generate_tls(url).await?; - - let request_text = create_raw_request(request); - - tls.write_all(request_text.as_bytes()) - .await - .map_err(convert_error)?; - - let mut body = [0; 8192]; - - let _read = io::utils::asynch::try_read_full(&mut tls, &mut body) - .await - .map_err(|(e, _)| e) - .unwrap(); - - let body = String::from_utf8(body.into()).expect("valid UTF8"); - - let ind = body.find("\r\n\r\n").expect("body start"); - - serde_json::from_str(body[ind + 4..].trim().trim_end_matches(char::from(0))) - .expect("Valid parse for GitHub manifest") - }; - - let local = semver::Version::new( - env!("CARGO_PKG_VERSION_MAJOR").parse().unwrap(), - env!("CARGO_PKG_VERSION_MINOR").parse().unwrap(), - env!("CARGO_PKG_VERSION_PATCH").parse().unwrap(), - ); - - let remote = semver::Version::from_str(&manifest.tag_name[1..]).expect("valid semver"); - - if remote > local { - info!("New release found! Downloading and updating"); - leds.set_all_colors(Rgb::new(0, 255, 0)); - // Grab new release and update - let url = manifest - .assets - .first() - .expect("release to contain assets") - .browser_download_url - .clone(); - - let tls = handle_redirect(&url).await?; - - // Consume until \r\n\r\n (body) - info!("Consuming headers..."); - { - #[derive(Debug)] - enum ParseConsumerState { - None, - FirstCR, - FirstNL, - SecondCR, - } - let mut state = ParseConsumerState::None; - - let mut consumption_buffer = [0; 1]; - - loop { - let read = tls - .read(&mut consumption_buffer) - .await - .map_err(convert_error) - .expect("read byte for parse consumer"); - - if read == 0 { - panic!("Invalid update parse! Reached EOF before valid body"); - } - state = match state { - ParseConsumerState::None => { - if consumption_buffer[0] == b'\r' { - ParseConsumerState::FirstCR - } else { - ParseConsumerState::None - } - } - ParseConsumerState::FirstCR => { - if consumption_buffer[0] == b'\n' { - ParseConsumerState::FirstNL - } else { - ParseConsumerState::None - } - } - ParseConsumerState::FirstNL => { - if consumption_buffer[0] == b'\r' { - ParseConsumerState::SecondCR - } else { - ParseConsumerState::None - } - } - ParseConsumerState::SecondCR => { - if consumption_buffer[0] == b'\n' { - break; - } else { - ParseConsumerState::None - } - } - } - } - } - - info!("Headers consumed"); - - let mut body = [0; 8192]; - - let mut ota = EspOta::new().expect("ESP OTA success"); - - let mut update = ota.initiate_update().expect("update to initialize"); - - let mut chunk = 0_usize; - loop { - let read = - with_timeout(embassy_time::Duration::from_secs(10), tls.read(&mut body)).await; - - match read { - Ok(Ok(read)) => { - info!("[CHUNK {chunk:>4}] Read {read:>4}"); - - update.write_all(&body[..read]).expect("write update data"); - - if read == 0 { - break; - } - - chunk += 1; - } - Ok(Err(e)) => e.panic(), - Err(_) => break, - }; - } - - info!("Update completed! Activating..."); - - update - .finish() - .expect("update finalization to work") - .activate() - .expect("activation to work"); - - restart(); - } else { - info!("Already on latest version."); - } - - Ok(()) -} - -// #[embassy_executor::task] async fn amain(mut leds: Leds, mut wifi: AsyncWifi>) { - // let tls = TlsConfig::new( - // const_random::const_random!(u64), - // &mut read_buffer, - // &mut write_buffer, - // TlsVerify::None, - // ); - // let mut client = HttpClient::new_with_tls(&tcp, &dns, tls); - - // ThreadSpawnConfiguration { - // name: None, - // stack_size: 60_000, - // priority: 24, - // inherit: false, - // pin_to_core: Some(Core::Core1), - // } - // .set() - // .unwrap(); - - // let mut client = Client::wrap(&mut EspHttpConnection::new(&Default::default()).unwrap()); - // Red before wifi leds.set_all_colors(Rgb::new(255, 0, 0)); @@ -361,8 +61,6 @@ async fn amain(mut leds: Leds, mut wifi: AsyncWifi>) { } } -// static mut APP_CORE_STACK: Stack<8192> = Stack::new(); - fn main() { // It is necessary to call this function once. Otherwise some patches to the runtime // implemented by esp-idf-sys might not link properly. See https://github.com/esp-rs/esp-idf-template/issues/71 @@ -386,23 +84,6 @@ fn main() { let peripherals = Peripherals::take().unwrap(); - // let io = Io::new(peripherals.GPIO, peripherals.IO_MUX); - - // let sw_ints = SoftwareInterruptControl::new(peripherals.SW_INTERRUPT); - - // let timg0 = TimerGroup::new(peripherals.TIMG1); - - // let init = esp_wifi::initialize( - // esp_wifi::EspWifiInitFor::Wifi, - // timg0.timer1, - // esp_hal::rng::Rng::new(peripherals.RNG), - // peripherals.RADIO_CLK, - // ) - // .unwrap(); - - // let (wifi_interface, controller) = - // esp_wifi::wifi::new_with_mode(&init, peripherals.WIFI, WifiStaDevice).unwrap(); - let sys_loop = EspSystemEventLoop::take().unwrap(); let nvs = EspDefaultNvsPartition::take().unwrap(); @@ -529,103 +210,8 @@ fn main() { let _button_led = PinDriver::output(peripherals.pins.gpio15); let _button_switch = PinDriver::input(peripherals.pins.gpio36); - // let leds = [ - // PinDriver::output(peripherals.pins.gpio1.downgrade_output()).unwrap(), - // PinDriver::output(peripherals.pins.gpio2.downgrade_output()).unwrap(), - // PinDriver::output(peripherals.pins.gpio4.downgrade_output()).unwrap(), - // PinDriver::output(peripherals.pins.gpio5.downgrade_output()).unwrap(), - // PinDriver::output(peripherals.pins.gpio6.downgrade_output()).unwrap(), - // PinDriver::output(peripherals.pins.gpio7.downgrade_output()).unwrap(), - // PinDriver::output(peripherals.pins.gpio8.downgrade_output()).unwrap(), - // PinDriver::output(peripherals.pins.gpio9.downgrade_output()).unwrap(), - // PinDriver::output(peripherals.pins.gpio10.downgrade_output()).unwrap(), - // PinDriver::output(peripherals.pins.gpio11.downgrade_output()).unwrap(), - // PinDriver::output(peripherals.pins.gpio12.downgrade_output()).unwrap(), - // PinDriver::output(peripherals.pins.gpio13.downgrade_output()).unwrap(), - // PinDriver::output(peripherals.pins.gpio14.downgrade_output()).unwrap(), - // PinDriver::output(peripherals.pins.gpio17.downgrade_output()).unwrap(), - // PinDriver::output(peripherals.pins.gpio18.downgrade_output()).unwrap(), - // ]; - - // static EXECUTOR_CORE_1: StaticCell> = StaticCell::new(); - // let executor_core1 = InterruptExecutor::new(sw_ints.software_interrupt1); - // let executor_core1 = EXECUTOR_CORE_1.init(executor_core1); - - // let _guard = cpu_control - // .start_app_core(unsafe { &mut *addr_of_mut!(APP_CORE_STACK) }, move || { - // let spawner = executor_core1.start(Priority::max()); - - // spawner.spawn(leds_software_pwm(leds)).ok(); - - // // Just loop to show that the main thread does not need to poll the executor. - // loop {} - // }) - // .unwrap(); - - // unsafe { - // esp_idf_svc::hal::task::create( - // task_handler, - // task_name, - // stack_size, - // task_arg, - // priority, - // pin_to_core, - // ) - // .unwrap(); - // } - - // let (tx, _rx) = channel(); - - // let config = TWDTConfig { - // duration: Duration::from_secs(2), - // panic_on_trigger: true, - // subscribed_idle_tasks: Core::Core0.into(), - // }; - // let mut driver = TWDTDriver::new(peripherals.twdt, &config).unwrap(); - - // ThreadSpawnConfiguration { - // name: None, - // stack_size: 8000, - // priority: 24, - // inherit: false, - // pin_to_core: Some(Core::Core1), - // } - // .set() - // .unwrap(); - // std::thread::spawn(move || { - // let watchdog = driver.watch_current_task().unwrap(); - // // let mut leds = leds; - // // let mut last_buffer = [0; 15]; - // // let timer = EspTimerService::new() - // // .unwrap() - // // .timer(move || { - // // leds_software_pwm_timer(&mut leds, last_buffer); - // // }) - // // .unwrap(); - - // // timer - // // .every(Duration::from_secs_f64(1.0 / (256.0 * 120.0))) - // // .unwrap(); - - // // loop { - // // last_buffer = rx.try_recv().unwrap_or(last_buffer); - // // watchdog.feed().expect("watchdog ok"); - // // } - // // block_on(leds_software_pwm(leds)); - // leds_software_pwm(leds, watchdog, rx); - // }); - let leds = Leds::create(leds); - // static EXECUTOR_CORE_0: StaticCell = StaticCell::new(); - // let executor_core0 = Executor::new(); - // let executor_core0 = EXECUTOR_CORE_0.init(executor_core0); - // executor_core0.run(|spawner| { - // // spawner.spawn(connection(controller)).ok(); - // // spawner.spawn(net_task(stack)).ok(); - // spawner.spawn(amain(leds)).ok(); - // }); - std::thread::Builder::new() .stack_size(60_000) .spawn(|| { @@ -636,102 +222,3 @@ fn main() { .join() .unwrap(); } - -// fn to_anyhow(result: Result) -> anyhow::Result { -// match result { -// Ok(t) => Ok(t), -// Err(e) => Err(convert_error(e)), -// } -// } - -fn convert_error(e: EspError) -> anyhow::Error { - anyhow!("Bad exit code {e}") -} - -async fn connect_to_network(wifi: &mut AsyncWifi>) -> anyhow::Result<()> { - let config = Configuration::Client(ClientConfiguration { - ssid: dotenv!("WIFI_SSID").try_into().unwrap(), - password: "".try_into().unwrap(), - auth_method: esp_idf_svc::wifi::AuthMethod::WPA2Enterprise, - ..Default::default() - }); - - wifi.set_configuration(&config).map_err(convert_error)?; - - unsafe { - use esp_idf_svc::sys::*; - anyesp!(esp_wifi_set_mode(wifi_mode_t_WIFI_MODE_STA))?; - anyesp!(esp_eap_client_set_identity( - dotenv!("WIFI_EMAIL").as_ptr(), - dotenv!("WIFI_EMAIL").len() as i32 - ))?; - anyesp!(esp_eap_client_set_username( - dotenv!("WIFI_USERNAME").as_ptr(), - dotenv!("WIFI_USERNAME").len() as i32 - ))?; - anyesp!(esp_eap_client_set_password( - dotenv!("WIFI_PASSWORD").as_ptr(), - dotenv!("WIFI_PASSWORD").len() as i32 - ))?; - anyesp!(esp_eap_client_set_ttls_phase2_method( - esp_eap_ttls_phase2_types_ESP_EAP_TTLS_PHASE2_MSCHAPV2 - ))?; - anyesp!(esp_wifi_sta_enterprise_enable())?; - } - - wifi.start().await.map_err(convert_error)?; - - // Connect but with a longer timeout - wifi.wifi_mut().connect().map_err(convert_error)?; - wifi.wifi_wait( - |this| this.wifi().is_connected().map(|s| !s), - Some(std::time::Duration::from_secs(60)), - ) - .await?; - - // wifi.connect().await.map_err(convert_error)?; - - wifi.wait_netif_up().await.map_err(convert_error)?; - - info!("Wi-Fi connected!"); - - Ok(()) -} - -// #[embassy_executor::task] -// async fn connection(mut controller: WifiController<'static>) { -// info!("start connection task"); -// debug!("Device capabilities: {:?}", controller.get_capabilities()); -// loop { -// if matches!(esp_wifi::wifi::get_wifi_state(), WifiState::StaConnected) { -// // wait until we're no longer connected -// controller.wait_for_event(WifiEvent::StaDisconnected).await; -// Timer::after_millis(5000).await -// } -// if !matches!(controller.is_started(), Ok(true)) { -// // Assume we don't need any certs -// let client_config = Configuration::EapClient(EapClientConfiguration { -// ssid: heapless::String::from_str(dotenv!("WIFI_SSID")).unwrap(), -// auth_method: esp_wifi::wifi::AuthMethod::WPA2Enterprise, -// identity: Some(heapless::String::from_str(dotenv!("WIFI_USERNAME")).unwrap()), -// username: Some(heapless::String::from_str(dotenv!("WIFI_USERNAME")).unwrap()), -// password: Some(heapless::String::from_str(dotenv!("WIFI_PASSWORD")).unwrap()), -// ttls_phase2_method: Some(TtlsPhase2Method::Mschapv2), -// ..Default::default() -// }); -// controller.set_configuration(&client_config).unwrap(); -// info!("Starting wifi"); -// controller.start().await.unwrap(); -// info!("Wifi started!"); -// } -// info!("About to connect..."); - -// match controller.connect().await { -// Ok(_) => info!("Wifi connected!"), -// Err(e) => { -// error!("Failed to connect to wifi: {e:?}"); -// Timer::after_millis(5000).await -// } -// } -// } -// } diff --git a/src/net.rs b/src/net.rs new file mode 100644 index 0000000..023848a --- /dev/null +++ b/src/net.rs @@ -0,0 +1,326 @@ +use core::str::FromStr; +use std::net::TcpStream; +use std::net::ToSocketAddrs; + +use async_io::Async; +use dotenvy_macro::dotenv; +use embassy_time::with_timeout; +use esp_idf_svc::hal::reset::restart; +use esp_idf_svc::io::{self, Write}; +use esp_idf_svc::ota::EspOta; +use esp_idf_svc::tls::EspAsyncTls; +use esp_idf_svc::wifi::{AsyncWifi, ClientConfiguration, Configuration, EspWifi}; +use http::Request; +use log::info; +use palette::rgb::Rgb; +use url::Url; + +use crate::{anyesp, convert_error, EspTlsSocket, Leds}; + +#[derive(Debug, serde::Deserialize)] +struct GithubResponse { + tag_name: String, + assets: Vec, +} + +#[derive(Debug, serde::Deserialize)] +struct GithubAsset { + browser_download_url: String, +} + +pub async fn generate_tls(url: &str) -> anyhow::Result> { + let url = Url::from_str(url).unwrap(); + let host = url.host_str().unwrap(); + let addr = format!("{host}:443") + .to_socket_addrs() + .unwrap() + .collect::>(); + + let socket = Async::::connect(addr[0]).await.unwrap(); + + let mut tls = esp_idf_svc::tls::EspAsyncTls::adopt(EspTlsSocket::new(socket)).unwrap(); + + tls.negotiate(host, &esp_idf_svc::tls::Config::new()) + .await + .unwrap(); + + Ok(tls) +} + +pub fn create_raw_request(request: http::Request) -> String { + let method = request.method(); + let uri = request.uri(); + let headers = request.headers(); + + let mut request_text = format!("{} {} HTTP/1.1\r\n", method, uri); + for (key, value) in headers { + request_text.push_str(&format!("{}: {}\r\n", key, value.to_str().unwrap())); + } + request_text.push_str("\r\n"); // End of headers + + request_text +} + +pub async fn handle_redirect(url: &str) -> anyhow::Result> { + let request = Request::builder() + .method("GET") + .header("User-Agent", "PHSign/1.0.0") + .header("Host", "github.com") + .uri(url) + .body(()) + .unwrap(); + + let mut tls = generate_tls(url).await?; + + let request_text = create_raw_request(request); + + tls.write_all(request_text.as_bytes()) + .await + .map_err(convert_error)?; + + let mut body = [0; 8192]; + + let _read = io::utils::asynch::try_read_full(&mut tls, &mut body) + .await + .map_err(|(e, _)| e) + .unwrap(); + + let body = String::from_utf8(body.into()).expect("valid UTF8"); + + let splits = body.split("\r\n"); + + for split in splits { + if split.to_lowercase().starts_with("location: ") { + let location = split.split(": ").nth(1).expect("location value"); + + let request = Request::builder() + .method("GET") + .header("User-Agent", "PHSign/1.0.0") + .header("Host", "githubusercontent.com") + .uri(location) + .body(()) + .unwrap(); + + let tls = generate_tls(location).await?; + let request_text = create_raw_request(request); + + tls.write_all(request_text.as_bytes()) + .await + .map_err(convert_error)?; + + return Ok(tls); + } + } + + unreachable!("location must be in returned value!") +} + +pub async fn self_update(leds: &mut Leds) -> anyhow::Result<()> { + info!("Checking for self-update"); + + let manifest: GithubResponse = { + let url = "https://api.github.com/repos/purduehackers/sign-firmware/releases/latest"; + + let request = Request::builder() + .method("GET") + .header("User-Agent", "PHSign/1.0.0") + .header("Host", "api.github.com") + .uri(url) + .body(()) + .unwrap(); + + let mut tls = generate_tls(url).await?; + + let request_text = create_raw_request(request); + + tls.write_all(request_text.as_bytes()) + .await + .map_err(convert_error)?; + + let mut body = [0; 8192]; + + let _read = io::utils::asynch::try_read_full(&mut tls, &mut body) + .await + .map_err(|(e, _)| e) + .unwrap(); + + let body = String::from_utf8(body.into()).expect("valid UTF8"); + + let ind = body.find("\r\n\r\n").expect("body start"); + + serde_json::from_str(body[ind + 4..].trim().trim_end_matches(char::from(0))) + .expect("Valid parse for GitHub manifest") + }; + + let local = semver::Version::new( + env!("CARGO_PKG_VERSION_MAJOR").parse().unwrap(), + env!("CARGO_PKG_VERSION_MINOR").parse().unwrap(), + env!("CARGO_PKG_VERSION_PATCH").parse().unwrap(), + ); + + let remote = semver::Version::from_str(&manifest.tag_name[1..]).expect("valid semver"); + + if remote > local { + info!("New release found! Downloading and updating"); + leds.set_all_colors(Rgb::new(0, 255, 0)); + // Grab new release and update + let url = manifest + .assets + .first() + .expect("release to contain assets") + .browser_download_url + .clone(); + + let tls = handle_redirect(&url).await?; + + // Consume until \r\n\r\n (body) + info!("Consuming headers..."); + { + #[derive(Debug)] + enum ParseConsumerState { + None, + FirstCR, + FirstNL, + SecondCR, + } + + let mut state = ParseConsumerState::None; + + let mut consumption_buffer = [0; 1]; + + loop { + let read = tls + .read(&mut consumption_buffer) + .await + .map_err(convert_error) + .expect("read byte for parse consumer"); + + if read == 0 { + panic!("Invalid update parse! Reached EOF before valid body"); + } + state = match state { + ParseConsumerState::None => { + if consumption_buffer[0] == b'\r' { + ParseConsumerState::FirstCR + } else { + ParseConsumerState::None + } + } + ParseConsumerState::FirstCR => { + if consumption_buffer[0] == b'\n' { + ParseConsumerState::FirstNL + } else { + ParseConsumerState::None + } + } + ParseConsumerState::FirstNL => { + if consumption_buffer[0] == b'\r' { + ParseConsumerState::SecondCR + } else { + ParseConsumerState::None + } + } + ParseConsumerState::SecondCR => { + if consumption_buffer[0] == b'\n' { + break; + } else { + ParseConsumerState::None + } + } + } + } + } + + info!("Headers consumed"); + + let mut body = [0; 8192]; + + let mut ota = EspOta::new().expect("ESP OTA success"); + + let mut update = ota.initiate_update().expect("update to initialize"); + + let mut chunk = 0_usize; + loop { + let read = + with_timeout(embassy_time::Duration::from_secs(10), tls.read(&mut body)).await; + + match read { + Ok(Ok(read)) => { + info!("[CHUNK {chunk:>4}] Read {read:>4}"); + + update.write_all(&body[..read]).expect("write update data"); + + if read == 0 { + break; + } + + chunk += 1; + } + Ok(Err(e)) => e.panic(), + Err(_) => break, + }; + } + + info!("Update completed! Activating..."); + + update + .finish() + .expect("update finalization to work") + .activate() + .expect("activation to work"); + + restart(); + } else { + info!("Already on latest version."); + } + + Ok(()) +} + +pub async fn connect_to_network(wifi: &mut AsyncWifi>) -> anyhow::Result<()> { + let config = Configuration::Client(ClientConfiguration { + ssid: dotenv!("WIFI_SSID").try_into().unwrap(), + password: "".try_into().unwrap(), + auth_method: esp_idf_svc::wifi::AuthMethod::WPA2Enterprise, + ..Default::default() + }); + + wifi.set_configuration(&config).map_err(convert_error)?; + + unsafe { + use esp_idf_svc::sys::*; + anyesp!(esp_wifi_set_mode(wifi_mode_t_WIFI_MODE_STA))?; + anyesp!(esp_eap_client_set_identity( + dotenv!("WIFI_EMAIL").as_ptr(), + dotenv!("WIFI_EMAIL").len() as i32 + ))?; + anyesp!(esp_eap_client_set_username( + dotenv!("WIFI_USERNAME").as_ptr(), + dotenv!("WIFI_USERNAME").len() as i32 + ))?; + anyesp!(esp_eap_client_set_password( + dotenv!("WIFI_PASSWORD").as_ptr(), + dotenv!("WIFI_PASSWORD").len() as i32 + ))?; + anyesp!(esp_eap_client_set_ttls_phase2_method( + esp_eap_ttls_phase2_types_ESP_EAP_TTLS_PHASE2_MSCHAPV2 + ))?; + anyesp!(esp_wifi_sta_enterprise_enable())?; + } + + wifi.start().await.map_err(convert_error)?; + + // Connect but with a longer timeout + wifi.wifi_mut().connect().map_err(convert_error)?; + wifi.wifi_wait( + |this| this.wifi().is_connected().map(|s| !s), + Some(std::time::Duration::from_secs(60)), + ) + .await?; + + wifi.wait_netif_up().await.map_err(convert_error)?; + + info!("Wi-Fi connected!"); + + Ok(()) +}