Skip to content

Commit

Permalink
feat: safe init (#7)
Browse files Browse the repository at this point in the history
* feat: make init functions safe and leveraging `Result`

* docs: improve docs

* docs: update README
  • Loading branch information
lorenzofelletti authored May 27, 2024
1 parent a98cb81 commit 3af2611
Show file tree
Hide file tree
Showing 7 changed files with 153 additions and 30 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,7 @@ It provides many useful features, notably:
- A TLS socket
- A DNS resolver.

The TCP Socket provided by this crate is compatible with [embedded-tls](https://github.com/drogue-iot/embedded-tls) TLS socket library.

# Notes
This crate require the use of nightly Rust.
25 changes: 25 additions & 0 deletions src/dns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ pub struct DnsResolver {

impl DnsResolver {
/// Create a new DNS resolver
///
/// # Parameters
/// - `dns`: The [`SocketAddr`] of the DNS server
///
/// # Errors
/// - [`DnsError::FailedToCreate`]: The DNS resolver failed to create. This may
/// happen if the socket could not be created or bound to the specified address
#[allow(unused)]
pub fn new(dns: SocketAddr) -> Result<Self, DnsError> {
let mut udp_socket = UdpSocket::new().map_err(|_| DnsError::FailedToCreate)?;
Expand Down Expand Up @@ -75,6 +82,11 @@ impl DnsResolver {
/// # Returns
/// - `Ok(in_addr)`: The IP address of the hostname
/// - `Err(())`: If the hostname could not be resolved
///
/// # Errors
/// - [`DnsError::HostnameResolutionFailed`]: The hostname could not be resolved.
/// This may happen if the connection of the socket fails, or if the DNS server
/// does not answer the query, or any other error occurs
pub fn resolve(&mut self, host: &str) -> Result<in_addr, DnsError> {
// connect to the DNS server, if not already
if self.udp_socket.get_state() != UdpSocketState::Connected {
Expand Down Expand Up @@ -158,6 +170,19 @@ impl DnsResolver {
impl traits::dns::ResolveHostname for DnsResolver {
type Error = DnsError;

/// Resolve a hostname to an IP address
///
/// # Parameters
/// - `host`: The hostname to resolve
///
/// # Returns
/// - `Ok(SocketAddr)`: The IP address of the hostname
/// - `Err(DnsError)`: If the hostname could not be resolved
///
/// # Errors
/// - [`DnsError::HostnameResolutionFailed`]: The hostname could not be resolved.
/// This may happen if the connection of the socket fails, or if the DNS server
/// does not answer the query, or any other error occurs
fn resolve_hostname(&mut self, hostname: &str) -> Result<SocketAddr, DnsError> {
self.resolve(hostname).map(|addr| addr.to_socket_addr())
}
Expand Down
2 changes: 1 addition & 1 deletion src/socket/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ pub enum SocketError {

impl Display for SocketError {
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
write!(f, "{:?}", self)
write!(f, "{self:?}")
}
}

Expand Down
3 changes: 2 additions & 1 deletion src/socket/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,15 @@ impl TcpSocket {
}
}

#[allow(dead_code)]
/// Connect to a remote host
///
/// # Parameters
/// - `remote`: The remote host to connect to
///
/// # Returns
/// - `Ok(())` if the connection was successful
#[allow(dead_code)]
#[allow(clippy::cast_possible_truncation)]
/// - `Err(String)` if the connection was unsuccessful.
pub fn connect(&mut self, remote: SocketAddr) -> Result<(), SocketError> {
if self.is_connected {
Expand Down
15 changes: 8 additions & 7 deletions src/socket/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ use crate::{

use super::{super::netc, error::SocketError, ToSockaddr};

#[derive(Clone, Copy, PartialEq, Eq)]
/// The state of a [`UdpSocket`]
#[derive(Clone, Copy, PartialEq, Eq)]
pub enum UdpSocketState {
/// The socket is not yet bound (the bind method has not been called)
Unbound,
Expand Down Expand Up @@ -50,8 +50,8 @@ pub struct UdpSocket {
}

impl UdpSocket {
#[allow(dead_code)]
/// Create a socket
#[allow(dead_code)]
pub fn new() -> Result<UdpSocket, SocketError> {
let fd = unsafe { sys::sceNetInetSocket(netc::AF_INET as i32, netc::SOCK_DGRAM, 0) };
if fd < 0 {
Expand All @@ -66,7 +66,6 @@ impl UdpSocket {
}
}

#[allow(unused)]
/// Bind the socket
///
/// # Parameters
Expand All @@ -75,6 +74,7 @@ impl UdpSocket {
/// # Returns
/// - `Ok(())` if the binding was successful
/// - `Err(String)` if the binding was unsuccessful.
#[allow(unused)]
pub fn bind(&mut self, addr: Option<SocketAddr>) -> Result<(), SocketError> {
if self.state != UdpSocketState::Unbound {
return Err(SocketError::AlreadyBound);
Expand Down Expand Up @@ -106,11 +106,11 @@ impl UdpSocket {
}
}

#[allow(unused)]
/// Connect to a remote host
///
/// # Notes
/// The socket must be in state [`UdpSocketState::Bound`] to connect to a remote host.
#[allow(unused)]
pub fn connect(&mut self, addr: SocketAddr) -> Result<(), SocketError> {
match self.state {
UdpSocketState::Unbound => return Err(SocketError::NotBound),
Expand All @@ -135,8 +135,8 @@ impl UdpSocket {
}
}

#[allow(unused)]
/// Read from a socket in state [`UdpSocketState::Connected`]
#[allow(unused)]
fn _read(&mut self, buf: &mut [u8]) -> Result<usize, SocketError> {
if self.state != UdpSocketState::Connected {
return Err(SocketError::NotConnected);
Expand All @@ -151,8 +151,8 @@ impl UdpSocket {
}
}

#[allow(unused)]
/// Write to a socket in state [`UdpSocketState::Bound`]
#[allow(unused)]
fn _read_from(&mut self, buf: &mut [u8]) -> Result<usize, SocketError> {
match self.state {
UdpSocketState::Unbound => return Err(SocketError::NotBound),
Expand All @@ -177,8 +177,9 @@ impl UdpSocket {
}
}

#[allow(unused)]
/// Write to a socket in state [`UdpSocketState::Bound`]
#[allow(unused)]
#[allow(clippy::cast_possible_truncation)]
fn _write_to(&mut self, buf: &[u8], len: usize, to: SocketAddr) -> Result<usize, SocketError> {
match self.state {
UdpSocketState::Unbound => return Err(SocketError::NotBound),
Expand Down
2 changes: 1 addition & 1 deletion src/traits/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ pub trait Open: ErrorType + OptionType {
/// already close the resources.
///
/// # Notes
/// EasyScoket types should implement in their [`drop`] method the steps required
/// [`EasySocket`] types should implement in their [`drop`] method the steps required
/// to close the acquired resources.
pub trait EasySocket: Open + Write + Read {}
134 changes: 114 additions & 20 deletions src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,31 +1,125 @@
use alloc::{borrow::ToOwned, string::String};

/// Error type for net functions
#[derive(Debug, Clone, PartialEq, Eq)]
#[must_use]
pub enum NetError {
/// Failed to load a net module
LoadModuleFailed(String, i32),
/// Failed to initialize
InitFailed(String, i32),
/// An error occurred when using a net function
Error(String, i32),
}

impl NetError {
pub fn load_module_failed(module: &str, error: i32) -> Self {
NetError::LoadModuleFailed(module.to_owned(), error)
}

pub fn init_failed(fn_name: &str, error: i32) -> Self {
NetError::InitFailed(fn_name.to_owned(), error)
}

pub fn error(fn_name: &str, error: i32) -> Self {
NetError::Error(fn_name.to_owned(), error)
}
}

/// Load net modules
///
/// # Safety
/// This function is unsafe because it loads `rust-psp`'s net modules, which are unsafe.
#[allow(dead_code)]
/// # Errors
/// - [`NetError::LoadModuleFailed`] if the net module could not be loaded
#[allow(unused)]
#[inline]
pub unsafe fn load_net_modules() {
psp::sys::sceUtilityLoadNetModule(psp::sys::NetModule::NetCommon);
psp::sys::sceUtilityLoadNetModule(psp::sys::NetModule::NetInet);
psp::sys::sceUtilityLoadNetModule(psp::sys::NetModule::NetParseUri);
psp::sys::sceUtilityLoadNetModule(psp::sys::NetModule::NetHttp);
pub fn load_net_modules() -> Result<(), NetError> {
unsafe {
let res = psp::sys::sceUtilityLoadNetModule(psp::sys::NetModule::NetCommon);
if res != 0 {
return Err(NetError::load_module_failed("", res));
}

let res = psp::sys::sceUtilityLoadNetModule(psp::sys::NetModule::NetInet);
if res != 0 {
return Err(NetError::load_module_failed("", res));
}

let res = psp::sys::sceUtilityLoadNetModule(psp::sys::NetModule::NetParseUri);
if res != 0 {
return Err(NetError::load_module_failed("", res));
}

let res = psp::sys::sceUtilityLoadNetModule(psp::sys::NetModule::NetHttp);
if res != 0 {
return Err(NetError::load_module_failed("", res));
}

Ok(())
}
}

/// Initialize net modules
/// Initialize network
///
/// # Safety
/// This function is unsafe because it loads `rust-psp`'s net modules, which are unsafe.
#[allow(dead_code)]
/// # Errors
/// - [`NetError::InitFailed`] if the net could not be initialized
#[allow(unused)]
#[inline]
pub unsafe fn net_init() {
psp::sys::sceNetInit(0x20000, 0x20, 0x1000, 0x20, 0x1000);
psp::sys::sceNetInetInit();
psp::sys::sceNetResolverInit();
psp::sys::sceNetApctlInit(0x1600, 42);
pub fn net_init() -> Result<(), NetError> {
unsafe {
let res = psp::sys::sceNetInit(0x20000, 0x20, 0x1000, 0x20, 0x1000);
if res != 0 {
return Err(NetError::init_failed("sceNetInit", res));
}

let res = psp::sys::sceNetInetInit();
if res != 0 {
return Err(NetError::init_failed("sceNetInetInit", res));
}

let res = psp::sys::sceNetResolverInit();
if res != 0 {
return Err(NetError::init_failed("sceNetResolverInit", res));
}

let res = psp::sys::sceNetApctlInit(0x1600, 42);
if res != 0 {
return Err(NetError::init_failed("sceNetApctlInit", res));
}
}

Ok(())
}

#[allow(dead_code)]
/// Select net config
///
/// # Errors
/// This function will return an error if selection fails.
/// The error, if any, will always be [`NetError::Error`].
///
/// # Notes
/// The netconfigs start from 1.
///
/// Remember that this function requires the [net modules](crate::utils::load_net_modules) to be loaded, and
/// [initialised](crate::utils::net_init) first.
#[allow(unused)]
#[inline]
pub fn select_netconfig(id: i32) -> Result<(), NetError> {
unsafe {
let res = psp::sys::sceUtilityCheckNetParam(id);
if res != 0 {
return Err(NetError::error("sceUtilityCheckNetParam", res));
}
}

Ok(())
}

/// Select first net config
///
/// # Errors
/// This function will return an error if selection fails
#[allow(unused)]
#[inline]
pub fn select_netconfig() -> i32 {
unsafe { psp::sys::sceUtilityCheckNetParam(1) }
pub fn select_first_netconfig() -> Result<(), NetError> {
select_netconfig(1)
}

0 comments on commit 3af2611

Please sign in to comment.