From 6f02f0bb57034ffff265cb45d40285bd0bf69f72 Mon Sep 17 00:00:00 2001 From: Lorenzo Felletti <60483783+lorenzofelletti@users.noreply.github.com> Date: Wed, 12 Jun 2024 21:05:25 +0100 Subject: [PATCH] feat(socket): improve TLS socket (#11) * feat: tls options handling more stuff * feat(socket): tls cert too can be passed as a option in open * feat: adds ca and reset_max_fragment_length to `TlsSocketOptions` * chore: bumps minor version * fix: changes the remote visibility * docs: improve docs --- Cargo.lock | 2 +- Cargo.toml | 2 +- src/socket/tcp.rs | 8 ++-- src/socket/tls.rs | 58 +++++++++++++++----------- src/socket/udp.rs | 8 ++-- src/traits/io.rs | 15 +++++-- src/types/mod.rs | 101 +++++++++++++++++++++++++++++++++++++++++----- 7 files changed, 147 insertions(+), 47 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 506690d..ddd0c95 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -509,7 +509,7 @@ dependencies = [ [[package]] name = "psp-net" -version = "0.4.0" +version = "0.5.0" dependencies = [ "dns-protocol", "embedded-io", diff --git a/Cargo.toml b/Cargo.toml index 3eaa007..75f6fda 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "psp-net" -version = "0.4.0" +version = "0.5.0" edition = "2021" license-file = "LICENSE" keywords = ["psp", "net", "networking", "embedded", "gamedev"] diff --git a/src/socket/tcp.rs b/src/socket/tcp.rs index 57edc93..cab9322 100644 --- a/src/socket/tcp.rs +++ b/src/socket/tcp.rs @@ -204,18 +204,18 @@ impl ErrorType for TcpSocket { } impl OptionType for TcpSocket { - type Options = SocketOptions; + type Options<'a> = SocketOptions; } -impl Open for TcpSocket { +impl<'a> Open<'a> for TcpSocket { /// Return a TCP socket connected to the remote specified in `options` - fn open(&mut self, options: Self::Options) -> Result<(), Self::Error> + fn open(mut self, options: &'a Self::Options<'a>) -> Result where Self: Sized, { self.connect(options.remote())?; - Ok(()) + Ok(self) } } diff --git a/src/socket/tls.rs b/src/socket/tls.rs index f68ba04..7495b1a 100644 --- a/src/socket/tls.rs +++ b/src/socket/tls.rs @@ -1,8 +1,6 @@ use alloc::string::String; use embedded_io::{ErrorType, Read, Write}; -use embedded_tls::{ - blocking::TlsConnection, Aes128GcmSha256, Certificate, NoVerify, TlsConfig, TlsContext, -}; +use embedded_tls::{blocking::TlsConnection, Aes128GcmSha256, NoVerify, TlsConfig, TlsContext}; use rand::SeedableRng; use rand_chacha::ChaCha20Rng; @@ -26,6 +24,7 @@ pub struct TlsSocket<'a> { tls_connection: TlsConnection<'a, TcpSocket, Aes128GcmSha256>, /// The TLS config tls_config: TlsConfig<'a, Aes128GcmSha256>, + // certificate: Option>, } impl<'a> TlsSocket<'a> { @@ -36,8 +35,6 @@ impl<'a> TlsSocket<'a> { /// - `socket`: The TCP socket to use for the TLS connection /// - `record_read_buf`: A buffer to use for reading records /// - `record_write_buf`: A buffer to use for writing records - /// - `server_name`: The server name to connect to (e.g. "example.com") - /// - `cert`: An optional certificate to use for the connection /// /// # Returns /// A new TLS socket. @@ -49,7 +46,7 @@ impl<'a> TlsSocket<'a> { /// ```no_run /// let mut read_buf = TlsSocket::new_buffer(); /// let mut write_buf = TlsSocket::new_buffer(); - /// let tls_socket = TlsSocket::new(tcp_socket, &mut read_buf, &mut write_buf, "example.com", None); + /// let tls_socket = TlsSocket::new(tcp_socket, &mut read_buf, &mut write_buf); /// ``` /// /// # Notes @@ -58,25 +55,15 @@ impl<'a> TlsSocket<'a> { socket: TcpSocket, record_read_buf: &'a mut [u8], record_write_buf: &'a mut [u8], - server_name: &'a str, - cert: Option<&'a [u8]>, ) -> Self { - let tls_config: TlsConfig<'_, Aes128GcmSha256> = match cert { - Some(cert) => TlsConfig::new() - .with_server_name(server_name) - .with_cert(Certificate::RawPublicKey(cert)) - .enable_rsa_signatures(), - None => TlsConfig::new() - .with_server_name(server_name) - .enable_rsa_signatures(), - }; + let tls_config: TlsConfig<'_, Aes128GcmSha256> = TlsConfig::new(); let tls_connection: TlsConnection = TlsConnection::new(socket, record_read_buf, record_write_buf); - TlsSocket { tls_connection, tls_config, + // certificate: None, } } @@ -90,7 +77,7 @@ impl<'a> TlsSocket<'a> { /// ```no_run /// let mut read_buf = TlsSocket::new_buffer(); /// let mut write_buf = TlsSocket::new_buffer(); - /// let tls_socket = TlsSocket::new(tcp_socket, &mut read_buf, &mut write_buf, "example.com", None); + /// let tls_socket = TlsSocket::new(tcp_socket, &mut read_buf, &mut write_buf); /// ``` #[must_use] pub fn new_buffer() -> [u8; 16_384] { @@ -124,16 +111,41 @@ impl ErrorType for TlsSocket<'_> { } impl OptionType for TlsSocket<'_> { - type Options = TlsSocketOptions; + type Options<'a> = TlsSocketOptions<'a>; } -impl Open for TlsSocket<'_> { +impl<'a, 'b> Open<'a> for TlsSocket<'b> +where + 'a: 'b, +{ /// Open the TLS connection. - fn open(&mut self, options: Self::Options) -> Result<(), embedded_tls::TlsError> { + fn open(mut self, options: &'a Self::Options<'a>) -> Result { let mut rng = ChaCha20Rng::seed_from_u64(options.seed()); + + self.tls_config = self.tls_config.with_server_name(options.server_name()); + + if options.rsa_signatures_enabled() { + self.tls_config = self.tls_config.enable_rsa_signatures(); + } + + if options.reset_max_fragment_length() { + self.tls_config = self.tls_config.reset_max_fragment_length(); + } + + if let Some(cert) = options.cert() { + // self.certificate = Some(Certificate::RawPublicKey(cert)); + self.tls_config = self.tls_config.with_cert(cert.clone()); + } + + if let Some(ca) = options.ca() { + self.tls_config = self.tls_config.with_ca(ca.clone()); + } + let tls_context = TlsContext::new(&self.tls_config, &mut rng); self.tls_connection - .open::(tls_context) + .open::(tls_context)?; + + Ok(self) } } diff --git a/src/socket/udp.rs b/src/socket/udp.rs index 35aa2d5..9df711f 100644 --- a/src/socket/udp.rs +++ b/src/socket/udp.rs @@ -360,20 +360,20 @@ impl Drop for UdpSocket { } impl OptionType for UdpSocket { - type Options = SocketOptions; + type Options<'a> = SocketOptions; } impl ErrorType for UdpSocket { type Error = SocketError; } -impl Open for UdpSocket { +impl<'a> Open<'a> for UdpSocket { /// Open the socket - fn open(&mut self, options: Self::Options) -> Result<(), Self::Error> { + fn open(mut self, options: &'a Self::Options<'a>) -> Result { self.bind(None)?; self.connect(options.remote())?; - Ok(()) + Ok(self) } } diff --git a/src/traits/io.rs b/src/traits/io.rs index bf5a710..3f8d85d 100644 --- a/src/traits/io.rs +++ b/src/traits/io.rs @@ -1,14 +1,21 @@ pub trait OptionType { - type Options: ?Sized; + type Options<'b>: ?Sized; } /// Type implementing this trait support a Open semantics. -pub trait Open: ErrorType + OptionType { +pub trait Open<'a>: ErrorType + OptionType { /// Open a resource, using options for configuration. /// + /// # Arguments + /// - `options`: The options to use to configure the TLS connection + /// /// # Errors /// This function can return an error if the resource could not be opened. - fn open(&mut self, options: Self::Options) -> Result<(), Self::Error> + /// + /// # Notes + /// See [`TlsSocketOptions`](crate::types::TlsSocketOptions) for more information + /// on the options you can pass. + fn open(self, options: &'a Self::Options<'a>) -> Result where Self: Sized; } @@ -27,7 +34,7 @@ pub trait Open: ErrorType + OptionType { /// # Notes /// [`EasySocket`] types should implement in their [`drop`] method the steps required /// to close the acquired resources. -pub trait EasySocket: Open + Write + Read {} +pub trait EasySocket: for<'a> Open<'a> + Write + Read {} // re-exports pub trait Write = embedded_io::Write; diff --git a/src/types/mod.rs b/src/types/mod.rs index 06a925b..2461d66 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -1,3 +1,5 @@ +use alloc::string::String; + use crate::socket::SocketAddr; /// Socket options, such as remote address to connect to. @@ -11,7 +13,7 @@ use crate::socket::SocketAddr; #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub struct SocketOptions { /// Remote address to connect to - pub remote: SocketAddr, + remote: SocketAddr, } impl SocketOptions { @@ -28,20 +30,51 @@ impl SocketOptions { } } -#[derive(Clone, Debug, PartialEq, Eq)] -/// TLS socket options +/// TLS socket options. /// -/// # Fields -/// - [`seed`](Self::seed): Seed for the RNG -pub struct TlsSocketOptions { - pub seed: u64, +/// This is used by [`TlsSocket`](super::socket::tls::TlsSocket) when used as a +/// [`EasySocket`](super::traits::io::EasySocket). +#[derive(Clone, Debug)] +pub struct TlsSocketOptions<'a> { + seed: u64, + server_name: String, + cert: Option>, + ca: Option>, + enable_rsa_signatures: bool, + reset_max_fragment_length: bool, } -impl TlsSocketOptions { +impl<'a> TlsSocketOptions<'a> { /// Create a new socket options + /// + /// # Arguments + /// - `seed`: The seed to use for the RNG + /// - `server_name`: The server name to use #[must_use] - pub fn new(seed: u64) -> Self { - Self { seed } + pub fn new(seed: u64, server_name: String) -> Self { + Self { + seed, + server_name, + cert: None, + ca: None, + enable_rsa_signatures: true, + reset_max_fragment_length: false, + } + } + + /// Disable RSA signatures + /// + /// By default, RSA signatures are enabled. + pub fn disable_rsa_signatures(&mut self) { + self.enable_rsa_signatures = false; + } + + /// Set the certificate + /// + /// # Arguments + /// - `cert`: The certificate + pub fn set_cert(&mut self, cert: Certificate<'a>) { + self.cert = Some(cert); } /// Get the seed @@ -49,4 +82,52 @@ impl TlsSocketOptions { pub fn seed(&self) -> u64 { self.seed } + + /// Get the server name + #[must_use] + pub fn server_name(&self) -> &str { + &self.server_name + } + + /// Get the certificate + #[must_use] + pub fn cert(&self) -> Option<&Certificate<'a>> { + self.cert.as_ref() + } + + /// Return whether RSA signatures are enabled + #[must_use] + pub fn rsa_signatures_enabled(&self) -> bool { + self.enable_rsa_signatures + } + + /// Return whether the max fragment length should be reset + #[must_use] + pub fn reset_max_fragment_length(&self) -> bool { + self.reset_max_fragment_length + } + + /// Set whether the max fragment length should be reset + /// + /// By default, the max fragment length is not reset. + pub fn set_reset_max_fragment_length(&mut self, reset_max_fragment_length: bool) { + self.reset_max_fragment_length = reset_max_fragment_length; + } + + /// Get the CA + #[must_use] + pub fn ca(&self) -> Option<&Certificate<'a>> { + self.ca.as_ref() + } + + /// Set the CA (certificate authority) + /// + /// # Arguments + /// - `ca`: The CA + pub fn set_ca(&mut self, ca: Option>) { + self.ca = ca; + } } + +// re-exports +pub type Certificate<'a> = embedded_tls::Certificate<'a>;