Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use a builder pattern for enabling hardware accelerated RSA #31

Merged
merged 1 commit into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 49 additions & 8 deletions esp-mbedtls/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ pub use esp_mbedtls_sys::bindings::{
use esp_mbedtls_sys::c_types::*;

/// Hold the RSA peripheral for cryptographic operations.
///
/// This is initialized when `with_hardware_rsa()` is called on a [Session] and is set back to None
/// when the session that called `with_hardware_rsa()` is dropped.
///
/// Note: Due to implementation constraints, this session and every other session will use the
/// hardware accelerated RSA driver until the session called with this function is dropped.
static mut RSA_REF: Option<Rsa<esp_hal::Blocking>> = None;

// these will come from esp-wifi (i.e. this can only be used together with esp-wifi)
Expand Down Expand Up @@ -385,6 +391,8 @@ pub struct Session<T> {
crt: *mut mbedtls_x509_crt,
client_crt: *mut mbedtls_x509_crt,
private_key: *mut mbedtls_pk_context,
// Indicate if this session is the one holding the RSA ref
owns_rsa: bool,
}

impl<T> Session<T> {
Expand All @@ -399,8 +407,6 @@ impl<T> Session<T> {
/// * `min_version` - The minimum TLS version for the connection, that will be accepted.
/// * `certificates` - Certificate chain for the connection. Will play a different role
/// depending on if running as client or server. See [Certificates] for more information.
/// * `rsa` - Optionally take an RSA driver instance. This session will use the hardware rsa crypto
/// accelerators for the session. Passing None will use the software implementation of RSA which is slower.
///
/// # Errors
///
Expand All @@ -413,20 +419,33 @@ impl<T> Session<T> {
mode: Mode,
min_version: TlsVersion,
certificates: Certificates,
rsa: Option<impl Peripheral<P = RSA>>,
) -> Result<Self, TlsError> {
let (ssl_context, ssl_config, crt, client_crt, private_key) =
certificates.init_ssl(servername, mode, min_version)?;
unsafe { RSA_REF = core::mem::transmute(rsa.map(|inner| Rsa::new(inner, None))) }
return Ok(Self {
stream,
ssl_context,
ssl_config,
crt,
client_crt,
private_key,
owns_rsa: false,
});
}

/// Enable the use of the hardware accelerated RSA peripheral for the [Session].
///
/// Note: Due to implementation constraints, this session and every other session will use the
/// hardware accelerated RSA driver until the sesssion called with this function is dropped.
///
/// # Arguments
///
/// * `rsa` - The RSA peripheral from the HAL
pub fn with_hardware_rsa(mut self, rsa: impl Peripheral<P = RSA>) -> Self {
unsafe { RSA_REF = core::mem::transmute(Some(Rsa::new(rsa, None))) }
self.owns_rsa = true;
self
}
}

impl<T> Session<T>
Expand Down Expand Up @@ -536,6 +555,11 @@ impl<T> Drop for Session<T> {
fn drop(&mut self) {
log::debug!("session dropped - freeing memory");
unsafe {
// If the struct that owns the RSA reference is dropped
// we remove RSA in static for safety
if self.owns_rsa {
RSA_REF = core::mem::transmute(None::<RSA>);
}
mbedtls_ssl_close_notify(self.ssl_context);
mbedtls_ssl_config_free(self.ssl_config);
mbedtls_ssl_free(self.ssl_context);
Expand Down Expand Up @@ -611,6 +635,7 @@ pub mod asynch {
eof: bool,
tx_buffer: BufferedBytes<BUFFER_SIZE>,
rx_buffer: BufferedBytes<BUFFER_SIZE>,
owns_rsa: bool,
}

impl<T, const BUFFER_SIZE: usize> Session<T, BUFFER_SIZE> {
Expand All @@ -625,8 +650,6 @@ pub mod asynch {
/// * `min_version` - The minimum TLS version for the connection, that will be accepted.
/// * `certificates` - Certificate chain for the connection. Will play a different role
/// depending on if running as client or server. See [Certificates] for more information.
/// * `rsa` - Optionally take an RSA driver instance. This session will use the hardware rsa crypto
/// accelerators for the session. Passing None will use the software implementation of RSA which is slower.
///
/// # Errors
///
Expand All @@ -639,11 +662,9 @@ pub mod asynch {
mode: Mode,
min_version: TlsVersion,
certificates: Certificates,
rsa: Option<impl Peripheral<P = RSA>>,
) -> Result<Self, TlsError> {
let (ssl_context, ssl_config, crt, client_crt, private_key) =
certificates.init_ssl(servername, mode, min_version)?;
unsafe { RSA_REF = core::mem::transmute(rsa.map(|inner| Rsa::new(inner, None))) }
return Ok(Self {
stream,
ssl_context,
Expand All @@ -654,14 +675,34 @@ pub mod asynch {
eof: false,
tx_buffer: Default::default(),
rx_buffer: Default::default(),
owns_rsa: false,
});
}

/// Enable the use of the hardware accelerated RSA peripheral for the [Session].
///
/// Note: Due to implementation constraints, this session and every other session will use the
/// hardware accelerated RSA driver until the sesssion called with this function is dropped.
///
/// # Arguments
///
/// * `rsa` - The RSA peripheral from the HAL
pub fn with_hardware_rsa(mut self, rsa: impl Peripheral<P = RSA>) -> Self {
unsafe { RSA_REF = core::mem::transmute(Some(Rsa::new(rsa, None))) }
self.owns_rsa = true;
self
}
}

impl<T, const BUFFER_SIZE: usize> Drop for Session<T, BUFFER_SIZE> {
fn drop(&mut self) {
log::debug!("session dropped - freeing memory");
unsafe {
// If the struct that owns the RSA reference is dropped
// we remove RSA in static for safety
if self.owns_rsa {
RSA_REF = core::mem::transmute(None::<RSA>);
}
mbedtls_ssl_close_notify(self.ssl_context);
mbedtls_ssl_config_free(self.ssl_config);
mbedtls_ssl_free(self.ssl_context);
Expand Down
4 changes: 2 additions & 2 deletions examples/async_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,9 @@ async fn main(spawner: Spawner) -> ! {
.ok(),
..Default::default()
},
Some(peripherals.RSA),
)
.unwrap();
.unwrap()
.with_hardware_rsa(peripherals.RSA);

println!("Start tls connect");
let mut tls = tls.connect().await.unwrap();
Expand Down
4 changes: 2 additions & 2 deletions examples/async_client_mTLS.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,9 @@ async fn main(spawner: Spawner) -> ! {
Mode::Client,
TlsVersion::Tls1_3,
certificates,
Some(peripherals.RSA),
)
.unwrap();
.unwrap()
.with_hardware_rsa(peripherals.RSA);

println!("Start tls connect");
let mut tls = tls.connect().await.unwrap();
Expand Down
4 changes: 2 additions & 2 deletions examples/async_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,9 @@ async fn main(spawner: Spawner) -> ! {
.ok(),
..Default::default()
},
Some(&mut peripherals.RSA),
)
.unwrap();
.unwrap()
.with_hardware_rsa(&mut peripherals.RSA);

println!("Start tls connect");
match tls.connect().await {
Expand Down
4 changes: 2 additions & 2 deletions examples/async_server_mTLS.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,9 @@ async fn main(spawner: Spawner) -> ! {
.ok(),
..Default::default()
},
Some(&mut peripherals.RSA),
)
.unwrap();
.unwrap()
.with_hardware_rsa(&mut peripherals.RSA);

println!("Start tls connect");
match tls.connect().await {
Expand Down
4 changes: 2 additions & 2 deletions examples/sync_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,9 @@ fn main() -> ! {
.ok(),
..Default::default()
},
Some(peripherals.RSA),
)
.unwrap();
.unwrap()
.with_hardware_rsa(peripherals.RSA);

println!("Start tls connect");
let mut tls = tls.connect().unwrap();
Expand Down
4 changes: 2 additions & 2 deletions examples/sync_client_mTLS.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,9 @@ fn main() -> ! {
Mode::Client,
TlsVersion::Tls1_3,
certificates,
Some(peripherals.RSA),
)
.unwrap();
.unwrap()
.with_hardware_rsa(peripherals.RSA);

println!("Start tls connect");
let mut tls = tls.connect().unwrap();
Expand Down
5 changes: 3 additions & 2 deletions examples/sync_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,10 @@ fn main() -> ! {
.ok(),
..Default::default()
},
Some(&mut peripherals.RSA),
)
.unwrap();
.unwrap()
.with_hardware_rsa(&mut peripherals.RSA);

match tls.connect() {
Ok(mut connected_session) => {
loop {
Expand Down
5 changes: 3 additions & 2 deletions examples/sync_server_mTLS.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,10 @@ fn main() -> ! {
.ok(),
..Default::default()
},
Some(&mut peripherals.RSA),
)
.unwrap();
.unwrap()
.with_hardware_rsa(&mut peripherals.RSA);

match tls.connect() {
Ok(mut connected_session) => {
loop {
Expand Down
Loading