Skip to content

Commit

Permalink
Cleaned up traits.
Browse files Browse the repository at this point in the history
  • Loading branch information
Frostie314159 committed Jan 6, 2024
1 parent 70e0c01 commit 2c3cbc8
Show file tree
Hide file tree
Showing 15 changed files with 86 additions and 84 deletions.
5 changes: 1 addition & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
[package]
name = "awdl-frame-parser"
version = "0.3.3"
version = "0.3.4"
edition = "2021"
description = "A parser for AWDL data and action frames."
authors = ["Frostie314159"]
license = "MIT OR Apache-2.0"
readme = "README.md"
repository = "https://github.com/Frostie314159/awdl-frame-parser"

[profile.release]
opt-level = 3

[dev-dependencies]
criterion = { version = "0.4.0", features = ["html_reports"] }

Expand Down
43 changes: 17 additions & 26 deletions src/common/awdl_dns_name.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,51 +44,43 @@ impl ExactSizeIterator for ReadLabelIterator<'_> {

#[derive(Clone, Debug, Default, Hash)]
/// A hostname combined with the [domain](AWDLDnsCompression).
pub struct AWDLDnsName<I> {
pub struct AWDLDnsName<'a, I = ReadLabelIterator<'a>>
where
I: IntoIterator<Item = AWDLStr<'a>>,
{
/// The labels of the peer.
pub labels: I,

/// The domain in [compressed form](AWDLDnsCompression).
pub domain: AWDLDnsCompression,
}
impl<'a, LhsIterator, RhsIterator> PartialEq<AWDLDnsName<RhsIterator>> for AWDLDnsName<LhsIterator>
impl<'a, I: IntoIterator<Item = AWDLStr<'a>> + Copy> Copy for AWDLDnsName<'a, I> {}
impl<'a, I: IntoIterator<Item = AWDLStr<'a>> + Clone> Eq for AWDLDnsName<'a, I> {}
impl<'a, LhsIterator, RhsIterator> PartialEq<AWDLDnsName<'a, RhsIterator>>
for AWDLDnsName<'a, LhsIterator>
where
LhsIterator: IntoIterator<Item = AWDLStr<'a>> + Clone,
RhsIterator: IntoIterator<Item = AWDLStr<'a>> + Clone,
<LhsIterator as IntoIterator>::IntoIter: Clone,
<RhsIterator as IntoIterator>::IntoIter: Clone,
{
fn eq(&self, other: &AWDLDnsName<RhsIterator>) -> bool {
self.labels
.clone()
.into_iter()
.eq(other.labels.clone().into_iter())
fn eq(&self, other: &AWDLDnsName<'a, RhsIterator>) -> bool {
self.labels.clone().into_iter().eq(other.labels.clone())
}
}

impl<'a, I> Eq for AWDLDnsName<I>
impl<'a, I> MeasureWith<()> for AWDLDnsName<'a, I>
where
I: IntoIterator<Item = AWDLStr<'a>> + Clone,
<I as IntoIterator>::IntoIter: Clone,
{
}

impl<'a, I> MeasureWith<()> for AWDLDnsName<I>
where
I: IntoIterator<Item = AWDLStr<'a>> + Clone,
<I as IntoIterator>::IntoIter: Clone,
{
fn measure_with(&self, ctx: &()) -> usize {
self.labels
.clone()
.into_iter()
.clone()
.map(|label| label.measure_with(ctx))
.sum::<usize>()
+ 2
}
}
impl<'a> TryFromCtx<'a> for AWDLDnsName<ReadLabelIterator<'a>> {
impl<'a> TryFromCtx<'a> for AWDLDnsName<'a> {
type Error = scroll::Error;
fn try_from_ctx(from: &'a [u8], _ctx: ()) -> Result<(Self, usize), Self::Error> {
let mut offset = 0;
Expand All @@ -104,7 +96,7 @@ impl<'a> TryFromCtx<'a> for AWDLDnsName<ReadLabelIterator<'a>> {
))
}
}
impl<'a, I: IntoIterator<Item = AWDLStr<'a>>> TryIntoCtx for AWDLDnsName<I> {
impl<'a, I: IntoIterator<Item = AWDLStr<'a>>> TryIntoCtx for AWDLDnsName<'a, I> {
type Error = scroll::Error;
fn try_into_ctx(self, buf: &mut [u8], _ctx: ()) -> Result<usize, Self::Error> {
let mut offset = 0;
Expand All @@ -116,10 +108,9 @@ impl<'a, I: IntoIterator<Item = AWDLStr<'a>>> TryIntoCtx for AWDLDnsName<I> {
Ok(offset)
}
}
impl<'a, I> Display for AWDLDnsName<I>
impl<'a, I> Display for AWDLDnsName<'a, I>
where
I: IntoIterator<Item = AWDLStr<'a>> + Clone,
<I as IntoIterator>::IntoIter: Clone,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
for label in self.labels.clone() {
Expand All @@ -137,15 +128,15 @@ fn test_dns_name() {
0x04, b'a', b'w', b'd', b'l', 0x04, b'a', b'w', b'd', b'l', 0xc0, 0x0c,
]
.as_slice();
let dns_name = bytes.pread::<AWDLDnsName<ReadLabelIterator>>(0).unwrap();
let dns_name = bytes.pread::<AWDLDnsName>(0).unwrap();
assert_eq!(
dns_name,
AWDLDnsName {
labels: vec!["awdl".into(), "awdl".into()],
labels: ["awdl".into(), "awdl".into()],
domain: AWDLDnsCompression::Local
}
);
let mut buf = [0x00; 12];
let mut buf = vec![0x00; dns_name.measure_with(&())];
buf.pwrite(dns_name, 0).unwrap();
assert_eq!(bytes, buf);
}
6 changes: 4 additions & 2 deletions src/common/awdl_str.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use scroll::{
Pread, Pwrite,
};

#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
/// A string in the format used by AWDL.
/// The characters are preceeded by a length byte.
pub struct AWDLStr<'a>(pub &'a str);
Expand Down Expand Up @@ -63,10 +63,12 @@ impl<'a> From<&'a str> for AWDLStr<'a> {
#[cfg(test)]
#[test]
fn test_awdl_str() {
use alloc::vec;

let bytes = [0x06, 0x6c, 0x61, 0x6d, 0x62, 0x64, 0x61].as_slice();
let string = bytes.pread::<AWDLStr<'_>>(0).unwrap();
assert_eq!(string, "lambda".into());
let mut buf = [0x00; 7];
let mut buf = vec![0x00; string.measure_with(&())];
let _ = buf.pwrite::<AWDLStr<'_>>(string, 0).unwrap();
assert_eq!(bytes, buf);
}
2 changes: 1 addition & 1 deletion src/common/awdl_version.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use core::fmt::Display;
use macro_bits::bitfield;

