Skip to content

Commit

Permalink
fix: correct response size calculation for modbus functions
Browse files Browse the repository at this point in the history
- Fix incorrect quantity extraction leading to truncated responses and timeouts
- Add get_u16_from_request helper for safe byte extraction from request frames
- Add get_quantity function to properly handle quantity based on function code:
  - Read functions (0x01-0x04) and write multiple (0x0F, 0x10): extract from bytes 4-5
  - Write single functions (0x05, 0x06): fixed quantity of 1
  - Other functions: default to quantity 1
- Improve error handling for malformed request frames
  • Loading branch information
aljen committed Dec 3, 2024
1 parent 3ec7f14 commit 71c76f9
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 60 deletions.
77 changes: 72 additions & 5 deletions src/modbus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,76 @@ pub fn guess_response_size(function: u8, quantity: u16) -> usize {
}
}

/// Extracts a 16-bit unsigned integer from a Modbus RTU request frame starting at the specified index.
///
/// This function attempts to retrieve two consecutive bytes from the provided request slice,
/// starting at the given index, and converts them into a `u16` value using big-endian byte order.
/// If the request slice is too short to contain the required bytes, it returns a `RelayError`
/// indicating an invalid frame format.
///
/// # Arguments
///
/// * `request` - A slice of bytes representing the Modbus RTU request frame.
/// * `start` - The starting index within the request slice from which to extract the `u16` value.
///
/// # Returns
///
/// A `Result` containing the extracted `u16` value if successful, or a `RelayError` if the request
/// slice is too short.
///
/// # Errors
///
/// Returns a `RelayError` with `FrameErrorKind::InvalidFormat` if the request slice does not contain
/// enough bytes to extract a `u16` value starting at the specified index.
fn get_u16_from_request(request: &[u8], start: usize) -> Result<u16, RelayError> {
request
.get(start..start + 2)
.map(|bytes| u16::from_be_bytes([bytes[0], bytes[1]]))
.ok_or_else(|| {
RelayError::frame(
FrameErrorKind::InvalidFormat,
"Request too short for register quantity".to_string(),
Some(request.to_vec()),
)
})
}

/// Extracts the quantity of coils or registers from a Modbus RTU request frame based on the function code.
///
/// This function determines the quantity of coils or registers involved in a Modbus RTU request
/// by examining the function code and extracting the appropriate bytes from the request frame.
/// For read functions (0x01 to 0x04) and write multiple functions (0x0F, 0x10), it extracts a 16-bit
/// unsigned integer from bytes 4 and 5 of the request frame. For write single functions (0x05, 0x06),
/// it returns a fixed quantity of 1. For other function codes, it defaults to a quantity of 1.
///
/// # Arguments
///
/// * `function_code` - The Modbus function code.
/// * `request` - A slice of bytes representing the Modbus RTU request frame.
///
/// # Returns
///
/// A `Result` containing the extracted quantity as a `u16` value if successful, or a `RelayError` if the request
/// slice is too short or the function code is invalid.
///
/// # Errors
///
/// Returns a `RelayError` with `FrameErrorKind::InvalidFormat` if the request slice does not contain
/// enough bytes to extract the quantity for the specified function code.
pub fn get_quantity(function_code: u8, request: &[u8]) -> Result<u16, RelayError> {
match function_code {
// For read functions (0x01 to 0x04) and write multiple functions (0x0F, 0x10),
// extract the quantity from bytes 4 and 5 of the request frame.
0x01..=0x04 | 0x0F | 0x10 => get_u16_from_request(request, 4),

// For write single functions (0x05, 0x06), the quantity is always 1.
0x05 | 0x06 => Ok(1),

// For other function codes, default the quantity to 1.
_ => Ok(1),
}
}

pub struct ModbusProcessor {
transport: Arc<RtuTransport>,
}
Expand Down Expand Up @@ -147,11 +217,8 @@ impl ModbusProcessor {

// Estimate the expected RTU response size
let function_code = pdu.first().copied().unwrap_or(0);
let quantity = if pdu.len() >= 4 {
u16::from_be_bytes([pdu[2], pdu[3]])
} else {
0
};
let quantity = get_quantity(function_code, &rtu_request)?;

let expected_response_size = guess_response_size(function_code, quantity);

// Allocate buffer for RTU response
Expand Down
104 changes: 49 additions & 55 deletions src/rtu_transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@ use tracing::{info, trace};

use crate::{RtsError, RtsType};

use crate::{
guess_response_size, FrameErrorKind, IoOperation, RelayError, RtuConfig, TransportError,
};
use crate::{FrameErrorKind, IoOperation, RelayError, RtuConfig, TransportError};

pub struct RtuTransport {
port: Mutex<Box<dyn SerialPort>>,
Expand Down Expand Up @@ -99,7 +97,7 @@ impl RtuTransport {
Ok(())
}

fn set_rts(&self, on: bool) -> Result<(), TransportError> {
fn set_rts(&self, on: bool, trace_frames: bool) -> Result<(), TransportError> {
let rts_span = tracing::info_span!(
"rts_control",
signal = if on { "HIGH" } else { "LOW" },
Expand Down Expand Up @@ -137,7 +135,9 @@ impl RtuTransport {
))));
}

info!("RTS set to {}", if on { "HIGH" } else { "LOW" });
if trace_frames {
trace!("RTS set to {}", if on { "HIGH" } else { "LOW" });
}
}

