Skip to content

Commit

Permalink
feat(metadata): add check for Node ID on gRPC metadata
Browse files Browse the repository at this point in the history
Signed-off-by: rgallor <riccardo.gallo@secomind.com>
  • Loading branch information
rgallor committed Feb 8, 2024
1 parent ab4e7cb commit 09584ea
Show file tree
Hide file tree
Showing 7 changed files with 153 additions and 21 deletions.
18 changes: 9 additions & 9 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 5 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,31 +22,33 @@ categories = ["network-programming"]
rust-version = "1.66.1"

[dependencies]
astarte-device-sdk = { version = "0.7.0", features = ["derive", "message-hub"] }
astarte-device-sdk = { git = "https://github.com/rgallor/astarte-device-sdk-rust.git", branch = "grpc-metadata", features = ["derive", "message-hub"] }
astarte-message-hub-proto = "0.6.1"
async-trait = "0.1.77"
axum = "0.7.4"
chrono = "0.4.33"
clap = { version = "3.2.25", features = ["derive"] }
displaydoc = "0.2.4"
env_logger = "0.10.2"
hyper = "0.14.26"
log = "0.4.20"
pbjson-types = "0.6.0"
prost = "0.12.3"
serde = "1.0.195"
serde_json = "1.0"
serde_json = "1.0.0"
thiserror = "1.0.56"
tokio = { version = "1.35.1", features = ["rt-multi-thread", "sync", "macros", "signal"] }
tokio-stream = { version = "0.1.14", features = ["net"] }
tokio-util ="0.7.10"
toml = "0.5.9"
tonic = "0.10.2"
tower = "0.4.13"
uuid = "1.7.0"
zbus = { version = "3.14.1", default-features = false, features = ["tokio"] }

# pinned transitive dependenciees
predicates = "=3.0.3"
anstyle = "=1.0.2"
predicates = "=3.0.3"

[dev-dependencies]
mockall = "0.12.1"
Expand Down
13 changes: 7 additions & 6 deletions examples/client/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ use astarte_device_sdk::builder::{DeviceBuilder, DeviceSdkBuild};
use astarte_device_sdk::store::memory::MemoryStore;
use astarte_device_sdk::transport::grpc::GrpcConfig;
use astarte_device_sdk::types::AstarteType;
use astarte_device_sdk::Client;
use astarte_device_sdk::{Client, ClientDisconnect};
use std::time;

