diff --git a/emulator/app/Cargo.toml b/emulator/app/Cargo.toml index f6e8307..3e8ce1d 100644 --- a/emulator/app/Cargo.toml +++ b/emulator/app/Cargo.toml @@ -38,3 +38,4 @@ test-flash-ctrl-init = [] test-flash-ctrl-read-write-page = [] test-flash-ctrl-erase-page = [] test-mctp-ctrl-cmds = ["emulator-periph/test-mctp-ctrl-cmds"] +test-mctp-send-loopback = ["emulator-periph/test-mctp-send-loopback"] diff --git a/emulator/app/src/i3c_socket.rs b/emulator/app/src/i3c_socket.rs index 61f5797..d1c954a 100644 --- a/emulator/app/src/i3c_socket.rs +++ b/emulator/app/src/i3c_socket.rs @@ -146,7 +146,7 @@ fn handle_i3c_socket_connection( Err(e) => panic!("Error reading message from socket: {}", e), } if let Ok(response) = bus_response_rx.recv_timeout(Duration::from_millis(10)) { - let data_len = response.resp.data.len(); + let data_len = response.resp.resp.data_length() as usize; if data_len > 255 { panic!("Cannot write more than 255 bytes to socket"); } @@ -157,16 +157,23 @@ fn handle_i3c_socket_connection( }; let header_bytes: [u8; 6] = transmute!(outgoing_header); stream.write_all(&header_bytes).unwrap(); - stream.write_all(&response.resp.data).unwrap(); + if data_len > 0 { + stream.write_all(&response.resp.data[..data_len]).unwrap(); + } } } } +pub(crate) trait TestTrait { + fn run_test(&mut self, running: Arc, stream: &mut TcpStream, target_addr: u8); + fn is_passed(&self) -> bool; +} + pub(crate) fn run_tests( running: Arc, port: u16, target_addr: DynamicI3cAddress, - tests: Vec, + tests: Vec>, ) { let running_clone = running.clone(); let addr = SocketAddr::from(([127, 0, 0, 1], port)); @@ -178,7 +185,7 @@ pub(crate) fn run_tests( } #[derive(Debug, Clone)] -enum TestState { +pub enum TestState { Start, SendPrivateWrite, WaitForIbi, @@ -186,138 +193,12 @@ enum TestState { Finish, } -#[derive(Debug, Clone)] -pub(crate) struct Test { - name: String, - state: TestState, - pvt_write_data: Vec, - pvt_read_data: Vec, - passed: bool, -} - -impl Test { - pub(crate) fn new(name: &str, pvt_write_data: Vec, pvt_read_data: Vec) -> Self { - Self { - name: name.to_string(), - state: TestState::Start, - pvt_write_data, - pvt_read_data, - passed: false, - } - } - - fn is_passed(&self) -> bool { - self.passed - } - - fn check_response(&mut self, data: &[u8]) { - if data.len() == self.pvt_read_data.len() && data == self.pvt_read_data { - self.passed = true; - } - } - - fn run_test(&mut self, running: Arc, stream: &mut TcpStream, target_addr: u8) { - stream.set_nonblocking(true).unwrap(); - while running.load(Ordering::Relaxed) { - match self.state { - TestState::Start => { - println!("Starting test: {}", self.name); - self.state = TestState::SendPrivateWrite; - } - TestState::SendPrivateWrite => self.send_private_write(stream, target_addr), - TestState::WaitForIbi => self.receive_ibi(stream, target_addr), - TestState::ReceivePrivateRead => self.receive_private_read(stream, target_addr), - TestState::Finish => { - println!( - "Test {} : {}", - self.name, - if self.passed { "PASSED" } else { "FAILED" } - ); - break; - } - } - } - } - - fn send_private_write(&mut self, stream: &mut TcpStream, target_addr: u8) { - let addr: u8 = target_addr; - let pvt_write_data = self.pvt_write_data.as_slice(); - - let pec = calculate_crc8(addr << 1, pvt_write_data); - - let mut pkt = Vec::new(); - pkt.extend_from_slice(pvt_write_data); - pkt.push(pec); - - let pvt_write_cmd = prepare_private_write_cmd(addr, pkt.len() as u16); - stream.set_nonblocking(false).unwrap(); - stream.write_all(&pvt_write_cmd).unwrap(); - stream.set_nonblocking(true).unwrap(); - stream.write_all(&pkt).unwrap(); - self.state = TestState::WaitForIbi; - } - - fn receive_ibi(&mut self, stream: &mut TcpStream, target_addr: u8) { - let mut out_header_bytes: [u8; 6] = [0u8; 6]; - match stream.read_exact(&mut out_header_bytes) { - Ok(()) => { - let outdata: OutgoingHeader = transmute!(out_header_bytes); - if outdata.ibi != 0 && outdata.from_addr == target_addr { - let pvt_read_cmd = prepare_private_read_cmd(target_addr); - stream.set_nonblocking(false).unwrap(); - stream.write_all(&pvt_read_cmd).unwrap(); - stream.set_nonblocking(true).unwrap(); - self.state = TestState::ReceivePrivateRead; - } - } - Err(ref e) if e.kind() == ErrorKind::WouldBlock => {} - Err(e) => panic!("Error reading message from socket: {}", e), - } - } - - fn receive_private_read(&mut self, stream: &mut TcpStream, target_addr: u8) { - let mut out_header_bytes = [0u8; 6]; - match stream.read_exact(&mut out_header_bytes) { - Ok(()) => { - let outdata: OutgoingHeader = transmute!(out_header_bytes); - if target_addr != outdata.from_addr { - return; - } - let resp_desc = outdata.response_descriptor; - let data_len = resp_desc.data_length() as usize; - let mut data = vec![0u8; data_len]; - - stream.set_nonblocking(false).unwrap(); - stream - .read_exact(&mut data) - .expect("Failed to read message from socket"); - stream.set_nonblocking(true).unwrap(); - - let pec = calculate_crc8((target_addr << 1) | 1, &data[..data.len() - 1]); - if pec == data[data.len() - 1] { - self.check_response(&data[..data.len() - 1]); - } else { - println!( - "Received data with invalid CRC8: calclulated {:X} != received {:X}", - pec, - data[data.len() - 1] - ); - } - - self.state = TestState::Finish; - } - Err(ref e) if e.kind() == ErrorKind::WouldBlock => {} - Err(e) => panic!("Error reading message from socket: {}", e), - } - } -} - struct TestRunner { stream: TcpStream, target_addr: u8, passed: usize, running: Arc, - tests: Vec, + tests: Vec>, } impl TestRunner { @@ -325,7 +206,7 @@ impl TestRunner { stream: TcpStream, target_addr: u8, running: Arc, - tests: Vec, + tests: Vec>, ) -> Self { Self { stream, @@ -352,6 +233,78 @@ impl TestRunner { } } +pub fn send_private_write(stream: &mut TcpStream, target_addr: u8, data: Vec) -> bool { + let addr: u8 = target_addr; + + let pec = calculate_crc8(addr << 1, data.as_slice()); + + let mut pkt = Vec::new(); + pkt.extend_from_slice(data.as_slice()); + pkt.push(pec); + + let pvt_write_cmd = prepare_private_write_cmd(addr, pkt.len() as u16); + stream.set_nonblocking(false).unwrap(); + stream.write_all(&pvt_write_cmd).unwrap(); + stream.set_nonblocking(true).unwrap(); + stream.write_all(&pkt).unwrap(); + true +} + +pub fn receive_ibi(stream: &mut TcpStream, target_addr: u8) -> bool { + let mut out_header_bytes: [u8; 6] = [0u8; 6]; + match stream.read_exact(&mut out_header_bytes) { + Ok(()) => { + let outdata: OutgoingHeader = transmute!(out_header_bytes); + if outdata.ibi != 0 && outdata.from_addr == target_addr { + let pvt_read_cmd = prepare_private_read_cmd(target_addr); + stream.set_nonblocking(false).unwrap(); + stream.write_all(&pvt_read_cmd).unwrap(); + stream.set_nonblocking(true).unwrap(); + return true; + } + } + Err(ref e) if e.kind() == ErrorKind::WouldBlock => {} + Err(e) => panic!("Error reading message from socket: {}", e), + } + false +} + +pub fn receive_private_read(stream: &mut TcpStream, target_addr: u8) -> Option> { + let mut out_header_bytes = [0u8; 6]; + match stream.read_exact(&mut out_header_bytes) { + Ok(()) => { + let outdata: OutgoingHeader = transmute!(out_header_bytes); + if target_addr != outdata.from_addr { + return None; + } + let resp_desc = outdata.response_descriptor; + let data_len = resp_desc.data_length() as usize; + let mut data = vec![0u8; data_len]; + + stream.set_nonblocking(false).unwrap(); + stream + .read_exact(&mut data) + .expect("Failed to read message from socket"); + stream.set_nonblocking(true).unwrap(); + + let pec = calculate_crc8((target_addr << 1) | 1, &data[..data.len() - 1]); + if pec != data[data.len() - 1] { + println!( + "Received data with invalid CRC8: calclulated {:X} != received {:X}", + pec, + data[data.len() - 1] + ); + return None; + } + + return Some(data[..data.len() - 1].to_vec()); + } + Err(ref e) if e.kind() == ErrorKind::WouldBlock => {} + Err(e) => panic!("Error reading message from socket: {}", e), + } + None +} + fn prepare_private_write_cmd(to_addr: u8, data_len: u16) -> [u8; 9] { let mut write_cmd = ReguDataTransferCommand::read_from_bytes(&[0; 8]).unwrap(); write_cmd.set_rnw(0); diff --git a/emulator/app/src/main.rs b/emulator/app/src/main.rs index 3a124af..d8aab52 100644 --- a/emulator/app/src/main.rs +++ b/emulator/app/src/main.rs @@ -360,7 +360,7 @@ fn run(cli: Emulator, capture_uart_output: bool) -> io::Result> { if cfg!(feature = "test-mctp-ctrl-cmds") { i3c_controller.start(); println!( - "Starting test thread for testing target {:?}", + "Starting test-mctp-ctrl-cmds test thread for testing target {:?}", i3c.get_dynamic_address().unwrap() ); @@ -371,6 +371,20 @@ fn run(cli: Emulator, capture_uart_output: bool) -> io::Result> { i3c.get_dynamic_address().unwrap(), tests, ); + } else if cfg!(feature = "test-mctp-send-loopback") { + i3c_controller.start(); + println!( + "Starting loopback test thread for testing target {:?}", + i3c.get_dynamic_address().unwrap() + ); + + let tests = tests::mctp_loopback::generate_tests(); + i3c_socket::run_tests( + running.clone(), + cli.i3c_port.unwrap(), + i3c.get_dynamic_address().unwrap(), + tests, + ); } let flash_ctrl_error_irq = pic.register_irq(CaliptraRootBus::FLASH_CTRL_ERROR_IRQ); diff --git a/emulator/app/src/tests/mctp_ctrl_cmd.rs b/emulator/app/src/tests/mctp_ctrl_cmd.rs index 401a7ee..1f77469 100644 --- a/emulator/app/src/tests/mctp_ctrl_cmd.rs +++ b/emulator/app/src/tests/mctp_ctrl_cmd.rs @@ -1,14 +1,17 @@ // Licensed under the Apache-2.0 license -use crate::i3c_socket::Test; +use crate::i3c_socket::{ + receive_ibi, receive_private_read, send_private_write, TestState, TestTrait, +}; use crate::tests::mctp_util::base_protocol::{ MCTPHdr, MCTPMsgHdr, MCTP_HDR_SIZE, MCTP_MSG_HDR_SIZE, }; use crate::tests::mctp_util::ctrl_protocol::*; - +use std::net::TcpStream; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; use strum::IntoEnumIterator; use strum_macros::EnumIter; - use zerocopy::IntoBytes; const TEST_TARGET_EID: u8 = 0xA; @@ -38,13 +41,13 @@ pub(crate) enum MCTPCtrlCmdTests { } impl MCTPCtrlCmdTests { - pub fn generate_tests() -> Vec { + pub fn generate_tests() -> Vec> { MCTPCtrlCmdTests::iter() .map(|test_id| { let test_name = test_id.name(); let req_data = test_id.generate_request_packet(); let resp_data = test_id.generate_response_packet(); - Test::new(test_name, req_data, resp_data) + Box::new(Test::new(test_name, req_data, resp_data)) as Box }) .collect() } @@ -193,3 +196,72 @@ impl MCTPCtrlCmdTests { } } } + +#[derive(Debug, Clone)] +struct Test { + name: String, + state: TestState, + pvt_write_data: Vec, + pvt_read_data: Vec, + passed: bool, +} + +impl Test { + fn new(name: &str, pvt_write_data: Vec, pvt_read_data: Vec) -> Self { + Self { + name: name.to_string(), + state: TestState::Start, + pvt_write_data, + pvt_read_data, + passed: false, + } + } + + fn check_response(&mut self, data: &[u8]) { + if data.len() == self.pvt_read_data.len() && data == self.pvt_read_data { + self.passed = true; + } + } +} + +impl TestTrait for Test { + fn is_passed(&self) -> bool { + self.passed + } + + fn run_test(&mut self, running: Arc, stream: &mut TcpStream, target_addr: u8) { + stream.set_nonblocking(true).unwrap(); + while running.load(Ordering::Relaxed) { + match self.state { + TestState::Start => { + println!("Starting test: {}", self.name); + self.state = TestState::SendPrivateWrite; + } + TestState::SendPrivateWrite => { + if send_private_write(stream, target_addr, self.pvt_write_data.clone()) { + self.state = TestState::WaitForIbi; + } + } + TestState::WaitForIbi => { + if receive_ibi(stream, target_addr) { + self.state = TestState::ReceivePrivateRead; + } + } + TestState::ReceivePrivateRead => { + if let Some(data) = receive_private_read(stream, target_addr) { + self.check_response(data.as_slice()); + self.state = TestState::Finish; + } + } + TestState::Finish => { + println!( + "Test {} : {}", + self.name, + if self.passed { "PASSED" } else { "FAILED" } + ); + break; + } + } + } + } +} diff --git a/emulator/app/src/tests/mctp_loopback.rs b/emulator/app/src/tests/mctp_loopback.rs new file mode 100644 index 0000000..e2fc575 --- /dev/null +++ b/emulator/app/src/tests/mctp_loopback.rs @@ -0,0 +1,96 @@ +// Licensed under the Apache-2.0 license + +use crate::i3c_socket::{ + receive_ibi, receive_private_read, send_private_write, TestState, TestTrait, +}; +use crate::tests::mctp_util::base_protocol::{MCTPHdr, MCTP_HDR_SIZE}; +use std::collections::VecDeque; +use std::net::TcpStream; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use zerocopy::FromBytes; + +pub fn generate_tests() -> Vec> { + vec![Box::new(Test::new("MctpMultiPktTest")) as Box] +} + +struct Test { + test_name: String, + state: TestState, + loopbak_pkts: VecDeque>, + passed: bool, +} + +impl Test { + fn new(test_name: &str) -> Self { + Test { + test_name: test_name.to_string(), + state: TestState::Start, + loopbak_pkts: VecDeque::new(), + passed: false, + } + } + + fn process_received_packet(&mut self, data: Vec) { + let mut resp_pkt = data.clone(); + let mctp_hdr: &mut MCTPHdr<[u8; MCTP_HDR_SIZE]> = + MCTPHdr::mut_from_bytes(&mut resp_pkt[0..MCTP_HDR_SIZE]).unwrap(); + if mctp_hdr.som() == 1 { + self.loopbak_pkts.clear(); + } + let src_eid = mctp_hdr.src_eid(); + mctp_hdr.set_src_eid(mctp_hdr.dest_eid()); + mctp_hdr.set_dest_eid(src_eid); + mctp_hdr.set_tag_owner(0); + + if mctp_hdr.eom() == 1 { + self.state = TestState::SendPrivateWrite; + } else { + self.state = TestState::WaitForIbi; + } + + self.loopbak_pkts.push_back(resp_pkt); + } +} + +impl TestTrait for Test { + fn is_passed(&self) -> bool { + self.passed + } + + fn run_test(&mut self, running: Arc, stream: &mut TcpStream, target_addr: u8) { + stream.set_nonblocking(true).unwrap(); + while running.load(Ordering::Relaxed) { + match self.state { + TestState::Start => { + println!("Starting test: {}", self.test_name); + self.state = TestState::WaitForIbi; + } + TestState::SendPrivateWrite => { + if let Some(write_pkt) = self.loopbak_pkts.pop_front() { + if send_private_write(stream, target_addr, write_pkt) { + self.state = TestState::SendPrivateWrite; + } else { + self.state = TestState::Finish; + } + } else { + self.state = TestState::WaitForIbi; + } + } + TestState::WaitForIbi => { + if receive_ibi(stream, target_addr) { + self.state = TestState::ReceivePrivateRead; + } + } + TestState::ReceivePrivateRead => { + if let Some(data) = receive_private_read(stream, target_addr) { + self.process_received_packet(data); + } + } + TestState::Finish => { + self.passed = true; + } + } + } + } +} diff --git a/emulator/app/src/tests/mctp_util/base_protocol.rs b/emulator/app/src/tests/mctp_util/base_protocol.rs index d0132ef..5539cf5 100644 --- a/emulator/app/src/tests/mctp_util/base_protocol.rs +++ b/emulator/app/src/tests/mctp_util/base_protocol.rs @@ -1,14 +1,14 @@ // Licensed under the Apache-2.0 license use bitfield::bitfield; -use zerocopy::{FromBytes, Immutable, IntoBytes}; +use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout}; pub const MCTP_HDR_SIZE: usize = 4; pub const MCTP_MSG_HDR_SIZE: usize = 1; bitfield! { #[repr(C)] - #[derive(Clone, FromBytes, IntoBytes, Immutable, PartialEq)] + #[derive(Clone, FromBytes, IntoBytes, Immutable, KnownLayout, PartialEq)] pub struct MCTPHdr(MSB0 [u8]); impl Debug; u8; diff --git a/emulator/app/src/tests/mod.rs b/emulator/app/src/tests/mod.rs index ddacdd0..20746c5 100644 --- a/emulator/app/src/tests/mod.rs +++ b/emulator/app/src/tests/mod.rs @@ -3,3 +3,4 @@ #[macro_use] pub mod mctp_util; pub mod mctp_ctrl_cmd; +pub mod mctp_loopback; diff --git a/emulator/periph/Cargo.toml b/emulator/periph/Cargo.toml index 87bb222..fb75012 100644 --- a/emulator/periph/Cargo.toml +++ b/emulator/periph/Cargo.toml @@ -32,3 +32,4 @@ test-i3c-constant-writes = [] test-flash-ctrl-read-write-page = [] test-flash-ctrl-erase-page = [] test-mctp-ctrl-cmds = [] +test-mctp-send-loopback = [] diff --git a/emulator/periph/src/i3c.rs b/emulator/periph/src/i3c.rs index 6c92c81..481a073 100644 --- a/emulator/periph/src/i3c.rs +++ b/emulator/periph/src/i3c.rs @@ -107,6 +107,7 @@ impl I3c { let data_size = resp_desc.data_length().into(); if let Some(_data) = self.tti_tx_data_raw.front() { if self.tti_tx_data_raw[0].len() >= data_size { + self.tti_tx_desc_queue_raw.pop_front(); let resp = I3cTcriResponseXfer { resp: resp_desc, data: self.tti_tx_data_raw.pop_front().unwrap(), diff --git a/emulator/periph/src/i3c_protocol.rs b/emulator/periph/src/i3c_protocol.rs index d8db929..94b9e4a 100644 --- a/emulator/periph/src/i3c_protocol.rs +++ b/emulator/periph/src/i3c_protocol.rs @@ -191,6 +191,11 @@ impl I3cController { .iter_mut() .flat_map(|target| { let mut v = vec![]; + v.extend(target.get_response().map(|resp| I3cBusResponse { + ibi: None, + addr: target.get_address().unwrap(), + resp, + })); v.extend(target.get_ibis().iter().map(|mdb| { I3cBusResponse { ibi: Some(*mdb), @@ -198,11 +203,6 @@ impl I3cController { resp: I3cTcriResponseXfer::default(), // empty descriptor for the IBI } })); - v.extend(target.get_response().map(|resp| I3cBusResponse { - ibi: None, - addr: target.get_address().unwrap(), - resp, - })); v }) .collect() diff --git a/runtime/Cargo.toml b/runtime/Cargo.toml index 42b9f09..d30a4f2 100644 --- a/runtime/Cargo.toml +++ b/runtime/Cargo.toml @@ -29,3 +29,4 @@ test-flash-ctrl-init = [] test-flash-ctrl-read-write-page = [] test-flash-ctrl-erase-page = [] test-mctp-ctrl-cmds = [] +test-mctp-send-loopback = [] diff --git a/runtime/capsules/src/lib.rs b/runtime/capsules/src/lib.rs index 1e13506..44a8544 100644 --- a/runtime/capsules/src/lib.rs +++ b/runtime/capsules/src/lib.rs @@ -3,4 +3,6 @@ #![cfg_attr(target_arch = "riscv32", no_std)] #![forbid(unsafe_code)] +pub mod test; + pub mod mctp; diff --git a/runtime/capsules/src/mctp/base_protocol.rs b/runtime/capsules/src/mctp/base_protocol.rs index a83d1a4..f1a6a5b 100644 --- a/runtime/capsules/src/mctp/base_protocol.rs +++ b/runtime/capsules/src/mctp/base_protocol.rs @@ -8,6 +8,8 @@ use bitfield::bitfield; use zerocopy::{FromBytes, Immutable, IntoBytes}; +pub const MCTP_TEST_MSG_TYPE: u8 = 0x70; + pub const MCTP_TAG_OWNER: u8 = 0x08; pub const MCTP_TAG_MASK: u8 = 0x07; @@ -84,6 +86,7 @@ pub enum MessageType { Spdm = 5, SecureSpdm = 6, VendorDefinedPci = 0x7E, + TestMsgType = MCTP_TEST_MSG_TYPE as isize, Invalid, } @@ -95,6 +98,7 @@ impl From for MessageType { 5 => MessageType::Spdm, 6 => MessageType::SecureSpdm, 0x7E => MessageType::VendorDefinedPci, + MCTP_TEST_MSG_TYPE => MessageType::TestMsgType, _ => MessageType::Invalid, } } diff --git a/runtime/capsules/src/mctp/driver.rs b/runtime/capsules/src/mctp/driver.rs index c38efff..4f287df 100644 --- a/runtime/capsules/src/mctp/driver.rs +++ b/runtime/capsules/src/mctp/driver.rs @@ -1,6 +1,6 @@ // Licensed under the Apache-2.0 license -use crate::mctp::base_protocol::*; +use crate::mctp::base_protocol::{valid_eid, MessageType, MCTP_TAG_OWNER}; use crate::mctp::recv::MCTPRxClient; use crate::mctp::send::{MCTPSender, MCTPTxClient}; use core::cell::Cell; @@ -135,12 +135,7 @@ impl<'a> MCTPDriver<'a> { } fn supported_msg_type(&self, msg_type: u8) -> bool { - for mtype in self.msg_types.iter() { - if msg_type == *mtype as u8 { - return true; - } - } - false + self.msg_types.iter().any(|&t| t as u8 == msg_type) } fn validate_args( diff --git a/runtime/capsules/src/mctp/mux.rs b/runtime/capsules/src/mctp/mux.rs index 728df8c..f55b8ee 100644 --- a/runtime/capsules/src/mctp/mux.rs +++ b/runtime/capsules/src/mctp/mux.rs @@ -356,7 +356,7 @@ impl<'a, M: MCTPTransportBinding<'a>> TransportTxClient for MuxMCTPDriver<'a, M> let mut cur_sender = self.sender_list.head(); if let Some(sender) = cur_sender { - if sender.is_eom() { + if sender.is_eom() || result.is_err() { sender.send_done(result); self.sender_list.pop_head(); cur_sender = self.sender_list.head(); @@ -394,7 +394,8 @@ impl<'a, M: MCTPTransportBinding<'a>> TransportRxClient for MuxMCTPDriver<'a, M> MessageType::Pldm | MessageType::Spdm | MessageType::SecureSpdm - | MessageType::VendorDefinedPci => { + | MessageType::VendorDefinedPci + | MessageType::TestMsgType => { self.process_first_packet( mctp_header, msg_type, diff --git a/runtime/capsules/src/mctp/recv.rs b/runtime/capsules/src/mctp/recv.rs index 8fcca33..5d40848 100644 --- a/runtime/capsules/src/mctp/recv.rs +++ b/runtime/capsules/src/mctp/recv.rs @@ -35,6 +35,7 @@ impl<'a> ListNode<'a, MCTPRxState<'a>> for MCTPRxState<'a> { } } +#[derive(Debug)] struct MsgTerminus { msg_type: u8, msg_tag: u8, diff --git a/runtime/capsules/src/mctp/transport_binding.rs b/runtime/capsules/src/mctp/transport_binding.rs index 49e7d86..398dec9 100644 --- a/runtime/capsules/src/mctp/transport_binding.rs +++ b/runtime/capsules/src/mctp/transport_binding.rs @@ -147,6 +147,10 @@ impl<'a> MCTPTransportBinding<'a> for MCTPI3CBinding<'a> { // Make sure there's enough space for the PEC byte if len == 0 || len > self.max_write_len.get() - 1 { + println!( + "MCTPI3CBinding: Invalid length. Expected: {}", + self.max_write_len.get() - 1 + ); Err((ErrorCode::SIZE, self.tx_buffer.take().unwrap()))?; } @@ -165,6 +169,7 @@ impl<'a> MCTPTransportBinding<'a> for MCTPI3CBinding<'a> { } } } else { + println!("MCTPI3CBinding: Invalid length. Expected: {}", len + 1); Err((ErrorCode::SIZE, tx_buffer))?; } } diff --git a/runtime/capsules/src/test/mctp.rs b/runtime/capsules/src/test/mctp.rs new file mode 100644 index 0000000..5e3996c --- /dev/null +++ b/runtime/capsules/src/test/mctp.rs @@ -0,0 +1,145 @@ +// Licensed under the Apache-2.0 license + +use crate::mctp::base_protocol::{MessageType, MCTP_TAG_MASK, MCTP_TAG_OWNER, MCTP_TEST_MSG_TYPE}; +use crate::mctp::recv::MCTPRxClient; +use crate::mctp::send::{MCTPSender, MCTPTxClient}; +use core::cell::Cell; +use core::fmt::Write; +use kernel::utilities::cells::{MapCell, OptionalCell}; +use kernel::utilities::leasable_buffer::SubSliceMut; +use kernel::ErrorCode; +use romtime::println; + +pub const MCTP_TEST_REMOTE_EID: u8 = 0x20; +pub const MCTP_TEST_MSG_SIZE: usize = 1000; + +static TEST_MSG_LEN_ARR: [usize; 4] = [64, 63, 256, 1000]; + +pub trait TestClient { + fn test_result(&self, passed: bool, npassed: usize, ntotal: usize); +} + +pub struct MockMctp<'a> { + mctp_sender: &'a dyn MCTPSender<'a>, + mctp_msg_buf: MapCell>, + msg_type: MessageType, + msg_tag: Cell, + test_client: OptionalCell<&'a dyn TestClient>, + cur_idx: Cell, +} + +impl<'a> MockMctp<'a> { + pub fn new( + mctp_sender: &'a dyn MCTPSender<'a>, + msg_type: MessageType, + mctp_msg_buf: SubSliceMut<'static, u8>, + ) -> Self { + Self { + mctp_sender, + mctp_msg_buf: MapCell::new(mctp_msg_buf), + msg_type, + msg_tag: Cell::new(0), + test_client: OptionalCell::empty(), + cur_idx: Cell::new(0), + } + } + + pub fn set_test_client(&self, test_client: &'a dyn TestClient) { + self.test_client.set(test_client); + } + + fn prepare_send_data(&self, msg_len: usize) { + assert!(self.mctp_msg_buf.map(|buf| buf.len()).unwrap() >= msg_len); + self.mctp_msg_buf.map(|buf| { + buf.reset(); + buf[0] = MCTP_TEST_MSG_TYPE; + for i in 1..msg_len { + buf[i] = i as u8; + } + buf.slice(0..msg_len) + }); + } + + pub fn run_send_loopback_test(&self) { + self.prepare_send_data(TEST_MSG_LEN_ARR[self.cur_idx.get()]); + self.mctp_sender + .send_msg( + self.msg_type as u8, + MCTP_TEST_REMOTE_EID, + MCTP_TAG_OWNER, + self.mctp_msg_buf.take().unwrap(), + ) + .unwrap(); + } +} + +impl<'a> MCTPRxClient for MockMctp<'a> { + fn receive(&self, src_eid: u8, msg_type: u8, msg_tag: u8, msg_payload: &[u8], msg_len: usize) { + println!( + "Received message from EID: {} with message type: {} and message tag: {} msg_len: {}", + src_eid, msg_type, msg_tag, msg_len + ); + + if msg_type != self.msg_type as u8 + || src_eid != MCTP_TEST_REMOTE_EID + || msg_tag != self.msg_tag.get() + || msg_len != TEST_MSG_LEN_ARR[self.cur_idx.get()] + { + self.test_client.map(|client| { + client.test_result(false, self.cur_idx.get() + 1, TEST_MSG_LEN_ARR.len()); + }); + } + + self.mctp_msg_buf.map(|buf| { + if buf[..msg_len] != msg_payload[..msg_len] { + self.test_client.map(|client| { + client.test_result(false, self.cur_idx.get() + 1, TEST_MSG_LEN_ARR.len()); + }); + } + }); + + println!( + "Completed loopback test for message length: {}", + TEST_MSG_LEN_ARR[self.cur_idx.get()] + ); + + if self.cur_idx.get() == TEST_MSG_LEN_ARR.len() - 1 { + self.test_client.map(|client| { + client.test_result(true, self.cur_idx.get() + 1, TEST_MSG_LEN_ARR.len()); + }); + } else { + self.cur_idx.set(self.cur_idx.get() + 1); + self.prepare_send_data(TEST_MSG_LEN_ARR[self.cur_idx.get()]); + self.mctp_sender + .send_msg( + self.msg_type as u8, + MCTP_TEST_REMOTE_EID, + MCTP_TAG_OWNER, + self.mctp_msg_buf.take().unwrap(), + ) + .unwrap(); + } + } +} + +impl<'a> MCTPTxClient for MockMctp<'a> { + fn send_done( + &self, + dest_eid: u8, + msg_type: u8, + msg_tag: u8, + result: Result<(), ErrorCode>, + mut msg_payload: SubSliceMut<'static, u8>, + ) { + assert!(result == Ok(())); + assert!(dest_eid == MCTP_TEST_REMOTE_EID); + assert!(msg_type == self.msg_type as u8); + self.msg_tag.set(msg_tag & MCTP_TAG_MASK); + msg_payload.reset(); + self.mctp_msg_buf.replace(msg_payload); + println!( + "Message sent of length : {}", + TEST_MSG_LEN_ARR[self.cur_idx.get()] + ); + } +} diff --git a/runtime/capsules/src/test/mod.rs b/runtime/capsules/src/test/mod.rs new file mode 100644 index 0000000..d93ca66 --- /dev/null +++ b/runtime/capsules/src/test/mod.rs @@ -0,0 +1,3 @@ +// Licensed under the Apache-2.0 license + +pub mod mctp; diff --git a/runtime/src/board.rs b/runtime/src/board.rs index edd0d4a..3ee970e 100644 --- a/runtime/src/board.rs +++ b/runtime/src/board.rs @@ -6,6 +6,7 @@ use crate::components as runtime_components; use crate::timers::InternalTimers; use capsules_core::virtualizers::virtual_alarm::{MuxAlarm, VirtualMuxAlarm}; +use capsules_runtime::mctp::base_protocol::MessageType; use core::ptr::{addr_of, addr_of_mut}; use kernel::capabilities; use kernel::component::Component; @@ -246,11 +247,8 @@ pub unsafe fn main() { .finalize(crate::mctp_mux_component_static!(MCTPI3CBinding)); let mctp_spdm_msg_types = static_init!( - [capsules_runtime::mctp::base_protocol::MessageType; 2], - [ - capsules_runtime::mctp::base_protocol::MessageType::Spdm, - capsules_runtime::mctp::base_protocol::MessageType::SecureSpdm, - ] + [MessageType; 2], + [MessageType::Spdm, MessageType::SecureSpdm,] ); let mctp_spdm = runtime_components::mctp_driver::MCTPDriverComponent::new( board_kernel, @@ -260,10 +258,7 @@ pub unsafe fn main() { ) .finalize(crate::mctp_driver_component_static!()); - let mctp_pldm_msg_types = static_init!( - [capsules_runtime::mctp::base_protocol::MessageType; 1], - [capsules_runtime::mctp::base_protocol::MessageType::Pldm] - ); + let mctp_pldm_msg_types = static_init!([MessageType; 1], [MessageType::Pldm]); let mctp_pldm = runtime_components::mctp_driver::MCTPDriverComponent::new( board_kernel, capsules_runtime::mctp::driver::MCTP_PLDM_DRIVER_NUM, @@ -272,10 +267,8 @@ pub unsafe fn main() { ) .finalize(crate::mctp_driver_component_static!()); - let mctp_vendor_def_pci_msg_types = static_init!( - [capsules_runtime::mctp::base_protocol::MessageType; 1], - [capsules_runtime::mctp::base_protocol::MessageType::VendorDefinedPci] - ); + let mctp_vendor_def_pci_msg_types = + static_init!([MessageType; 1], [MessageType::VendorDefinedPci]); let mctp_vendor_def_pci = runtime_components::mctp_driver::MCTPDriverComponent::new( board_kernel, capsules_runtime::mctp::driver::MCTP_VENDOR_DEFINED_PCI_DRIVER_NUM, @@ -355,19 +348,22 @@ pub unsafe fn main() { // Run any requested test let exit = if cfg!(feature = "test-i3c-simple") { debug!("Executing test-i3c-simple"); - crate::tests::test_i3c_simple() + crate::tests::i3c_target_test::test_i3c_simple() } else if cfg!(feature = "test-i3c-constant-writes") { debug!("Executing test-i3c-constant-writes"); - crate::tests::test_i3c_constant_writes() + crate::tests::i3c_target_test::test_i3c_constant_writes() } else if cfg!(feature = "test-flash-ctrl-init") { debug!("Executing test-flash-ctrl-init"); - crate::flash_ctrl_test::test_flash_ctrl_init() + crate::tests::flash_ctrl_test::test_flash_ctrl_init() } else if cfg!(feature = "test-flash-ctrl-read-write-page") { debug!("Executing test-flash-ctrl-read-write-page"); - crate::flash_ctrl_test::test_flash_ctrl_read_write_page() + crate::tests::flash_ctrl_test::test_flash_ctrl_read_write_page() } else if cfg!(feature = "test-flash-ctrl-erase-page") { debug!("Executing test-flash-ctrl-erase-page"); - crate::flash_ctrl_test::test_flash_ctrl_erase_page() + crate::tests::flash_ctrl_test::test_flash_ctrl_erase_page() + } else if cfg!(feature = "test-mctp-send-loopback") { + debug!("Executing test-mctp-send-loopback"); + crate::tests::mctp_test::test_mctp_send_loopback(mctp_mux) } else { None }; diff --git a/runtime/src/components/mctp_driver.rs b/runtime/src/components/mctp_driver.rs index 4bf2802..7dcc216 100644 --- a/runtime/src/components/mctp_driver.rs +++ b/runtime/src/components/mctp_driver.rs @@ -11,10 +11,14 @@ //! Usage //! ----- //! ```rust -//! let mctp_driver = MCTPDriverComponent::new().finalize( +//! let spdm_mctp_driver = MCTPDriverComponent::new( +//! board_kernel, +//! capsules_runtime::mctp::driver::MCTP_SPDM_DRIVER_NUM, +//! mctp_mux, +//! mctp_spdm_msg_types, +//! ) +//! .finalize(mctp_driver_component_static!()); //! ``` -//! -//! use capsules_runtime::mctp::base_protocol::MessageType; use capsules_runtime::mctp::driver::{MCTPDriver, MCTP_MAX_MESSAGE_SIZE}; diff --git a/runtime/src/components/mock_mctp.rs b/runtime/src/components/mock_mctp.rs new file mode 100644 index 0000000..cd061b4 --- /dev/null +++ b/runtime/src/components/mock_mctp.rs @@ -0,0 +1,81 @@ +// Licensed under the Apache-2.0 license + +use capsules_runtime::mctp::base_protocol::MessageType; +use capsules_runtime::mctp::driver::MCTP_MAX_MESSAGE_SIZE; +use capsules_runtime::mctp::mux::MuxMCTPDriver; +use capsules_runtime::mctp::recv::MCTPRxState; +use capsules_runtime::mctp::send::{MCTPSender, MCTPTxState}; +use capsules_runtime::mctp::transport_binding::MCTPI3CBinding; +use capsules_runtime::test::mctp::MockMctp; +use core::mem::MaybeUninit; +use kernel::component::Component; +use kernel::utilities::leasable_buffer::SubSliceMut; + +#[macro_export] +macro_rules! mock_mctp_component_static { + () => {{ + use capsules_runtime::mctp::base_protocol::MessageType; + use capsules_runtime::mctp::driver::MCTP_MAX_MESSAGE_SIZE; + use capsules_runtime::mctp::recv::MCTPRxState; + use capsules_runtime::mctp::send::MCTPTxState; + use capsules_runtime::mctp::transport_binding::MCTPI3CBinding; + use capsules_runtime::test::mctp::MockMctp; + + let tx_state = kernel::static_buf!(MCTPTxState<'static, MCTPI3CBinding<'static>>); + let rx_state = kernel::static_buf!(MCTPRxState<'static>); + let rx_msg_buf = kernel::static_buf!([u8; MCTP_MAX_MESSAGE_SIZE]); + let tx_msg_buf = kernel::static_buf!([u8; MCTP_MAX_MESSAGE_SIZE]); + let msg_types = kernel::static_buf!([MessageType; 1]); + let mock_mctp = kernel::static_buf!(MockMctp<'static>); + ( + tx_state, rx_state, rx_msg_buf, tx_msg_buf, msg_types, mock_mctp, + ) + }}; +} + +pub struct MockMctpComponent { + mctp_mux: &'static MuxMCTPDriver<'static, MCTPI3CBinding<'static>>, +} + +impl MockMctpComponent { + pub fn new(mctp_mux: &'static MuxMCTPDriver<'static, MCTPI3CBinding<'static>>) -> Self { + Self { mctp_mux } + } +} + +impl Component for MockMctpComponent { + type StaticInput = ( + &'static mut MaybeUninit>>, + &'static mut MaybeUninit>, + &'static mut MaybeUninit<[u8; MCTP_MAX_MESSAGE_SIZE]>, + &'static mut MaybeUninit<[u8; MCTP_MAX_MESSAGE_SIZE]>, + &'static mut MaybeUninit<[MessageType; 1]>, + &'static mut MaybeUninit>, + ); + type Output = &'static MockMctp<'static>; + + fn finalize(self, static_buffer: Self::StaticInput) -> Self::Output { + let rx_msg_buf = static_buffer.2.write([0; MCTP_MAX_MESSAGE_SIZE]); + let tx_msg_buf = static_buffer.3.write([0; MCTP_MAX_MESSAGE_SIZE]); + + let tx_state = static_buffer.0.write(MCTPTxState::new(self.mctp_mux)); + + let msg_types = static_buffer.4.write([MessageType::TestMsgType; 1]); + + let rx_state = static_buffer + .1 + .write(MCTPRxState::new(rx_msg_buf, msg_types)); + + let mock_mctp = static_buffer.5.write(MockMctp::new( + tx_state, + MessageType::TestMsgType, + SubSliceMut::new(tx_msg_buf), + )); + + tx_state.set_client(mock_mctp); + rx_state.set_client(mock_mctp); + self.mctp_mux.add_receiver(rx_state); + + mock_mctp + } +} diff --git a/runtime/src/components/mod.rs b/runtime/src/components/mod.rs index a09112b..bd39b6d 100644 --- a/runtime/src/components/mod.rs +++ b/runtime/src/components/mod.rs @@ -4,3 +4,4 @@ pub mod mctp_driver; pub mod mctp_mux; +pub mod mock_mctp; diff --git a/runtime/src/main.rs b/runtime/src/main.rs index e05d0b5..4915f3c 100644 --- a/runtime/src/main.rs +++ b/runtime/src/main.rs @@ -30,9 +30,6 @@ mod timers; #[cfg(target_arch = "riscv32")] mod flash_ctrl; -#[cfg(target_arch = "riscv32")] -#[allow(unused_imports)] -mod flash_ctrl_test; #[cfg(target_arch = "riscv32")] pub use board::*; diff --git a/runtime/src/flash_ctrl_test.rs b/runtime/src/tests/flash_ctrl_test.rs similarity index 100% rename from runtime/src/flash_ctrl_test.rs rename to runtime/src/tests/flash_ctrl_test.rs diff --git a/runtime/src/tests.rs b/runtime/src/tests/i3c_target_test.rs similarity index 100% rename from runtime/src/tests.rs rename to runtime/src/tests/i3c_target_test.rs diff --git a/runtime/src/tests/mctp_test.rs b/runtime/src/tests/mctp_test.rs new file mode 100644 index 0000000..86dfb3f --- /dev/null +++ b/runtime/src/tests/mctp_test.rs @@ -0,0 +1,50 @@ +// Licensed under the Apache-2.0 license + +use crate::components::mock_mctp::MockMctpComponent; +use capsules_runtime::mctp::mux::MuxMCTPDriver; +use capsules_runtime::mctp::transport_binding::MCTPI3CBinding; +use capsules_runtime::test::mctp::MockMctp; +use capsules_runtime::test::mctp::TestClient; + +use core::fmt::Write; +use romtime::println; + +use kernel::component::Component; +use kernel::static_init; + +pub fn test_mctp_send_loopback( + mctp_mux: &'static MuxMCTPDriver<'static, MCTPI3CBinding<'static>>, +) -> Option { + // set local EID here if needed. + let mock_mctp = + unsafe { MockMctpComponent::new(mctp_mux).finalize(crate::mock_mctp_component_static!()) }; + let mctp_tester = unsafe { static_init!(TestMctp<'static>, TestMctp::new(mock_mctp)) }; + mock_mctp.set_test_client(mctp_tester); + mock_mctp.run_send_loopback_test(); + None +} + +struct TestMctp<'a> { + _mock_mctp: &'a MockMctp<'a>, +} + +impl<'a> TestMctp<'a> { + pub fn new(_mock_mctp: &'static MockMctp<'a>) -> Self { + Self { _mock_mctp } + } +} + +impl<'a> TestClient for TestMctp<'a> { + fn test_result(&self, passed: bool, npassed: usize, ntotal: usize) { + println!("MCTP test result: {}/{} passed", npassed, ntotal); + println!( + "MCTP test result: {}", + if passed { "PASSED" } else { "FAILED" } + ); + if passed { + crate::io::exit_emulator(0); + } else { + crate::io::exit_emulator(1); + } + } +} diff --git a/runtime/src/tests/mod.rs b/runtime/src/tests/mod.rs new file mode 100644 index 0000000..d921d2d --- /dev/null +++ b/runtime/src/tests/mod.rs @@ -0,0 +1,5 @@ +// Licensed under the Apache-2.0 license + +pub(crate) mod flash_ctrl_test; +pub(crate) mod i3c_target_test; +pub(crate) mod mctp_test; diff --git a/tests/integration/src/lib.rs b/tests/integration/src/lib.rs index ff8caa7..8d98603 100644 --- a/tests/integration/src/lib.rs +++ b/tests/integration/src/lib.rs @@ -131,4 +131,5 @@ mod test { run_test!(test_flash_ctrl_read_write_page); run_test!(test_flash_ctrl_erase_page); run_test!(test_mctp_ctrl_cmds); + run_test!(test_mctp_send_loopback); }