Skip to content

Commit

Permalink
feat(socket): improve TLS socket (#11)
Browse files Browse the repository at this point in the history
* feat: tls options handling more stuff

* feat(socket): tls cert too can be passed as a option in open

* feat: adds ca and reset_max_fragment_length to `TlsSocketOptions`

* chore: bumps minor version

* fix: changes the remote visibility

* docs: improve docs
  • Loading branch information
lorenzofelletti authored Jun 12, 2024
1 parent e7a682c commit 6f02f0b
Show file tree
Hide file tree
Showing 7 changed files with 147 additions and 47 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "psp-net"
version = "0.4.0"
version = "0.5.0"
edition = "2021"
license-file = "LICENSE"
keywords = ["psp", "net", "networking", "embedded", "gamedev"]
Expand Down
8 changes: 4 additions & 4 deletions src/socket/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,18 +204,18 @@ impl ErrorType for TcpSocket {
}

impl OptionType for TcpSocket {
type Options = SocketOptions;
type Options<'a> = SocketOptions;
}

impl Open for TcpSocket {
impl<'a> Open<'a> for TcpSocket {
/// Return a TCP socket connected to the remote specified in `options`
fn open(&mut self, options: Self::Options) -> Result<(), Self::Error>
fn open(mut self, options: &'a Self::Options<'a>) -> Result<Self, Self::Error>
where
Self: Sized,
{
self.connect(options.remote())?;

Ok(())
Ok(self)
}
}

Expand Down
58 changes: 35 additions & 23 deletions src/socket/tls.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
use alloc::string::String;
use embedded_io::{ErrorType, Read, Write};
use embedded_tls::{
blocking::TlsConnection, Aes128GcmSha256, Certificate, NoVerify, TlsConfig, TlsContext,
};
use embedded_tls::{blocking::TlsConnection, Aes128GcmSha256, NoVerify, TlsConfig, TlsContext};

use rand::SeedableRng;
use rand_chacha::ChaCha20Rng;
Expand All @@ -26,6 +24,7 @@ pub struct TlsSocket<'a> {
tls_connection: TlsConnection<'a, TcpSocket, Aes128GcmSha256>,
/// The TLS config
tls_config: TlsConfig<'a, Aes128GcmSha256>,
// certificate: Option<Certificate<'a>>,
}

impl<'a> TlsSocket<'a> {
Expand All @@ -36,8 +35,6 @@ impl<'a> TlsSocket<'a> {
/// - `socket`: The TCP socket to use for the TLS connection
/// - `record_read_buf`: A buffer to use for reading records
/// - `record_write_buf`: A buffer to use for writing records
/// - `server_name`: The server name to connect to (e.g. "example.com")
/// - `cert`: An optional certificate to use for the connection
///
/// # Returns
/// A new TLS socket.
Expand All @@ -49,7 +46,7 @@ impl<'a> TlsSocket<'a> {
/// ```no_run
/// let mut read_buf = TlsSocket::new_buffer();
/// let mut write_buf = TlsSocket::new_buffer();
/// let tls_socket = TlsSocket::new(tcp_socket, &mut read_buf, &mut write_buf, "example.com", None);
/// let tls_socket = TlsSocket::new(tcp_socket, &mut read_buf, &mut write_buf);
/// ```
///
/// # Notes
Expand All @@ -58,25 +55,15 @@ impl<'a> TlsSocket<'a> {
socket: TcpSocket,
record_read_buf: &'a mut [u8],
record_write_buf: &'a mut [u8],
server_name: &'a str,
cert: Option<&'a [u8]>,
) -> Self {
let tls_config: TlsConfig<'_, Aes128GcmSha256> = match cert {
Some(cert) => TlsConfig::new()
.with_server_name(server_name)
.with_cert(Certificate::RawPublicKey(cert))
.enable_rsa_signatures(),
None => TlsConfig::new()
.with_server_name(server_name)
.enable_rsa_signatures(),
};
let tls_config: TlsConfig<'_, Aes128GcmSha256> = TlsConfig::new();

let tls_connection: TlsConnection<TcpSocket, Aes128GcmSha256> =
TlsConnection::new(socket, record_read_buf, record_write_buf);

TlsSocket {
tls_connection,
tls_config,
// certificate: None,
}
}

Expand All @@ -90,7 +77,7 @@ impl<'a> TlsSocket<'a> {
/// ```no_run
/// let mut read_buf = TlsSocket::new_buffer();
/// let mut write_buf = TlsSocket::new_buffer();
/// let tls_socket = TlsSocket::new(tcp_socket, &mut read_buf, &mut write_buf, "example.com", None);
/// let tls_socket = TlsSocket::new(tcp_socket, &mut read_buf, &mut write_buf);
/// ```
#[must_use]
pub fn new_buffer() -> [u8; 16_384] {
Expand Down Expand Up @@ -124,16 +111,41 @@ impl ErrorType for TlsSocket<'_> {
}

impl OptionType for TlsSocket<'_> {
type Options = TlsSocketOptions;
type Options<'a> = TlsSocketOptions<'a>;
}

impl Open for TlsSocket<'_> {
impl<'a, 'b> Open<'a> for TlsSocket<'b>
where
'a: 'b,
{
/// Open the TLS connection.
fn open(&mut self, options: Self::Options) -> Result<(), embedded_tls::TlsError> {
fn open(mut self, options: &'a Self::Options<'a>) -> Result<Self, embedded_tls::TlsError> {
let mut rng = ChaCha20Rng::seed_from_u64(options.seed());

self.tls_config = self.tls_config.with_server_name(options.server_name());

if options.rsa_signatures_enabled() {
self.tls_config = self.tls_config.enable_rsa_signatures();
}

if options.reset_max_fragment_length() {
self.tls_config = self.tls_config.reset_max_fragment_length();
}

if let Some(cert) = options.cert() {
// self.certificate = Some(Certificate::RawPublicKey(cert));
self.tls_config = self.tls_config.with_cert(cert.clone());
}

if let Some(ca) = options.ca() {
self.tls_config = self.tls_config.with_ca(ca.clone());
}

let tls_context = TlsContext::new(&self.tls_config, &mut rng);
self.tls_connection
.open::<ChaCha20Rng, NoVerify>(tls_context)
.open::<ChaCha20Rng, NoVerify>(tls_context)?;

Ok(self)
}
}

Expand Down
8 changes: 4 additions & 4 deletions src/socket/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -360,20 +360,20 @@ impl Drop for UdpSocket {
}

impl OptionType for UdpSocket {
type Options = SocketOptions;
type Options<'a> = SocketOptions;
}

impl ErrorType for UdpSocket {
type Error = SocketError;
}

impl Open for UdpSocket {
impl<'a> Open<'a> for UdpSocket {
/// Open the socket
fn open(&mut self, options: Self::Options) -> Result<(), Self::Error> {
fn open(mut self, options: &'a Self::Options<'a>) -> Result<Self, Self::Error> {
self.bind(None)?;
self.connect(options.remote())?;

Ok(())
Ok(self)
}
}

Expand Down
15 changes: 11 additions & 4 deletions src/traits/io.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
pub trait OptionType {
type Options: ?Sized;
type Options<'b>: ?Sized;
}

/// Type implementing this trait support a Open semantics.
pub trait Open: ErrorType + OptionType {
pub trait Open<'a>: ErrorType + OptionType {
/// Open a resource, using options for configuration.
///
/// # Arguments
/// - `options`: The options to use to configure the TLS connection
///
/// # Errors
/// This function can return an error if the resource could not be opened.
fn open(&mut self, options: Self::Options) -> Result<(), Self::Error>
///
/// # Notes
/// See [`TlsSocketOptions`](crate::types::TlsSocketOptions) for more information
/// on the options you can pass.
fn open(self, options: &'a Self::Options<'a>) -> Result<Self, Self::Error>
where
Self: Sized;
}
Expand All @@ -27,7 +34,7 @@ pub trait Open: ErrorType + OptionType {
/// # Notes
/// [`EasySocket`] types should implement in their [`drop`] method the steps required
/// to close the acquired resources.
pub trait EasySocket: Open + Write + Read {}
pub trait EasySocket: for<'a> Open<'a> + Write + Read {}

// re-exports
pub trait Write = embedded_io::Write;
Expand Down
101 changes: 91 additions & 10 deletions src/types/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use alloc::string::String;

use crate::socket::SocketAddr;

/// Socket options, such as remote address to connect to.
Expand All @@ -11,7 +13,7 @@ use crate::socket::SocketAddr;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct SocketOptions {
/// Remote address to connect to
pub remote: SocketAddr,
remote: SocketAddr,
}

impl SocketOptions {
Expand All @@ -28,25 +30,104 @@ impl SocketOptions {
}
}

#[derive(Clone, Debug, PartialEq, Eq)]
/// TLS socket options
/// TLS socket options.
///
/// # Fields
/// - [`seed`](Self::seed): Seed for the RNG
pub struct TlsSocketOptions {
pub seed: u64,
/// This is used by [`TlsSocket`](super::socket::tls::TlsSocket) when used as a
/// [`EasySocket`](super::traits::io::EasySocket).
#[derive(Clone, Debug)]
pub struct TlsSocketOptions<'a> {
seed: u64,
server_name: String,
cert: Option<Certificate<'a>>,
ca: Option<Certificate<'a>>,
enable_rsa_signatures: bool,
reset_max_fragment_length: bool,
}

impl TlsSocketOptions {
impl<'a> TlsSocketOptions<'a> {
/// Create a new socket options
///
/// # Arguments
/// - `seed`: The seed to use for the RNG
/// - `server_name`: The server name to use
#[must_use]
pub fn new(seed: u64) -> Self {
Self { seed }
pub fn new(seed: u64, server_name: String) -> Self {
Self {
seed,
server_name,
cert: None,
ca: None,
enable_rsa_signatures: true,
reset_max_fragment_length: false,
}
}

/// Disable RSA signatures
///
/// By default, RSA signatures are enabled.
pub fn disable_rsa_signatures(&mut self) {
self.enable_rsa_signatures = false;
}

/// Set the certificate
///
/// # Arguments
/// - `cert`: The certificate
pub fn set_cert(&mut self, cert: Certificate<'a>) {
self.cert = Some(cert);
}

/// Get the seed
#[must_use]
pub fn seed(&self) -> u64 {
self.seed
}

/// Get the server name
#[must_use]
pub fn server_name(&self) -> &str {
&self.server_name
}

/// Get the certificate
#[must_use]
pub fn cert(&self) -> Option<&Certificate<'a>> {
self.cert.as_ref()
}

/// Return whether RSA signatures are enabled
#[must_use]
pub fn rsa_signatures_enabled(&self) -> bool {
self.enable_rsa_signatures
}

/// Return whether the max fragment length should be reset
#[must_use]
pub fn reset_max_fragment_length(&self) -> bool {
self.reset_max_fragment_length
}

/// Set whether the max fragment length should be reset
///
/// By default, the max fragment length is not reset.
pub fn set_reset_max_fragment_length(&mut self, reset_max_fragment_length: bool) {
self.reset_max_fragment_length = reset_max_fragment_length;
}

/// Get the CA
#[must_use]
pub fn ca(&self) -> Option<&Certificate<'a>> {
self.ca.as_ref()
}

/// Set the CA (certificate authority)
///
/// # Arguments
/// - `ca`: The CA
pub fn set_ca(&mut self, ca: Option<Certificate<'a>>) {
self.ca = ca;
}
}

// re-exports
pub type Certificate<'a> = embedded_tls::Certificate<'a>;

0 comments on commit 6f02f0b

Please sign in to comment.