use clap::Parser;
use log::{error, info, warn};
use log::{debug, error, info, warn};
use tokio::select;
use tokio::signal::ctrl_c;
use tokio::task::JoinHandle;
Expand Down Expand Up @@ -66,11 +66,9 @@ async fn main() -> Result<(), DynError> {

let (mut node, mut rx_events) = DeviceBuilder::new()
.store(MemoryStore::new())
.interface_directory("examples/client/interfaces")
.expect("failed to use interface directory")
.interface_directory("examples/client/interfaces")?
.connect(grpc_cfg)
.await
.expect("failed to connect")
.await?
.build();

let receive_handle = tokio::spawn(async move {
Expand Down Expand Up @@ -127,6 +125,7 @@ async fn main() -> Result<(), DynError> {
// wait for CTRL C to terminate the node execution
select! {
_ = ctrl_c() => {
debug!("CTRL C received, stop sending/receiving data from Astarte");
send_handle.abort();
receive_handle.abort();
}
Expand All @@ -139,6 +138,8 @@ async fn main() -> Result<(), DynError> {
handle_task(receive_handle).await;
handle_task(send_handle).await;

node.disconnect().await;

Ok(())
}

Expand Down
123 changes: 121 additions & 2 deletions src/astarte_message_hub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,24 @@
*/
//! Contains the implementation for the Astarte message hub.
use hyper::{http, Body};
use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

use astarte_message_hub_proto::types::InterfaceJson;
use astarte_message_hub_proto::AstarteMessage;
use log::info;
use clap::__macro_refs::once_cell::sync::Lazy;
use log::{error, info};
use thiserror::Error;
use tokio::sync::RwLock;
use tokio_stream::wrappers::ReceiverStream;
use tonic::body::BoxBody;
use tonic::{Request, Response, Status};
use tower::{Layer, Service};
use uuid::Uuid;
use zbus::export::futures_util::FutureExt;

use crate::data::astarte::{AstartePublisher, AstarteRunner, AstarteSubscriber};

Expand All @@ -41,6 +49,7 @@ pub struct AstarteMessageHub<T: Clone + AstarteRunner + AstartePublisher + Astar
}

/// A single node that can be connected to the Astarte message hub.
#[derive(Clone)]
pub struct AstarteNode {
/// Identifier for the node
pub id: Uuid,
Expand Down Expand Up @@ -80,6 +89,116 @@ where
astarte_handler,
}
}

/// Create a [tower] layer to intercept the Node ID in the gRPC requests
pub fn make_interceptor_layer(&self) -> NodeIdInterceptorLayer {
NodeIdInterceptorLayer::new(Arc::clone(&self.nodes))
}
}

#[derive(Error, Debug)]
pub enum InterceptorError {
/// Node ID not specified in the gRPC request
#[error("Node ID not specified in the gRPC request")]
AbsentId,
/// Invalid string value for Node ID
#[error("Invalid string value for Node ID, {0:?}")]
IdToStr(#[from] hyper::header::ToStrError),
/// Invalid Uuid value for Node ID
#[error("Invalid Uuid value for Node ID, {0:?}")]
IdToUuid(#[from] uuid::Error),
/// The Node ID does not exist
#[error("The Node ID does not exist")]
IdNotFound,
#[error("Inner")]
Error(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
}

#[derive(Clone, Default)]
pub struct NodeIdInterceptorLayer {
msg_hub_nodes: Arc<RwLock<HashMap<Uuid, AstarteNode>>>,
}

impl NodeIdInterceptorLayer {
fn new(msg_hub_nodes: Arc<RwLock<HashMap<Uuid, AstarteNode>>>) -> NodeIdInterceptorLayer {
Self { msg_hub_nodes }
}
}

impl<S> Layer<S> for NodeIdInterceptorLayer {
type Service = NodeIdInterceptor<S>;

fn layer(&self, service: S) -> Self::Service {
NodeIdInterceptor {
inner: service,
msg_hub_nodes: self.msg_hub_nodes.clone(),
}
}
}

#[derive(Clone)]
pub struct NodeIdInterceptor<S> {
inner: S,
msg_hub_nodes: Arc<RwLock<HashMap<Uuid, AstarteNode>>>,
}

impl<S> NodeIdInterceptor<S> {
pub async fn handle_req(
mut self,
req: hyper::Request<Body>,
) -> Result<S::Response, InterceptorError>
where
S: Service<hyper::Request<Body>, Response = hyper::Response<BoxBody>>
+ Clone
+ Send
+ 'static,
S::Future: Send + 'static,
S::Error: Into<InterceptorError>,
{
static ATTACH_PATH: Lazy<http::uri::PathAndQuery> = Lazy::new(|| {
http::uri::PathAndQuery::from_static("/astarteplatform.msghub.MessageHub/Attach")
});

// check that the Node ID is present inside the metadata on all gRPC requests apart from the attach ones
if req.uri().path_and_query() != Some(&ATTACH_PATH) {
let node_id_str = req
.headers()
.get("node_id")
.ok_or(InterceptorError::AbsentId)?
.to_str()?;
let node_uuid = Uuid::parse_str(node_id_str)?;

if !self.msg_hub_nodes.read().await.contains_key(&node_uuid) {
return Err(InterceptorError::IdNotFound);
}
}

self.inner.call(req).await.map_err(S::Error::into)
}
}

type BoxFuture<'a, T> = Pin<Box<dyn std::future::Future<Output = T> + Send + 'a>>;

impl<S> Service<hyper::Request<Body>> for NodeIdInterceptor<S>
where
S: Service<hyper::Request<Body>, Response = hyper::Response<BoxBody>> + Clone + Send + 'static,
S::Future: Send + 'static,
S::Error: Into<InterceptorError>,
{
type Response = S::Response;
type Error = InterceptorError;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(S::Error::into)
}

fn call(&mut self, req: hyper::Request<Body>) -> Self::Future {
let clone = self.clone();
let service = std::mem::replace(self, clone);

service.handle_req(req).boxed()
}
}

#[tonic::async_trait]
Expand Down Expand Up @@ -184,7 +303,7 @@ impl<T: Clone + AstarteRunner + AstartePublisher + AstarteSubscriber + 'static>
/// payload: Some(Payload::AstarteData(100.into()))
/// };
///
/// let _ = message_hub_client.send(astarte_message).await;
/// let _ = message_hub_client.send(astarte_message).await;
///
/// Ok(())
///
Expand Down
2 changes: 1 addition & 1 deletion src/data/mock_astarte_sdk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

//! Mocking of the Astarte Device Sdk.
use astarte_device_sdk::{types::AstarteType, AstarteAggregate, Error, Interface};
use astarte_device_sdk::{error::Error, types::AstarteType, AstarteAggregate, Interface};
use async_trait::async_trait;
use mockall::{automock, mock};

Expand Down
10 changes: 10 additions & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ async fn main() -> Result<(), AstarteMessageHubError> {
// Run the protobuf server
let addrs = (Ipv6Addr::LOCALHOST, options.grpc_socket_port).into();
tonic::transport::Server::builder()
.layer(message_hub.make_interceptor_layer())
.add_service(MessageHubServer::new(message_hub))
.serve(addrs)
.await?;
Expand Down Expand Up @@ -104,10 +105,19 @@ async fn initialize_astarte_device_sdk(
return Err(AstarteMessageHubError::MissingConfig("interface directory"));
};

let int_dir = int_dir
.as_os_str()
.to_str()
.ok_or(AstarteMessageHubError::MissingConfig(
"invalid interface directory",
))?;

// create a device instance
let (device, rx_events) = DeviceBuilder::new()
.interface_directory(int_dir)?
.store(MemoryStore::new())
// the connect method will internally create a MessageHubClient already configured with an NodeId Interceptor
// layer, necessary to insert the Node ID into the gRPC metadata
.connect(mqtt_config)
.await?
.build();
Expand Down
Empty file added src/types.rs
Empty file.

0 comments on commit 09584ea

Please sign in to comment.