diff --git a/src/modbus.rs b/src/modbus.rs index 8e3663f..50cea86 100644 --- a/src/modbus.rs +++ b/src/modbus.rs @@ -101,6 +101,76 @@ pub fn guess_response_size(function: u8, quantity: u16) -> usize { } } +/// Extracts a 16-bit unsigned integer from a Modbus RTU request frame starting at the specified index. +/// +/// This function attempts to retrieve two consecutive bytes from the provided request slice, +/// starting at the given index, and converts them into a `u16` value using big-endian byte order. +/// If the request slice is too short to contain the required bytes, it returns a `RelayError` +/// indicating an invalid frame format. +/// +/// # Arguments +/// +/// * `request` - A slice of bytes representing the Modbus RTU request frame. +/// * `start` - The starting index within the request slice from which to extract the `u16` value. +/// +/// # Returns +/// +/// A `Result` containing the extracted `u16` value if successful, or a `RelayError` if the request +/// slice is too short. +/// +/// # Errors +/// +/// Returns a `RelayError` with `FrameErrorKind::InvalidFormat` if the request slice does not contain +/// enough bytes to extract a `u16` value starting at the specified index. +fn get_u16_from_request(request: &[u8], start: usize) -> Result { + request + .get(start..start + 2) + .map(|bytes| u16::from_be_bytes([bytes[0], bytes[1]])) + .ok_or_else(|| { + RelayError::frame( + FrameErrorKind::InvalidFormat, + "Request too short for register quantity".to_string(), + Some(request.to_vec()), + ) + }) +} + +/// Extracts the quantity of coils or registers from a Modbus RTU request frame based on the function code. +/// +/// This function determines the quantity of coils or registers involved in a Modbus RTU request +/// by examining the function code and extracting the appropriate bytes from the request frame. +/// For read functions (0x01 to 0x04) and write multiple functions (0x0F, 0x10), it extracts a 16-bit +/// unsigned integer from bytes 4 and 5 of the request frame. For write single functions (0x05, 0x06), +/// it returns a fixed quantity of 1. For other function codes, it defaults to a quantity of 1. +/// +/// # Arguments +/// +/// * `function_code` - The Modbus function code. +/// * `request` - A slice of bytes representing the Modbus RTU request frame. +/// +/// # Returns +/// +/// A `Result` containing the extracted quantity as a `u16` value if successful, or a `RelayError` if the request +/// slice is too short or the function code is invalid. +/// +/// # Errors +/// +/// Returns a `RelayError` with `FrameErrorKind::InvalidFormat` if the request slice does not contain +/// enough bytes to extract the quantity for the specified function code. +pub fn get_quantity(function_code: u8, request: &[u8]) -> Result { + match function_code { + // For read functions (0x01 to 0x04) and write multiple functions (0x0F, 0x10), + // extract the quantity from bytes 4 and 5 of the request frame. + 0x01..=0x04 | 0x0F | 0x10 => get_u16_from_request(request, 4), + + // For write single functions (0x05, 0x06), the quantity is always 1. + 0x05 | 0x06 => Ok(1), + + // For other function codes, default the quantity to 1. + _ => Ok(1), + } +} + pub struct ModbusProcessor { transport: Arc, } @@ -147,11 +217,8 @@ impl ModbusProcessor { // Estimate the expected RTU response size let function_code = pdu.first().copied().unwrap_or(0); - let quantity = if pdu.len() >= 4 { - u16::from_be_bytes([pdu[2], pdu[3]]) - } else { - 0 - }; + let quantity = get_quantity(function_code, &rtu_request)?; + let expected_response_size = guess_response_size(function_code, quantity); // Allocate buffer for RTU response diff --git a/src/rtu_transport.rs b/src/rtu_transport.rs index 035c030..018a69e 100644 --- a/src/rtu_transport.rs +++ b/src/rtu_transport.rs @@ -15,9 +15,7 @@ use tracing::{info, trace}; use crate::{RtsError, RtsType}; -use crate::{ - guess_response_size, FrameErrorKind, IoOperation, RelayError, RtuConfig, TransportError, -}; +use crate::{FrameErrorKind, IoOperation, RelayError, RtuConfig, TransportError}; pub struct RtuTransport { port: Mutex>, @@ -99,7 +97,7 @@ impl RtuTransport { Ok(()) } - fn set_rts(&self, on: bool) -> Result<(), TransportError> { + fn set_rts(&self, on: bool, trace_frames: bool) -> Result<(), TransportError> { let rts_span = tracing::info_span!( "rts_control", signal = if on { "HIGH" } else { "LOW" }, @@ -137,7 +135,9 @@ impl RtuTransport { )))); } - info!("RTS set to {}", if on { "HIGH" } else { "LOW" }); + if trace_frames { + trace!("RTS set to {}", if on { "HIGH" } else { "LOW" }); + } } Ok(()) @@ -173,59 +173,40 @@ impl RtuTransport { )); } + let expected_size = response.len(); + if self.trace_frames { trace!("TX: {} bytes: {:02X?}", request.len(), request); + trace!("Expected response size: {} bytes", expected_size); } - let function = request.get(1).ok_or_else(|| { - RelayError::frame( - FrameErrorKind::InvalidFormat, - "Request too short to contain function code".to_string(), - Some(request.to_vec()), - ) - })?; - - let quantity = if *function == 0x03 || *function == 0x04 { - u16::from_be_bytes([ - *request.get(4).ok_or_else(|| { - RelayError::frame( - FrameErrorKind::InvalidFormat, - "Request too short for register quantity".to_string(), - Some(request.to_vec()), - ) - })?, - *request.get(5).ok_or_else(|| { - RelayError::frame( - FrameErrorKind::InvalidFormat, - "Request too short for register quantity".to_string(), - Some(request.to_vec()), - ) - })?, - ]) - } else { - 1 - }; - - let expected_size = guess_response_size(*function, quantity); - info!("Expected response size: {} bytes", expected_size); - let transaction_start = Instant::now(); let result = tokio::time::timeout(self.config.transaction_timeout, async { let mut port = self.port.lock().await; if self.config.rts_type != RtsType::None { - info!("RTS -> TX mode"); - self.set_rts(self.config.rts_type.to_signal_level(true))?; + if self.trace_frames { + trace!("RTS -> TX mode"); + } + + self.set_rts( + self.config.rts_type.to_signal_level(true), + self.trace_frames, + )?; if self.config.rts_delay_us > 0 { - info!("RTS -> TX mode [waiting]"); + if self.trace_frames { + trace!("RTS -> TX mode [waiting]"); + } tokio::time::sleep(Duration::from_micros(self.config.rts_delay_us)).await; } } // Write request - info!("Writing request"); + if self.trace_frames { + trace!("Writing request"); + } port.write_all(request).map_err(|e| TransportError::Io { operation: IoOperation::Write, details: "Failed to write request".to_string(), @@ -239,22 +220,34 @@ impl RtuTransport { })?; if self.config.rts_type != RtsType::None { - info!("RTS -> RX mode"); - self.set_rts(self.config.rts_type.to_signal_level(false))?; + if self.trace_frames { + trace!("RTS -> RX mode"); + } + + self.set_rts( + self.config.rts_type.to_signal_level(false), + self.trace_frames, + )?; } if self.config.flush_after_write { - info!("RTS -> TX mode [flushing]"); + if self.trace_frames { + trace!("RTS -> TX mode [flushing]"); + } self.tc_flush()?; } if self.config.rts_type != RtsType::None && self.config.rts_delay_us > 0 { - info!("RTS -> RX mode [waiting]"); + if self.trace_frames { + trace!("RTS -> RX mode [waiting]"); + } tokio::time::sleep(Duration::from_micros(self.config.rts_delay_us)).await; } // Read response - trace!("Reading response (expecting {} bytes)", expected_size); + if self.trace_frames { + trace!("Reading response (expecting {} bytes)", expected_size); + } const MAX_TIMEOUTS: u8 = 3; let mut total_bytes = 0; @@ -265,7 +258,6 @@ impl RtuTransport { while total_bytes < expected_size { match port.read(&mut response[total_bytes..]) { Ok(0) => { - trace!("Zero bytes read"); if total_bytes > 0 { let elapsed = last_read_time.elapsed(); if elapsed >= inter_byte_timeout { @@ -276,22 +268,25 @@ impl RtuTransport { tokio::task::yield_now().await; } Ok(n) => { - trace!( - "Read {} bytes: {:02X?}", - n, - &response[total_bytes..total_bytes + n] - ); + if self.trace_frames { + trace!( + "Read {} bytes: {:02X?}", + n, + &response[total_bytes..total_bytes + n] + ); + } total_bytes += n; last_read_time = tokio::time::Instant::now(); consecutive_timeouts = 0; if total_bytes >= expected_size { - trace!("Received complete response"); + if self.trace_frames { + trace!("Received complete response"); + } break; } } Err(e) if e.kind() == std::io::ErrorKind::TimedOut => { - trace!("Read timeout"); if total_bytes > 0 { let elapsed = last_read_time.elapsed(); if elapsed >= inter_byte_timeout { @@ -323,7 +318,6 @@ impl RtuTransport { } if total_bytes == 0 { - info!("No response received"); return Err(TransportError::NoResponse { attempts: consecutive_timeouts, elapsed: transaction_start.elapsed(),