Ok(())
Expand Down Expand Up @@ -173,59 +173,40 @@ impl RtuTransport {
));
}

let expected_size = response.len();

if self.trace_frames {
trace!("TX: {} bytes: {:02X?}", request.len(), request);
trace!("Expected response size: {} bytes", expected_size);
}

let function = request.get(1).ok_or_else(|| {
RelayError::frame(
FrameErrorKind::InvalidFormat,
"Request too short to contain function code".to_string(),
Some(request.to_vec()),
)
})?;

let quantity = if *function == 0x03 || *function == 0x04 {
u16::from_be_bytes([
*request.get(4).ok_or_else(|| {
RelayError::frame(
FrameErrorKind::InvalidFormat,
"Request too short for register quantity".to_string(),
Some(request.to_vec()),
)
})?,
*request.get(5).ok_or_else(|| {
RelayError::frame(
FrameErrorKind::InvalidFormat,
"Request too short for register quantity".to_string(),
Some(request.to_vec()),
)
})?,
])
} else {
1
};

let expected_size = guess_response_size(*function, quantity);
info!("Expected response size: {} bytes", expected_size);

let transaction_start = Instant::now();

let result = tokio::time::timeout(self.config.transaction_timeout, async {
let mut port = self.port.lock().await;

if self.config.rts_type != RtsType::None {
info!("RTS -> TX mode");
self.set_rts(self.config.rts_type.to_signal_level(true))?;
if self.trace_frames {
trace!("RTS -> TX mode");
}

self.set_rts(
self.config.rts_type.to_signal_level(true),
self.trace_frames,
)?;

if self.config.rts_delay_us > 0 {
info!("RTS -> TX mode [waiting]");
if self.trace_frames {
trace!("RTS -> TX mode [waiting]");
}
tokio::time::sleep(Duration::from_micros(self.config.rts_delay_us)).await;
}
}

// Write request
info!("Writing request");
if self.trace_frames {
trace!("Writing request");
}
port.write_all(request).map_err(|e| TransportError::Io {
operation: IoOperation::Write,
details: "Failed to write request".to_string(),
Expand All @@ -239,22 +220,34 @@ impl RtuTransport {
})?;

if self.config.rts_type != RtsType::None {
info!("RTS -> RX mode");
self.set_rts(self.config.rts_type.to_signal_level(false))?;
if self.trace_frames {
trace!("RTS -> RX mode");
}

self.set_rts(
self.config.rts_type.to_signal_level(false),
self.trace_frames,
)?;
}

if self.config.flush_after_write {
info!("RTS -> TX mode [flushing]");
if self.trace_frames {
trace!("RTS -> TX mode [flushing]");
}
self.tc_flush()?;
}

if self.config.rts_type != RtsType::None && self.config.rts_delay_us > 0 {
info!("RTS -> RX mode [waiting]");
if self.trace_frames {
trace!("RTS -> RX mode [waiting]");
}
tokio::time::sleep(Duration::from_micros(self.config.rts_delay_us)).await;
}

// Read response
trace!("Reading response (expecting {} bytes)", expected_size);
if self.trace_frames {
trace!("Reading response (expecting {} bytes)", expected_size);
}

const MAX_TIMEOUTS: u8 = 3;
let mut total_bytes = 0;
Expand All @@ -265,7 +258,6 @@ impl RtuTransport {
while total_bytes < expected_size {
match port.read(&mut response[total_bytes..]) {
Ok(0) => {
trace!("Zero bytes read");
if total_bytes > 0 {
let elapsed = last_read_time.elapsed();
if elapsed >= inter_byte_timeout {
Expand All @@ -276,22 +268,25 @@ impl RtuTransport {
tokio::task::yield_now().await;
}
Ok(n) => {
trace!(
"Read {} bytes: {:02X?}",
n,
&response[total_bytes..total_bytes + n]
);
if self.trace_frames {
trace!(
"Read {} bytes: {:02X?}",
n,
&response[total_bytes..total_bytes + n]
);
}
total_bytes += n;
last_read_time = tokio::time::Instant::now();
consecutive_timeouts = 0;

if total_bytes >= expected_size {
trace!("Received complete response");
if self.trace_frames {
trace!("Received complete response");
}
break;
}
}
Err(e) if e.kind() == std::io::ErrorKind::TimedOut => {
trace!("Read timeout");
if total_bytes > 0 {
let elapsed = last_read_time.elapsed();
if elapsed >= inter_byte_timeout {
Expand Down Expand Up @@ -323,7 +318,6 @@ impl RtuTransport {
}

if total_bytes == 0 {
info!("No response received");
return Err(TransportError::NoResponse {
attempts: consecutive_timeouts,
elapsed: transaction_start.elapsed(),
Expand Down

0 comments on commit 71c76f9

Please sign in to comment.