From 5981c9f94cc182483dd51eaff1d354ada8072ebd Mon Sep 17 00:00:00 2001 From: Bruce Merry Date: Tue, 26 Mar 2024 09:51:10 +0200 Subject: [PATCH] Make Message struct generic It is no longer required to be Cow, and for format.rs it is constructed with PyBackedBytes instead. --- src/format.rs | 12 +++++++++--- src/message.rs | 39 ++++++++++++++++++++++++--------------- src/parse.rs | 11 ++++++++--- 3 files changed, 41 insertions(+), 21 deletions(-) diff --git a/src/format.rs b/src/format.rs index b8ce05b..5413148 100644 --- a/src/format.rs +++ b/src/format.rs @@ -17,7 +17,11 @@ use std::io::Write; use crate::message::{Message, MessageType}; -impl Message<'_> { +impl Message +where + N: AsRef<[u8]>, + A: AsRef<[u8]>, +{ /// Serialize the message pub fn write(&self, target: &mut T) -> std::io::Result<()> { let type_symbol = match self.mtype { @@ -26,7 +30,7 @@ impl Message<'_> { MessageType::Inform => b'#', }; target.write_all(std::slice::from_ref(&type_symbol))?; - target.write_all(&self.name)?; + target.write_all(self.name.as_ref())?; if let Some(mid) = self.mid { target.write_all(b"[")?; let mut buffer = itoa::Buffer::new(); @@ -35,6 +39,7 @@ impl Message<'_> { target.write_all(b"]")?; } for argument in self.arguments.iter() { + let argument = argument.as_ref(); target.write_all(b" ")?; if argument.is_empty() { target.write_all(b"\\@")?; @@ -67,13 +72,14 @@ impl Message<'_> { /// Get the number of bytes needed by [write]. pub fn write_size(&self) -> usize { // Type symbol, name, spaces and newline - let mut bytes = 2 + self.name.len() + self.arguments.len(); + let mut bytes = 2 + self.name.as_ref().len() + self.arguments.len(); if let Some(mid) = self.mid { let mut buffer = itoa::Buffer::new(); let mid_formatted = buffer.format(mid); bytes += 2 + mid_formatted.len(); } for argument in self.arguments.iter() { + let argument = argument.as_ref(); if argument.is_empty() { bytes += 2; // For the \@ } diff --git a/src/message.rs b/src/message.rs index dbf7dbc..d92aa5f 100644 --- a/src/message.rs +++ b/src/message.rs @@ -21,9 +21,7 @@ use pyo3::prelude::*; use pyo3::pybacked::PyBackedBytes; use pyo3::types::{PyBytes, PyList}; use pyo3::PyTraverseError; - use std::borrow::Cow; -use std::ops::Deref; /// Type of katcp message #[pyclass(module = "katcp_codec._lib", rename_all = "SCREAMING_SNAKE_CASE")] @@ -42,26 +40,34 @@ pub enum MessageType { /// it can be decoded as ASCII (or UTF-8) but the arguments may contain /// arbitrary bytes. #[derive(Clone, Eq, PartialEq, Debug)] -pub struct Message<'data> { +pub struct Message +where + N: AsRef<[u8]>, + A: AsRef<[u8]>, +{ /// Message type pub mtype: MessageType, /// Message name - pub name: Cow<'data, [u8]>, + pub name: N, /// Message ID, if present. It must be positive. pub mid: Option, /// Message arguments - pub arguments: Vec>, + pub arguments: Vec, } -impl<'data> Message<'data> { +impl Message +where + N: AsRef<[u8]>, + A: AsRef<[u8]>, +{ /// Create a new message. /// /// Panics if the message ID is given and is not positive. pub fn new( mtype: MessageType, - name: impl Into>, + name: impl Into, mid: Option, - arguments: impl Into>>, + arguments: impl Into>, ) -> Self { assert!(mid.map_or(true, |x| x > 0)); Self { @@ -154,12 +160,11 @@ impl PyMessage { .as_ref() .ok_or_else(|| PyValueError::new_err("name is None"))?; // TODO: this is creating a new vector to hold the arguments. - // Can we use a trait to handle directly iterating the PyList? - let py_arguments: Vec = arguments.bind(py).extract()?; - let arguments: Vec<_> = py_arguments.iter().map(|x| Cow::from(x.deref())).collect(); + // Can we use another trait to handle directly iterating the PyList? + let arguments: Vec = arguments.bind(py).extract()?; let message = Message { mtype: self.mtype, - name: Cow::from(name.as_bytes()), + name: name.as_bytes(), mid: self.mid, arguments, }; @@ -170,13 +175,17 @@ impl PyMessage { } } -impl<'data> ToPyObject for Message<'data> { +impl ToPyObject for Message +where + N: AsRef<[u8]>, + A: AsRef<[u8]>, +{ fn to_object(&self, py: Python<'_>) -> PyObject { let py_msg = PyMessage::new( self.mtype, - PyBytes::new_bound(py, &self.name).unbind(), + PyBytes::new_bound(py, self.name.as_ref()).unbind(), self.mid, - PyList::new_bound(py, self.arguments.iter()).unbind(), + PyList::new_bound(py, self.arguments.iter().map(|x| Cow::from(x.as_ref()))).unbind(), ); py_msg.into_py(py) } diff --git a/src/parse.rs b/src/parse.rs index a5180d6..93c4609 100644 --- a/src/parse.rs +++ b/src/parse.rs @@ -25,6 +25,8 @@ use thiserror::Error; use crate::message::{Message, MessageType}; +type ParsedMessage<'data> = Message, Cow<'data, [u8]>>; + /// State in the state machine #[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Hash, Enum)] enum State { @@ -152,7 +154,7 @@ impl<'parser, 'data> Iterator for ParseIterator<'parser, 'data> where 'parser: 'data, { - type Item = Result, ParseError>; + type Item = Result, ParseError>; fn next(&mut self) -> Option { let (msg, tail) = self.parser.next_message(self.data, &mut self.transient); @@ -423,7 +425,7 @@ impl Parser { chunk: &'data [u8], transient: &mut Transient<'data>, position: usize, - ) -> Result>, ParseError> { + ) -> Result>, ParseError> { match action { Action::SetType(mtype) => { self.mtype = Some(*mtype); @@ -490,7 +492,10 @@ impl Parser { &mut self, mut data: &'data [u8], transient: &mut Transient<'data>, - ) -> (Option, ParseError>>, &'data [u8]) { + ) -> ( + Option, ParseError>>, + &'data [u8], + ) { while !data.is_empty() { if self.line_length >= self.max_line_length && self.state != State::Error { self.error(transient, "Line too long");