Skip to content

Commit

Permalink
Make Message struct generic
Browse files Browse the repository at this point in the history
It is no longer required to be Cow, and for format.rs it is constructed
with PyBackedBytes instead.
  • Loading branch information
bmerry committed Mar 26, 2024
1 parent 33f19ab commit 5981c9f
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 21 deletions.
12 changes: 9 additions & 3 deletions src/format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ use std::io::Write;

use crate::message::{Message, MessageType};

impl Message<'_> {
impl<N, A> Message<N, A>
where
N: AsRef<[u8]>,
A: AsRef<[u8]>,
{
/// Serialize the message
pub fn write<T: Write>(&self, target: &mut T) -> std::io::Result<()> {
let type_symbol = match self.mtype {
Expand All @@ -26,7 +30,7 @@ impl Message<'_> {
MessageType::Inform => b'#',
};
target.write_all(std::slice::from_ref(&type_symbol))?;
target.write_all(&self.name)?;
target.write_all(self.name.as_ref())?;
if let Some(mid) = self.mid {
target.write_all(b"[")?;
let mut buffer = itoa::Buffer::new();
Expand All @@ -35,6 +39,7 @@ impl Message<'_> {
target.write_all(b"]")?;
}
for argument in self.arguments.iter() {
let argument = argument.as_ref();
target.write_all(b" ")?;
if argument.is_empty() {
target.write_all(b"\\@")?;
Expand Down Expand Up @@ -67,13 +72,14 @@ impl Message<'_> {
/// Get the number of bytes needed by [write].
pub fn write_size(&self) -> usize {
// Type symbol, name, spaces and newline
let mut bytes = 2 + self.name.len() + self.arguments.len();
let mut bytes = 2 + self.name.as_ref().len() + self.arguments.len();
if let Some(mid) = self.mid {
let mut buffer = itoa::Buffer::new();
let mid_formatted = buffer.format(mid);
bytes += 2 + mid_formatted.len();
}
for argument in self.arguments.iter() {
let argument = argument.as_ref();
if argument.is_empty() {
bytes += 2; // For the \@
}
Expand Down
39 changes: 24 additions & 15 deletions src/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ use pyo3::prelude::*;
use pyo3::pybacked::PyBackedBytes;
use pyo3::types::{PyBytes, PyList};
use pyo3::PyTraverseError;

use std::borrow::Cow;
use std::ops::Deref;

/// Type of katcp message
#[pyclass(module = "katcp_codec._lib", rename_all = "SCREAMING_SNAKE_CASE")]
Expand All @@ -42,26 +40,34 @@ pub enum MessageType {
/// it can be decoded as ASCII (or UTF-8) but the arguments may contain
/// arbitrary bytes.
#[derive(Clone, Eq, PartialEq, Debug)]
pub struct Message<'data> {
pub struct Message<N, A>
where
N: AsRef<[u8]>,
A: AsRef<[u8]>,
{
/// Message type
pub mtype: MessageType,
/// Message name
pub name: Cow<'data, [u8]>,
pub name: N,
/// Message ID, if present. It must be positive.
pub mid: Option<i32>,
/// Message arguments
pub arguments: Vec<Cow<'data, [u8]>>,
pub arguments: Vec<A>,
}

impl<'data> Message<'data> {
impl<N, A> Message<N, A>
where
N: AsRef<[u8]>,
A: AsRef<[u8]>,
{
/// Create a new message.
///
/// Panics if the message ID is given and is not positive.
pub fn new(
mtype: MessageType,
name: impl Into<Cow<'data, [u8]>>,
name: impl Into<N>,
mid: Option<i32>,
arguments: impl Into<Vec<Cow<'data, [u8]>>>,
arguments: impl Into<Vec<A>>,
) -> Self {
assert!(mid.map_or(true, |x| x > 0));
Self {
Expand Down Expand Up @@ -154,12 +160,11 @@ impl PyMessage {
.as_ref()
.ok_or_else(|| PyValueError::new_err("name is None"))?;
// TODO: this is creating a new vector to hold the arguments.
// Can we use a trait to handle directly iterating the PyList?
let py_arguments: Vec<PyBackedBytes> = arguments.bind(py).extract()?;
let arguments: Vec<_> = py_arguments.iter().map(|x| Cow::from(x.deref())).collect();
// Can we use another trait to handle directly iterating the PyList?
let arguments: Vec<PyBackedBytes> = arguments.bind(py).extract()?;
let message = Message {
mtype: self.mtype,
name: Cow::from(name.as_bytes()),
name: name.as_bytes(),
mid: self.mid,
arguments,
};
Expand All @@ -170,13 +175,17 @@ impl PyMessage {
}
}

impl<'data> ToPyObject for Message<'data> {
impl<N, A> ToPyObject for Message<N, A>
where
N: AsRef<[u8]>,
A: AsRef<[u8]>,
{
fn to_object(&self, py: Python<'_>) -> PyObject {
let py_msg = PyMessage::new(
self.mtype,
PyBytes::new_bound(py, &self.name).unbind(),
PyBytes::new_bound(py, self.name.as_ref()).unbind(),
self.mid,
PyList::new_bound(py, self.arguments.iter()).unbind(),
PyList::new_bound(py, self.arguments.iter().map(|x| Cow::from(x.as_ref()))).unbind(),
);
py_msg.into_py(py)
}
Expand Down
11 changes: 8 additions & 3 deletions src/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ use thiserror::Error;

use crate::message::{Message, MessageType};

type ParsedMessage<'data> = Message<Cow<'data, [u8]>, Cow<'data, [u8]>>;

/// State in the state machine
#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Hash, Enum)]
enum State {
Expand Down Expand Up @@ -152,7 +154,7 @@ impl<'parser, 'data> Iterator for ParseIterator<'parser, 'data>
where
'parser: 'data,
{
type Item = Result<Message<'data>, ParseError>;
type Item = Result<ParsedMessage<'data>, ParseError>;

fn next(&mut self) -> Option<Self::Item> {
let (msg, tail) = self.parser.next_message(self.data, &mut self.transient);
Expand Down Expand Up @@ -423,7 +425,7 @@ impl Parser {
chunk: &'data [u8],
transient: &mut Transient<'data>,
position: usize,
) -> Result<Option<Message<'data>>, ParseError> {
) -> Result<Option<ParsedMessage<'data>>, ParseError> {
match action {
Action::SetType(mtype) => {
self.mtype = Some(*mtype);
Expand Down Expand Up @@ -490,7 +492,10 @@ impl Parser {
&mut self,
mut data: &'data [u8],
transient: &mut Transient<'data>,
) -> (Option<Result<Message<'data>, ParseError>>, &'data [u8]) {
) -> (
Option<Result<ParsedMessage<'data>, ParseError>>,
&'data [u8],
) {
while !data.is_empty() {
if self.line_length >= self.max_line_length && self.state != State::Error {
self.error(transient, "Line too long");
Expand Down

0 comments on commit 5981c9f

Please sign in to comment.