From a13df1070eeadee061da49cdb82ce5ea3c355852 Mon Sep 17 00:00:00 2001 From: Lorenzo Felletti <60483783+lorenzofelletti@users.noreply.github.com> Date: Wed, 17 Jul 2024 20:09:12 +0100 Subject: [PATCH] feat: use typestate pattern for UdpSocket and TcpSocket (#17) * wip * wip * feat: implement typestate pattern for UdpSocket and TcpSocket * feat: implement socket file descriptor abstraction That removes the need for a custom drop for upd and tcp sockets * chore: add .clippy.toml configuration * chore: clippy configuration * style: implement clippy suggestions * refactor: remove need for custom drop impl --- .clippy.toml | 3 + src/dns.rs | 32 +-- src/lib.rs | 3 + src/socket/error.rs | 9 - src/socket/mod.rs | 4 + src/socket/sce.rs | 52 +++++ src/socket/state.rs | 27 +++ src/socket/tcp.rs | 164 +++++++------- src/socket/tls.rs | 11 +- src/socket/udp.rs | 411 +++++++++++++++++------------------- src/traits/dns.rs | 2 + src/traits/io.rs | 6 +- src/traits/mod.rs | 3 +- src/types/socket_flags.rs | 4 +- src/types/socket_options.rs | 2 + 15 files changed, 393 insertions(+), 340 deletions(-) create mode 100644 .clippy.toml create mode 100644 src/socket/sce.rs create mode 100644 src/socket/state.rs diff --git a/.clippy.toml b/.clippy.toml new file mode 100644 index 0000000..d5cbb94 --- /dev/null +++ b/.clippy.toml @@ -0,0 +1,3 @@ +doc-valid-idents = ["PlayStation", ".."] + +allowed-prefixes = ["To", ".."] \ No newline at end of file diff --git a/src/dns.rs b/src/dns.rs index 6b4b655..13a7f70 100644 --- a/src/dns.rs +++ b/src/dns.rs @@ -1,3 +1,5 @@ +#![allow(clippy::module_name_repetitions)] + use alloc::{ borrow::ToOwned, string::{String, ToString}, @@ -8,7 +10,7 @@ use embedded_io::{Read, Write}; use embedded_nal::{IpAddr, Ipv4Addr, SocketAddr}; use psp::sys::in_addr; -use crate::socket::udp::UdpSocketState; +use crate::socket::state::Connected; use super::{ socket::{udp::UdpSocket, ToSocketAddr}, @@ -31,7 +33,7 @@ pub fn create_a_type_query(domain: &str) -> Question { #[derive(Debug, Clone, PartialEq, Eq)] pub enum DnsError { /// The DNS resolver failed to create - FailedToCreate, + FailedToCreate(String), /// The hostname could not be resolved HostnameResolutionFailed(String), /// The IP address could not be resolved @@ -41,7 +43,7 @@ pub enum DnsError { /// A DNS resolver pub struct DnsResolver { /// The UDP socket that is used to send and receive DNS messages - udp_socket: UdpSocket, + udp_socket: UdpSocket, /// The DNS server address dns: SocketAddr, } @@ -57,10 +59,14 @@ impl DnsResolver { /// happen if the socket could not be created or bound to the specified address #[allow(unused)] pub fn new(dns: SocketAddr) -> Result { - let mut udp_socket = UdpSocket::new().map_err(|_| DnsError::FailedToCreate)?; - udp_socket + let udp_socket = UdpSocket::new() + .map_err(|_| DnsError::FailedToCreate("Failed to create socket".to_owned()))?; + let udp_socket = udp_socket .bind(None) // binds to None, otherwise the socket errors for some reason - .map_err(|_| DnsError::FailedToCreate)?; + .map_err(|_| DnsError::FailedToCreate("Failed to bind socket".to_owned()))?; + let udp_socket = udp_socket + .connect(dns) + .map_err(|_| DnsError::FailedToCreate("Failed to connect socket".to_owned()))?; Ok(DnsResolver { udp_socket, dns }) } @@ -89,13 +95,6 @@ impl DnsResolver { /// This may happen if the connection of the socket fails, or if the DNS server /// does not answer the query, or any other error occurs pub fn resolve(&mut self, host: &str) -> Result { - // connect to the DNS server, if not already - if self.udp_socket.get_state() != UdpSocketState::Connected { - self.udp_socket - .connect(self.dns) - .map_err(|e| DnsError::HostnameResolutionFailed(e.to_string()))?; - } - // create a new query let mut questions = [super::dns::create_a_type_query(host)]; let query = dns_protocol::Message::new( @@ -166,6 +165,13 @@ impl DnsResolver { )), } } + + /// Get the [`SocketAddr`] of the DNS server + #[must_use] + #[inline] + pub fn dns(&self) -> SocketAddr { + self.dns + } } impl traits::dns::ResolveHostname for DnsResolver { diff --git a/src/lib.rs b/src/lib.rs index 7d06abb..8da3582 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,9 @@ #![no_std] #![feature(trait_alias)] #![doc = include_str!("../README.md")] +#![allow(clippy::cast_sign_loss)] +#![allow(clippy::cast_possible_truncation)] +#![allow(clippy::cast_possible_wrap)] extern crate alloc; diff --git a/src/socket/error.rs b/src/socket/error.rs index d05f8cc..d649907 100644 --- a/src/socket/error.rs +++ b/src/socket/error.rs @@ -3,14 +3,6 @@ use core::fmt::Display; /// An error that can occur with a socket #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] pub enum SocketError { - /// The socket is not connected - NotConnected, - /// The socket is already connected - AlreadyConnected, - /// The socket is already bound - AlreadyBound, - /// The socket is not bound - NotBound, /// Unsupported address family UnsupportedAddressFamily, /// Socket error with errno @@ -29,7 +21,6 @@ impl Display for SocketError { impl embedded_io::Error for SocketError { fn kind(&self) -> embedded_io::ErrorKind { match self { - SocketError::NotConnected => embedded_io::ErrorKind::NotConnected, SocketError::UnsupportedAddressFamily => embedded_io::ErrorKind::Unsupported, _ => embedded_io::ErrorKind::Other, } diff --git a/src/socket/mod.rs b/src/socket/mod.rs index e5e3f19..a283472 100644 --- a/src/socket/mod.rs +++ b/src/socket/mod.rs @@ -1,9 +1,13 @@ +#![allow(clippy::module_name_repetitions)] + use embedded_nal::{Ipv4Addr, SocketAddrV4}; use psp::sys::{in_addr, sockaddr}; use super::netc; pub mod error; +pub mod sce; +pub mod state; pub mod tcp; pub mod tls; pub mod udp; diff --git a/src/socket/sce.rs b/src/socket/sce.rs new file mode 100644 index 0000000..867af8b --- /dev/null +++ b/src/socket/sce.rs @@ -0,0 +1,52 @@ +use core::ops::Deref; + +use alloc::rc::Rc; +use psp::sys; + +/// Raw socket file descriptor +/// +/// This is a wrapper around a raw socket file descriptor, which +/// takes care of closing it when no other references to it exist. +/// +/// # Notes +/// The drop implementation of this type calls the close syscall. +/// Closing via drop is best-effort as of now (errors are ignored). +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub(crate) struct RawSocketFileDescriptor(pub(crate) i32); + +impl Drop for RawSocketFileDescriptor { + fn drop(&mut self) { + unsafe { + sys::sceNetInetClose(self.0); + }; + } +} + +impl Deref for RawSocketFileDescriptor { + type Target = i32; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +/// Socket file descriptor +/// +/// This is a wrapper around a raw socket file descriptor, which +/// takes care of closing it when no other references to it exist. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct SocketFileDescriptor(pub(crate) Rc); + +impl SocketFileDescriptor { + pub(crate) fn new(fd: i32) -> Self { + Self(Rc::new(RawSocketFileDescriptor(fd))) + } +} + +impl Deref for SocketFileDescriptor { + type Target = i32; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} diff --git a/src/socket/state.rs b/src/socket/state.rs new file mode 100644 index 0000000..ab91237 --- /dev/null +++ b/src/socket/state.rs @@ -0,0 +1,27 @@ +use core::fmt::Debug; + +/// Trait describing the state of a socket +pub trait SocketState: Debug {} + +/// Socket is in an unbound state +#[derive(Debug)] +pub struct Unbound; +impl SocketState for Unbound {} + +/// Socket is in a bound state +#[derive(Debug)] +pub struct Bound; +impl SocketState for Bound {} + +/// Socket is in a connected state +#[derive(Debug)] +pub struct Connected; +impl SocketState for Connected {} + +#[derive(Debug)] +pub struct NotReady; +impl SocketState for NotReady {} + +#[derive(Debug)] +pub struct Ready; +impl SocketState for Ready {} diff --git a/src/socket/tcp.rs b/src/socket/tcp.rs index 4ef423a..4a0a2b5 100644 --- a/src/socket/tcp.rs +++ b/src/socket/tcp.rs @@ -1,4 +1,5 @@ -use alloc::boxed::Box; +#![allow(clippy::module_name_repetitions)] + use alloc::vec::Vec; use embedded_io::{ErrorType, Read, Write}; @@ -14,6 +15,8 @@ use crate::types::{SocketOptions, SocketRecvFlags, SocketSendFlags}; use super::super::netc; use super::error::SocketError; +use super::sce::SocketFileDescriptor; +use super::state::{Connected, SocketState, Unbound}; use super::ToSockaddr; /// A TCP socket @@ -33,25 +36,26 @@ use super::ToSockaddr; /// ```no_run /// use psp::net::TcpSocket; /// -/// let mut socket = TcpSocket::new().unwrap(); +/// let socket = TcpSocket::new().unwrap(); /// let socket_options = SocketOptions{ remote: addr }; -/// socket.open(socket_options).unwrap(); +/// let socket = socket.open(socket_options).unwrap(); /// socket.write(b"hello world").unwrap(); /// socket.flush().unwrap(); /// // no need to call close, as drop will do it /// ``` #[repr(C)] -pub struct TcpSocket { +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct TcpSocket> { /// The socket file descriptor - fd: i32, - /// Whether the socket is connected - is_connected: bool, + pub(super) fd: SocketFileDescriptor, /// The buffer to store data to send - buffer: Box, + buffer: B, /// flags for send calls send_flags: SocketSendFlags, /// flags for recv calls recv_flags: SocketRecvFlags, + /// marker for the socket state + _marker: core::marker::PhantomData, } impl TcpSocket { @@ -63,20 +67,64 @@ impl TcpSocket { /// # Errors /// - [`SocketError::Errno`] if the socket could not be created #[allow(dead_code)] - pub fn new() -> Result { + pub fn new() -> Result, SocketError> { let fd = unsafe { sys::sceNetInetSocket(i32::from(netc::AF_INET), netc::SOCK_STREAM, 0) }; if fd < 0 { Err(SocketError::Errno(unsafe { sys::sceNetInetGetErrno() })) } else { + let fd = SocketFileDescriptor::new(fd); Ok(TcpSocket { fd, - is_connected: false, - buffer: Box::>::default(), + buffer: Vec::with_capacity(0), send_flags: SocketSendFlags::empty(), recv_flags: SocketRecvFlags::empty(), + _marker: core::marker::PhantomData, }) } } +} + +impl TcpSocket { + /// Return the underlying socket's file descriptor + #[must_use] + pub fn fd(&self) -> i32 { + *self.fd + } + + /// Flags used when sending data + #[must_use] + pub fn send_flags(&self) -> SocketSendFlags { + self.send_flags + } + + /// Set the flags used when sending data + pub fn set_send_flags(&mut self, send_flags: SocketSendFlags) { + self.send_flags = send_flags; + } + + /// Flags used when receiving data + #[must_use] + pub fn recv_flags(&self) -> SocketRecvFlags { + self.recv_flags + } + + /// Set the flags used when receiving data + pub fn set_recv_flags(&mut self, recv_flags: SocketRecvFlags) { + self.recv_flags = recv_flags; + } +} + +impl TcpSocket { + #[must_use] + fn transition(self) -> TcpSocket { + TcpSocket { + fd: self.fd, + buffer: Vec::default(), + send_flags: self.send_flags, + recv_flags: self.recv_flags, + _marker: core::marker::PhantomData, + } + } /// Connect to a remote host /// @@ -90,19 +138,14 @@ impl TcpSocket { /// # Errors /// - [`SocketError::UnsupportedAddressFamily`] if the address family is not supported (only IPv4 is supported) /// - Any other [`SocketError`] if the connection was unsuccessful - #[allow(dead_code)] - #[allow(dead_code)] - pub fn connect(&mut self, remote: SocketAddr) -> Result<(), SocketError> { - if self.is_connected { - return Err(SocketError::AlreadyConnected); - } + pub fn connect(self, remote: SocketAddr) -> Result, SocketError> { match remote { SocketAddr::V4(v4) => { let sockaddr = v4.to_sockaddr(); if unsafe { sys::sceNetInetConnect( - self.fd, + *self.fd, &sockaddr, core::mem::size_of::() as u32, ) @@ -111,14 +154,15 @@ impl TcpSocket { let errno = unsafe { sys::sceNetInetGetErrno() }; Err(SocketError::Errno(errno)) } else { - self.is_connected = true; - Ok(()) + Ok(self.transition()) } } SocketAddr::V6(_) => Err(SocketError::UnsupportedAddressFamily), } } +} +impl TcpSocket { /// Read from the socket /// /// # Returns @@ -135,7 +179,7 @@ impl TcpSocket { pub fn _read(&self, buf: &mut [u8]) -> Result { let result = unsafe { sys::sceNetInetRecv( - self.fd, + *self.fd, buf.as_mut_ptr().cast::(), buf.len(), self.recv_flags.as_i32(), @@ -153,19 +197,11 @@ impl TcpSocket { /// # Errors /// - A [`SocketError`] if the write was unsuccessful pub fn _write(&mut self, buf: &[u8]) -> Result { - if !self.is_connected { - return Err(SocketError::NotConnected); - } - self.buffer.append_buffer(buf); self.send() } fn _flush(&mut self) -> Result<(), SocketError> { - if !self.is_connected { - return Err(SocketError::NotConnected); - } - while !self.buffer.is_empty() { self.send()?; } @@ -175,7 +211,7 @@ impl TcpSocket { fn send(&mut self) -> Result { let result = unsafe { sys::sceNetInetSend( - self.fd, + *self.fd, self.buffer.as_slice().as_ptr().cast::(), self.buffer.len(), self.send_flags.as_i32(), @@ -188,71 +224,29 @@ impl TcpSocket { Ok(result as usize) } } - - /// Return the underlying socket's file descriptor - #[must_use] - pub fn fd(&self) -> i32 { - self.fd - } - - /// Return whether the socket is connected - #[must_use] - pub fn is_connected(&self) -> bool { - self.is_connected - } - - /// Flags used when sending data - #[must_use] - pub fn send_flags(&self) -> SocketSendFlags { - self.send_flags - } - - /// Set the flags used when sending data - pub fn set_send_flags(&mut self, send_flags: SocketSendFlags) { - self.send_flags = send_flags; - } - - /// Flags used when receiving data - #[must_use] - pub fn recv_flags(&self) -> SocketRecvFlags { - self.recv_flags - } - - /// Set the flags used when receiving data - pub fn set_recv_flags(&mut self, recv_flags: SocketRecvFlags) { - self.recv_flags = recv_flags; - } } -impl Drop for TcpSocket { - fn drop(&mut self) { - unsafe { - sys::sceNetInetClose(self.fd); - } - } -} - -impl ErrorType for TcpSocket { +impl ErrorType for TcpSocket { type Error = SocketError; } -impl OptionType for TcpSocket { +impl OptionType for TcpSocket { type Options<'a> = SocketOptions; } -impl<'a> Open<'a> for TcpSocket { +impl<'a> Open<'a> for TcpSocket { + type Return<'b> = TcpSocket; /// Return a TCP socket connected to the remote specified in `options` - fn open(mut self, options: &'a Self::Options<'a>) -> Result + fn open(self, options: &'a Self::Options<'a>) -> Result, Self::Error> where Self: Sized, { - self.connect(options.remote())?; - - Ok(self) + let socket = self.connect(options.remote())?; + Ok(socket) } } -impl Read for TcpSocket { +impl Read for TcpSocket { /// Read from the socket /// /// # Parameters @@ -266,23 +260,17 @@ impl Read for TcpSocket { /// - [`SocketError::NotConnected`] if the socket is not connected /// - A [`SocketError`] if the read was unsuccessful fn read<'m>(&'m mut self, buf: &'m mut [u8]) -> Result { - if !self.is_connected { - return Err(SocketError::NotConnected); - } self._read(buf) } } -impl Write for TcpSocket { +impl Write for TcpSocket { /// Write to the socket /// /// # Errors /// - [`SocketError::NotConnected`] if the socket is not connected /// - A [`SocketError`] if the write was unsuccessful fn write<'m>(&'m mut self, buf: &'m [u8]) -> Result { - if !self.is_connected { - return Err(SocketError::NotConnected); - } self._write(buf) } @@ -295,4 +283,4 @@ impl Write for TcpSocket { } } -impl EasySocket for TcpSocket {} +impl EasySocket for TcpSocket {} diff --git a/src/socket/tls.rs b/src/socket/tls.rs index 52aa198..41e63ec 100644 --- a/src/socket/tls.rs +++ b/src/socket/tls.rs @@ -1,3 +1,5 @@ +#![allow(clippy::module_name_repetitions)] + use alloc::string::String; use embedded_io::{ErrorType, Read, Write}; use embedded_tls::{blocking::TlsConnection, Aes128GcmSha256, NoVerify, TlsConfig, TlsContext}; @@ -11,7 +13,7 @@ use crate::{ types::TlsSocketOptions, }; -use super::tcp::TcpSocket; +use super::{state::Connected, tcp::TcpSocket}; lazy_static::lazy_static! { static ref REGEX: Regex = Regex::new("\r|\0").unwrap(); @@ -21,7 +23,7 @@ lazy_static::lazy_static! { /// This is a wrapper around a [`TcpSocket`] that provides a TLS connection. pub struct TlsSocket<'a> { /// The TLS connection - tls_connection: TlsConnection<'a, TcpSocket, Aes128GcmSha256>, + tls_connection: TlsConnection<'a, TcpSocket, Aes128GcmSha256>, /// The TLS config tls_config: TlsConfig<'a, Aes128GcmSha256>, } @@ -51,13 +53,13 @@ impl<'a> TlsSocket<'a> { /// # Notes /// In most cases you can pass `None` for the `cert` parameter. pub fn new( - socket: TcpSocket, + socket: TcpSocket, record_read_buf: &'a mut [u8], record_write_buf: &'a mut [u8], ) -> Self { let tls_config: TlsConfig<'_, Aes128GcmSha256> = TlsConfig::new(); - let tls_connection: TlsConnection = + let tls_connection: TlsConnection, Aes128GcmSha256> = TlsConnection::new(socket, record_read_buf, record_write_buf); TlsSocket { tls_connection, @@ -116,6 +118,7 @@ impl<'a, 'b> Open<'a> for TlsSocket<'b> where 'a: 'b, { + type Return<'c> = Self; /// Open the TLS connection. /// /// # Parameters diff --git a/src/socket/udp.rs b/src/socket/udp.rs index 393466e..4c7c1f1 100644 --- a/src/socket/udp.rs +++ b/src/socket/udp.rs @@ -1,4 +1,6 @@ -use alloc::{boxed::Box, vec::Vec}; +#![allow(clippy::module_name_repetitions)] + +use alloc::vec::Vec; use embedded_io::{ErrorType, Read, Write}; use embedded_nal::{IpAddr, Ipv4Addr, SocketAddr}; use psp::sys::{self, sockaddr, socklen_t}; @@ -13,18 +15,13 @@ use crate::{ types::{SocketOptions, SocketRecvFlags, SocketSendFlags}, }; -use super::{super::netc, error::SocketError, ToSockaddr, ToSocketAddr}; - -/// The state of a [`UdpSocket`] -#[derive(Clone, Copy, PartialEq, Eq)] -pub enum UdpSocketState { - /// The socket is not yet bound (the bind method has not been called) - Unbound, - /// The socket is bound (the bind method has been called) - Bound, - /// The socket is connected - Connected, -} +use super::{ + super::netc, + error::SocketError, + sce::SocketFileDescriptor, + state::{Bound, Connected, SocketState, Unbound}, + ToSockaddr, ToSocketAddr, +}; /// A UDP socket /// @@ -37,19 +34,20 @@ pub enum UdpSocketState { /// [`write`](embedded_io::Write::write) and [`read`](embedded_io::Read::read) methods). /// - The socket is closed when the struct is dropped. Closing via drop is best-effort. #[repr(C)] -pub struct UdpSocket { +#[derive(Clone)] +pub struct UdpSocket> { /// The socket file descriptor - fd: i32, + fd: SocketFileDescriptor, /// The remote host to connect to remote: Option, - /// The state of the socket - state: UdpSocketState, /// The buffer to store data to send - buffer: Box, + buffer: B, /// flags for send calls send_flags: SocketSendFlags, /// flags for recv calls recv_flags: SocketRecvFlags, + /// marker for the socket state + _marker: core::marker::PhantomData, } impl UdpSocket { @@ -63,21 +61,76 @@ impl UdpSocket { /// # Errors /// - [`SocketError::Errno`] if the socket could not be created #[allow(dead_code)] - pub fn new() -> Result { + pub fn new() -> Result, SocketError> { let fd = unsafe { sys::sceNetInetSocket(i32::from(netc::AF_INET), netc::SOCK_DGRAM, 0) }; if fd < 0 { Err(SocketError::Errno(unsafe { sys::sceNetInetGetErrno() })) } else { + let fd = SocketFileDescriptor::new(fd); Ok(UdpSocket { fd, remote: None, - state: UdpSocketState::Unbound, - buffer: Box::>::default(), + buffer: Vec::with_capacity(0), send_flags: SocketSendFlags::empty(), recv_flags: SocketRecvFlags::empty(), + _marker: core::marker::PhantomData, }) } } +} + +impl UdpSocket { + fn socket_len() -> socklen_t { + core::mem::size_of::() as u32 + } + + /// Get the file descriptor of the socket + #[must_use] + pub fn fd(&self) -> i32 { + *self.fd + } + + /// Get the remote address of the socket + #[must_use] + pub fn remote(&self) -> Option { + self.remote.map(|sockaddr| sockaddr.to_socket_addr()) + } + + /// Flags used when sending data + #[must_use] + pub fn send_flags(&self) -> SocketSendFlags { + self.send_flags + } + + /// Set the flags used when sending data + pub fn set_send_flags(&mut self, send_flags: SocketSendFlags) { + self.send_flags = send_flags; + } + + /// Flags used when receiving data + #[must_use] + pub fn recv_flags(&self) -> SocketRecvFlags { + self.recv_flags + } + + /// Set the flags used when receiving data + pub fn set_recv_flags(&mut self, recv_flags: SocketRecvFlags) { + self.recv_flags = recv_flags; + } +} + +impl UdpSocket { + /// Transition the socket to `Bound` state + fn transition(self, remote: Option) -> UdpSocket { + UdpSocket { + fd: self.fd, + remote, + buffer: Vec::with_capacity(0), + send_flags: self.send_flags, + recv_flags: self.recv_flags, + _marker: core::marker::PhantomData, + } + } /// Bind the socket /// @@ -85,18 +138,13 @@ impl UdpSocket { /// - `addr`: The address to bind to, if `None` binds to `0.0.0.0:0` /// /// # Returns - /// - `Ok(())` if the binding was successful - /// - `Err(String)` if the binding was unsuccessful. + /// - `Ok(UdpSocket)` if the binding was successful + /// - `Err(SocketError)` if the binding was unsuccessful. /// /// # Errors - /// - [`SocketError::AlreadyBound`] if the socket is already bound /// - [`SocketError::Errno`] if the binding was unsuccessful #[allow(unused)] - pub fn bind(&mut self, addr: Option) -> Result<(), SocketError> { - if self.state != UdpSocketState::Unbound { - return Err(SocketError::AlreadyBound); - } - + pub fn bind(mut self, addr: Option) -> Result, SocketError> { let default_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0); let addr = addr.unwrap_or(default_addr); match addr { @@ -105,7 +153,7 @@ impl UdpSocket { if unsafe { sys::sceNetInetBind( - self.fd, + *self.fd, &sockaddr, core::mem::size_of::() as u32, ) @@ -114,14 +162,26 @@ impl UdpSocket { let errno = unsafe { sys::sceNetInetGetErrno() }; Err(SocketError::Errno(errno)) } else { - self.remote = Some(sockaddr); - self.state = UdpSocketState::Bound; - Ok(()) + Ok(self.transition(Some(sockaddr))) } } SocketAddr::V6(_) => Err(SocketError::UnsupportedAddressFamily), } } +} + +impl UdpSocket { + /// Transition the socket to `Connected` state + fn transition(self, remote: sockaddr, buf: Option>) -> UdpSocket { + UdpSocket { + fd: self.fd, + remote: Some(remote), + buffer: buf.unwrap_or_default(), + send_flags: self.send_flags, + recv_flags: self.recv_flags, + _marker: core::marker::PhantomData, + } + } /// Connect to a remote host /// @@ -130,89 +190,48 @@ impl UdpSocket { /// To bind the socket use [`bind()`](UdpSocket::bind). /// /// # Returns - /// - `Ok(())` if the connection was successful + /// - `Ok(UdpSocket)` if the connection was successful /// - `Err(SocketError)` if the connection was unsuccessful /// /// # Errors - /// - [`SocketError::NotBound`] if the socket is not bound - /// - [`SocketError::AlreadyConnected`] if the socket is already connected - /// - Any other [`SocketError`] if the connection was unsuccessful + /// - Any [`SocketError`] if the connection was unsuccessful #[allow(unused)] - pub fn connect(&mut self, addr: SocketAddr) -> Result<(), SocketError> { - match self.state { - UdpSocketState::Unbound => return Err(SocketError::NotBound), - UdpSocketState::Bound => {} - UdpSocketState::Connected => return Err(SocketError::AlreadyConnected), - } - + pub fn connect(mut self, addr: SocketAddr) -> Result, SocketError> { match addr { SocketAddr::V4(v4) => { let sockaddr = v4.to_sockaddr(); - if unsafe { sys::sceNetInetConnect(self.fd, &sockaddr, Self::socket_len()) } != 0 { + if unsafe { sys::sceNetInetConnect(*self.fd, &sockaddr, Self::socket_len()) } != 0 { let errno = unsafe { sys::sceNetInetGetErrno() }; Err(SocketError::Errno(errno)) } else { - self.remote = Some(sockaddr); - self.state = UdpSocketState::Connected; - Ok(()) + Ok(self.transition(sockaddr, None)) } } SocketAddr::V6(_) => Err(SocketError::UnsupportedAddressFamily), } } - /// Read from a socket in state [`UdpSocketState::Connected`] - /// - /// # Returns - /// - `Ok(usize)` if the read was successful. The number of bytes read - /// - `Err(SocketError)` if the read was unsuccessful. + /// Read from a bound socket /// - /// # Errors - /// - [`SocketError::NotConnected`] if the socket is not connected - /// - Any other [`SocketError`] if the read was unsuccessful - #[allow(unused)] - pub fn _read(&mut self, buf: &mut [u8]) -> Result { - if self.state != UdpSocketState::Connected { - return Err(SocketError::NotConnected); - } - let mut sockaddr = self.remote.ok_or(SocketError::Other)?; - let result = unsafe { - sys::sceNetInetRecv( - self.fd, - buf.as_mut_ptr().cast::(), - buf.len(), - self.recv_flags.as_i32(), - ) - }; - if result < 0 { - Err(SocketError::Errno(unsafe { sys::sceNetInetGetErrno() })) - } else { - Ok(result as usize) - } - } - - /// Write to a socket in state [`UdpSocketState::Bound`] + /// # Parameters + /// - `buf`: The buffer where to store the received data /// /// # Returns - /// - `Ok(usize)` if the write was successful. The number of bytes read + /// - `Ok((usize, UdpSocket))` if the write was successful. The number of bytes read /// - `Err(SocketError)` if the read was unsuccessful. /// /// # Errors - /// - [`SocketError::NotBound`] if the socket is not bound - /// - [`SocketError::AlreadyConnected`] if the socket is already connected - /// - Any other [`SocketError`] if the read was unsuccessful + /// - Any [`SocketError`] if the read was unsuccessful #[allow(unused)] - pub fn _read_from(&mut self, buf: &mut [u8]) -> Result { - match self.state { - UdpSocketState::Unbound => return Err(SocketError::NotBound), - UdpSocketState::Bound => {} - UdpSocketState::Connected => return Err(SocketError::AlreadyConnected), - } + pub fn _read_from( + mut self, + buf: &mut [u8], + ) -> Result<(usize, UdpSocket), SocketError> { let mut sockaddr = self.remote.ok_or(SocketError::Other)?; let result = unsafe { sys::sceNetInetRecvfrom( - self.fd, + *self.fd, buf.as_mut_ptr().cast::(), buf.len(), self.recv_flags.as_i32(), @@ -223,44 +242,41 @@ impl UdpSocket { if result < 0 { Err(SocketError::Errno(unsafe { sys::sceNetInetGetErrno() })) } else { - Ok(result as usize) + Ok((result as usize, self.transition(sockaddr, None))) } } - /// Write to a socket in state [`UdpSocketState::Bound`] + /// Write to a bound socket + /// + /// # Parameters + /// - `buf`: The buffer containing the data to send + /// /// - /// /// /// # Returns - /// - `Ok(usize)` if the send was successful. The number of bytes sent + /// - `Ok((usize, UdpSocket))` if the send was successful. The number of bytes sent /// - `Err(SocketError)` if the send was unsuccessful. /// /// # Errors - /// If the socket is not in state [`UdpSocketState::Connected`] this will return an error. - /// It may also error if the socket fails to send the data. + /// - Any [`SocketError`] if the send was unsuccessful #[allow(unused)] pub fn _write_to( - &mut self, + self, buf: &[u8], len: usize, to: SocketAddr, - ) -> Result { - match self.state { - UdpSocketState::Unbound => return Err(SocketError::NotBound), - UdpSocketState::Bound => {} - UdpSocketState::Connected => return Err(SocketError::AlreadyConnected), - } - + ) -> Result<(usize, UdpSocket), SocketError> { let sockaddr = match to { SocketAddr::V4(v4) => Ok(super::socket_addr_v4_to_sockaddr(v4)), SocketAddr::V6(_) => Err(SocketError::UnsupportedAddressFamily), }?; let socklen = core::mem::size_of::() as u32; - self.buffer.append_buffer(buf); + let mut buffer = Vec::with_capacity(buf.len()); + buffer.append_buffer(buf); let result = unsafe { sys::sceNetInetSendto( - self.fd, + *self.fd, buf.as_ptr().cast::(), len, self.send_flags.as_i32(), @@ -271,41 +287,60 @@ impl UdpSocket { if result < 0 { Err(SocketError::Errno(unsafe { sys::sceNetInetGetErrno() })) } else { - self.buffer.shift_left_buffer(result as usize); + buffer.shift_left_buffer(result as usize); + Ok((result as usize, self.transition(sockaddr, Some(buffer)))) + } + } +} + +impl UdpSocket { + /// Read from a socket + /// + /// # Parameters + /// - `buf`: The buffer where to store the received data + /// + /// # Returns + /// - `Ok(usize)` if the read was successful. The number of bytes read + /// - `Err(SocketError)` if the read was unsuccessful. + /// + /// # Errors + /// - Any [`SocketError`] if the read was unsuccessful + #[allow(unused)] + pub fn _read(&mut self, buf: &mut [u8]) -> Result { + let result = unsafe { + sys::sceNetInetRecv( + *self.fd, + buf.as_mut_ptr().cast::(), + buf.len(), + self.recv_flags.as_i32(), + ) + }; + if result < 0 { + Err(SocketError::Errno(unsafe { sys::sceNetInetGetErrno() })) + } else { Ok(result as usize) } } - /// Write to a socket in state [`UdpSocketState::Connected`] + /// Write to a socket /// /// # Returns /// - `Ok(usize)` if the send was successful. The number of bytes sent /// - `Err(SocketError)` if the send was unsuccessful. /// /// # Errors - /// If the socket is not in state [`UdpSocketState::Connected`] this will return an error. - /// It may also error if the socket fails to send the data. + /// - Any [`SocketError`] if the send was unsuccessful #[allow(unused)] pub fn _write(&mut self, buf: &[u8]) -> Result { - if self.state != UdpSocketState::Connected { - return Err(SocketError::NotConnected); - } - self.buffer.append_buffer(buf); - self.send() } /// Flush the send buffer /// /// # Errors - /// - [`SocketError::NotConnected`] if the socket is not connected - /// - Any other [`SocketError`] if the flush was unsuccessful + /// - Any [`SocketError`] if the flush was unsuccessful. pub fn _flush(&mut self) -> Result<(), SocketError> { - if self.state != UdpSocketState::Connected { - return Err(SocketError::NotConnected); - } - while !self.buffer.is_empty() { self.send()?; } @@ -315,7 +350,7 @@ impl UdpSocket { fn send(&mut self) -> Result { let result = unsafe { sys::sceNetInetSend( - self.fd, + *self.fd, self.buffer.as_slice().as_ptr().cast::(), self.buffer.len(), self.send_flags.as_i32(), @@ -328,139 +363,73 @@ impl UdpSocket { Ok(result as usize) } } - - /// Get the state of the socket - /// - /// # Returns - /// The state of the socket, one of [`UdpSocketState`] - #[must_use] - pub fn get_state(&self) -> UdpSocketState { - self.state - } - - fn socket_len() -> socklen_t { - core::mem::size_of::() as u32 - } - - /// Get the file descriptor of the socket - #[must_use] - pub fn fd(&self) -> i32 { - self.fd - } - - /// Get the remote address of the socket - #[must_use] - pub fn remote(&self) -> Option { - self.remote.map(|sockaddr| sockaddr.to_socket_addr()) - } - - /// Get the state of the socket - #[must_use] - pub fn state(&self) -> UdpSocketState { - self.state - } - - /// Flags used when sending data - #[must_use] - pub fn send_flags(&self) -> SocketSendFlags { - self.send_flags - } - - /// Set the flags used when sending data - pub fn set_send_flags(&mut self, send_flags: SocketSendFlags) { - self.send_flags = send_flags; - } - - /// Flags used when receiving data - #[must_use] - pub fn recv_flags(&self) -> SocketRecvFlags { - self.recv_flags - } - - /// Set the flags used when receiving data - pub fn set_recv_flags(&mut self, recv_flags: SocketRecvFlags) { - self.recv_flags = recv_flags; - } -} - -impl Drop for UdpSocket { - /// Close the socket - fn drop(&mut self) { - unsafe { - sys::sceNetInetClose(self.fd); - } - } } -impl OptionType for UdpSocket { +impl OptionType for UdpSocket { type Options<'a> = SocketOptions; } -impl ErrorType for UdpSocket { +impl ErrorType for UdpSocket { type Error = SocketError; } -impl<'a> Open<'a> for UdpSocket { +impl<'a> Open<'a> for UdpSocket { + type Return<'b> = UdpSocket; /// Open the socket /// /// # Parameters - /// - `options`: The options to use when opening the socket + /// - `options`: The options to use when opening the socket. /// /// # Returns - /// - `Ok(Self)` if the socket was opened successfully - /// - `Err(SocketError)` if the socket failed to open + /// - `Ok(UdpSocket)` if the socket was opened successfully + /// - `Err(SocketError)` if the socket failed to open. /// /// # Examples /// ```no_run - /// let mut socket = UdpSocket::new()?; - /// socket.open(&SocketOptions::default())?; + /// let socket = UdpSocket::new()?; + /// let socket = socket.open(&SocketOptions::default())?; /// ``` - fn open(mut self, options: &'a Self::Options<'a>) -> Result { - self.bind(None)?; - self.connect(options.remote())?; - - Ok(self) + fn open(self, options: &'a Self::Options<'a>) -> Result, Self::Error> { + let sock = self.bind(None)?; + let sock = sock.connect(options.remote())?; + Ok(sock) } } -impl Read for UdpSocket { +impl Read for UdpSocket { /// Read from the socket /// - /// # Notes - /// If the socket is in state [`UdpSocketState::Unbound`] this will return an error, - /// otherwise it will attempt to read from the socket. You can check the state of the socket - /// using [`get_state`](Self::get_state). + /// # Parameters + /// - `buf`: The buffer where the read data will be stored + /// + /// # Returns + /// - `Ok(usize)` if the read was successful. The number of bytes read + /// - `Err(SocketError)` if the read was unsuccessful. fn read(&mut self, buf: &mut [u8]) -> Result { - match self.get_state() { - UdpSocketState::Unbound => Err(SocketError::NotBound), - UdpSocketState::Bound => self._read_from(buf), - UdpSocketState::Connected => self._read(buf), - } + self._read(buf) } } -impl Write for UdpSocket { +impl Write for UdpSocket { /// Write to the socket /// - /// # Notes - /// If the socket is not in state [`UdpSocketState::Connected`] this will return an error. - /// To connect to a remote host use [`connect`](UdpSocket::connect) first. + /// # Parameters + /// - `buf`: The data to write + /// + /// # Returns + /// - `Ok(usize)` if the write was successful. The number of bytes written + /// - `Err(SocketError)` if the write was unsuccessful. fn write(&mut self, buf: &[u8]) -> Result { - match self.get_state() { - UdpSocketState::Unbound => Err(SocketError::NotBound), - UdpSocketState::Bound => Err(SocketError::NotConnected), - UdpSocketState::Connected => self._write(buf), - } + self._write(buf) } /// Flush the socket + /// + /// # Errors + /// - Any [`SocketError`] if the flush was unsuccessful. fn flush(&mut self) -> Result<(), Self::Error> { - match self.get_state() { - UdpSocketState::Unbound => Err(SocketError::NotBound), - UdpSocketState::Bound => Err(SocketError::NotConnected), - UdpSocketState::Connected => self._flush(), - } + self._flush() } } -impl EasySocket for UdpSocket {} +impl EasySocket for UdpSocket {} diff --git a/src/traits/dns.rs b/src/traits/dns.rs index 656bc45..38a1df6 100644 --- a/src/traits/dns.rs +++ b/src/traits/dns.rs @@ -1,3 +1,5 @@ +#![allow(clippy::module_name_repetitions)] + use core::fmt::Debug; use alloc::string::String; diff --git a/src/traits/io.rs b/src/traits/io.rs index 9edcd9d..88d6564 100644 --- a/src/traits/io.rs +++ b/src/traits/io.rs @@ -4,6 +4,8 @@ pub trait OptionType { /// Type implementing this trait support a Open semantics. pub trait Open<'a>: ErrorType + OptionType { + type Return<'b>; + /// Open a resource, using options for configuration. /// /// # Arguments @@ -15,7 +17,7 @@ pub trait Open<'a>: ErrorType + OptionType { /// # Notes /// See [`TlsSocketOptions`](crate::types::TlsSocketOptions) for more information /// on the options you can pass. - fn open(self, options: &'a Self::Options<'a>) -> Result + fn open(self, options: &'a Self::Options<'a>) -> Result, Self::Error> where Self: Sized; } @@ -34,7 +36,7 @@ pub trait Open<'a>: ErrorType + OptionType { /// # Notes /// [`EasySocket`] types should implement in their [`drop`] method the steps required /// to close the acquired resources. -pub trait EasySocket: for<'a> Open<'a> + Write + Read {} +pub trait EasySocket: Write + Read {} // re-exports pub trait Write = embedded_io::Write; diff --git a/src/traits/mod.rs b/src/traits/mod.rs index 77f5b3f..4b1b815 100644 --- a/src/traits/mod.rs +++ b/src/traits/mod.rs @@ -1,10 +1,11 @@ use alloc::vec::Vec; +use core::fmt::Debug; pub mod dns; pub mod io; /// A trait for a buffer that can be used with a socket -pub trait SocketBuffer { +pub trait SocketBuffer: Clone + Debug + Default { /// Create a new buffer fn new() -> Self where diff --git a/src/types/socket_flags.rs b/src/types/socket_flags.rs index 7d5a90b..00c109d 100644 --- a/src/types/socket_flags.rs +++ b/src/types/socket_flags.rs @@ -2,7 +2,7 @@ use bitflags::bitflags; bitflags! { /// Socket flags to use in send calls - #[derive(Clone, Copy, Debug, PartialEq, Eq, Default)] + #[derive(Clone, Copy, Debug, PartialEq, Eq, Default, Hash)] pub struct SocketSendFlags: u32 { /// No flags passed. Equivalent to `0x0`. const NONE = 0x0; @@ -23,7 +23,7 @@ impl SocketSendFlags { bitflags! { /// Socket flags to use in recv calls - #[derive(Clone, Copy, Debug, PartialEq, Eq, Default)] + #[derive(Clone, Copy, Debug, PartialEq, Eq, Default, Hash)] pub struct SocketRecvFlags: u32 { /// No flags passed. Equivalent to `0x0`. const NONE = 0x0; diff --git a/src/types/socket_options.rs b/src/types/socket_options.rs index ea0152f..19e1e1d 100644 --- a/src/types/socket_options.rs +++ b/src/types/socket_options.rs @@ -1,3 +1,5 @@ +#![allow(clippy::module_name_repetitions)] + use alloc::string::String; use crate::socket::SocketAddr;