bitfield! {
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
/// A version in AWDL format.
pub struct AWDLVersion: u8 {
/// The major version.
Expand Down
6 changes: 3 additions & 3 deletions src/tlvs/data_path/ht_capabilities_tlv/ampdu_parameters.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use macro_bits::{bit, bitfield, serializable_enum};

serializable_enum! {
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
pub enum MAXAMpduLength: u8 {
/// 8kb
#[default]
Expand All @@ -16,7 +16,7 @@ serializable_enum! {
}

serializable_enum! {
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
pub enum MpduDensity: u8 {
#[default]
NoRestriction => 0,
Expand All @@ -31,7 +31,7 @@ serializable_enum! {
}

bitfield! {
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
pub struct AMpduParameters: u8 {
pub max_a_mpdu_length: MAXAMpduLength => bit!(0,1),
pub mpdu_density: MpduDensity => bit!(2,3,4)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use macro_bits::{bit, bitfield, serializable_enum};

serializable_enum! {
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
pub enum SmPwSave: u8 {
#[default]
Static => 0,
Expand All @@ -10,7 +10,7 @@ serializable_enum! {
}
}
bitfield! {
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
pub struct HTCapabilitiesInfo: u16 {
pub ldpc_coding_capability: bool => bit!(0),
pub support_channel_width: bool => bit!(1),
Expand Down
2 changes: 1 addition & 1 deletion src/tlvs/data_path/ht_capabilities_tlv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use scroll::{
Endian, Pread, Pwrite,
};

#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
pub struct HTCapabilitiesTLV {
pub ht_capabilities_info: HTCapabilitiesInfo,
pub a_mpdu_parameters: AMpduParameters,
Expand Down
2 changes: 1 addition & 1 deletion src/tlvs/data_path/ieee80211_cntr_tlv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use tlv_rs::raw_tlv::RawTLV;

pub type IEEE80211TLV<'a> = RawTLV<'a, u8, u8>;

#[derive(Clone, Debug, PartialEq, Eq)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
/// This TLV just encapsulates an IEEE802.11 TLV.
///
/// In reality, this just contains an EHT capabilities TLV, but for future compatibility we'll just make it do this now.
Expand Down
36 changes: 24 additions & 12 deletions src/tlvs/dns_sd/arpa_tlv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,29 @@ use scroll::{

use crate::common::{AWDLDnsName, AWDLStr, ReadLabelIterator};

#[derive(Clone, Debug, Default, PartialEq, Eq)]
#[derive(Clone, Debug, Default, Hash)]
/// A TLV containing the hostname of the peer. Used for reverse DNS.
pub struct ArpaTLV<'a, I>
pub struct ArpaTLV<'a, I = ReadLabelIterator<'a>>
where
I: IntoIterator<Item = AWDLStr<'a>> + Clone,
<I as IntoIterator>::IntoIter: Clone,
I: IntoIterator<Item = AWDLStr<'a>>,
{
/// The actual arpa data.
pub arpa: AWDLDnsName<I>,
pub arpa: AWDLDnsName<'a, I>,
}
impl<'a, I: IntoIterator<Item = AWDLStr<'a>> + Copy> Copy for ArpaTLV<'a, I> {}
impl<'a, I: IntoIterator<Item = AWDLStr<'a>> + Clone> Eq for ArpaTLV<'a, I> {}
impl<'a, LhsIterator, RhsIterator> PartialEq<ArpaTLV<'a, RhsIterator>> for ArpaTLV<'a, LhsIterator>
where
LhsIterator: IntoIterator<Item = AWDLStr<'a>> + Clone,
RhsIterator: IntoIterator<Item = AWDLStr<'a>> + Clone,
{
fn eq(&self, other: &ArpaTLV<'a, RhsIterator>) -> bool {
self.arpa == other.arpa && self.arpa == other.arpa
}
}
impl<'a, I> MeasureWith<()> for ArpaTLV<'a, I>
where
I: IntoIterator<Item = AWDLStr<'a>> + Clone,
<I as IntoIterator>::IntoIter: Clone,
{
fn measure_with(&self, ctx: &()) -> usize {
self.arpa.measure_with(ctx) + 1
Expand All @@ -35,8 +44,7 @@ impl<'a> TryFromCtx<'a> for ArpaTLV<'a, ReadLabelIterator<'a>> {
}
impl<'a, I> TryIntoCtx for ArpaTLV<'a, I>
where
I: IntoIterator<Item = AWDLStr<'a>> + Clone,
<I as IntoIterator>::IntoIter: Clone,
I: IntoIterator<Item = AWDLStr<'a>>,
{
type Error = scroll::Error;
fn try_into_ctx(self, buf: &mut [u8], _ctx: ()) -> Result<usize, Self::Error> {
Expand All @@ -52,16 +60,20 @@ where
#[test]
fn test_arpa_tlv() {
use crate::common::AWDLDnsCompression;
use alloc::{vec, vec::Vec};
use alloc::vec;
use scroll::{Pread, Pwrite};

let bytes = &include_bytes!("../../../test_bins/arpa_tlv.bin")[3..];

let arpa_tlv = bytes.pread::<ArpaTLV<ReadLabelIterator>>(0).unwrap();
assert_eq!(arpa_tlv.arpa.domain, AWDLDnsCompression::Local);
assert_eq!(
arpa_tlv.arpa.labels.collect::<Vec<_>>(),
vec!["simon-framework".into()]
arpa_tlv,
ArpaTLV {
arpa: AWDLDnsName {
labels: ["simon-framework".into()],
domain: AWDLDnsCompression::Local
}
}
);
let mut buf = vec![0x00; arpa_tlv.measure_with(&())];
buf.as_mut_slice()
Expand Down
22 changes: 9 additions & 13 deletions src/tlvs/dns_sd/service_response_tlv/dns_record.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,15 @@ serializable_enum! {
}
}

#[derive(Clone, Debug, Eq)]
#[derive(Clone, Debug, Hash)]
/// A DNS record as encoded by AWDL.
pub enum AWDLDnsRecord<'a, I>
pub enum AWDLDnsRecord<'a, I = ReadLabelIterator<'a>>
where
I: IntoIterator<Item = AWDLStr<'a>> + Clone,
<I as IntoIterator>::IntoIter: Clone,
I: IntoIterator<Item = AWDLStr<'a>>,
{
/// Pointer
PTR {
domain_name: AWDLDnsName<I>,
domain_name: AWDLDnsName<'a, I>,
},
/// Text
TXT {
Expand All @@ -36,7 +35,7 @@ where
priority: u16,
weight: u16,
port: u16,
target: AWDLDnsName<I>,
target: AWDLDnsName<'a, I>,
},
UnknownRecord {
record_type: u8,
Expand All @@ -45,8 +44,7 @@ where
}
impl<'a, I> AWDLDnsRecord<'a, I>
where
I: IntoIterator<Item = AWDLStr<'a>> + Clone,
<I as IntoIterator>::IntoIter: Clone,
I: IntoIterator<Item = AWDLStr<'a>>,
{
#[inline]
/// Returns the [record type](AWDLDnsRecordType).
Expand All @@ -61,13 +59,13 @@ where
}
}
}
impl<'a, I: IntoIterator<Item = AWDLStr<'a>> + Copy> Copy for AWDLDnsRecord<'a, I> {}
impl<'a, I: IntoIterator<Item = AWDLStr<'a>> + Clone> Eq for AWDLDnsRecord<'a, I> {}
impl<'a, LhsIterator, RhsIterator> PartialEq<AWDLDnsRecord<'a, RhsIterator>>
for AWDLDnsRecord<'a, LhsIterator>
where
LhsIterator: IntoIterator<Item = AWDLStr<'a>> + Clone,
RhsIterator: IntoIterator<Item = AWDLStr<'a>> + Clone,
<LhsIterator as IntoIterator>::IntoIter: Clone,
<RhsIterator as IntoIterator>::IntoIter: Clone,
{
fn eq(&self, other: &AWDLDnsRecord<'a, RhsIterator>) -> bool {
match (self, other) {
Expand Down Expand Up @@ -121,7 +119,6 @@ where
impl<'a, I> MeasureWith<()> for AWDLDnsRecord<'a, I>
where
I: IntoIterator<Item = AWDLStr<'a>> + Clone,
<I as IntoIterator>::IntoIter: Clone,
{
fn measure_with(&self, ctx: &()) -> usize {
(match self {
Expand All @@ -136,7 +133,7 @@ where
}) + 1
}
}
impl<'a> TryFromCtx<'a> for AWDLDnsRecord<'a, ReadLabelIterator<'a>> {
impl<'a> TryFromCtx<'a> for AWDLDnsRecord<'a> {
type Error = scroll::Error;
fn try_from_ctx(from: &'a [u8], _ctx: ()) -> Result<(Self, usize), Self::Error> {
let mut offset = 0;
Expand Down Expand Up @@ -168,7 +165,6 @@ impl<'a> TryFromCtx<'a> for AWDLDnsRecord<'a, ReadLabelIterator<'a>> {
impl<'a, I> TryIntoCtx for AWDLDnsRecord<'a, I>
where
I: IntoIterator<Item = AWDLStr<'a>> + Clone,
<I as IntoIterator>::IntoIter: Clone,
{
type Error = scroll::Error;
fn try_into_ctx(self, buf: &mut [u8], _ctx: ()) -> Result<usize, Self::Error> {
Expand Down
Loading

0 comments on commit 2c3cbc8

Please sign in to comment.