Skip to content

Commit

Permalink
added error reporting on publish, update workflows (#139)
Browse files Browse the repository at this point in the history
* added error reporting on publish, update workflows

* rm println, fix lints

* som pub edits

* Added `dial` and `network_info`

* update workflows, fix lintings
  • Loading branch information
erhant authored Nov 6, 2024
1 parent 02d4dc0 commit 86b2877
Show file tree
Hide file tree
Showing 12 changed files with 436 additions and 597 deletions.
854 changes: 311 additions & 543 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ default-members = ["compute"]

[workspace.package]
edition = "2021"
version = "0.2.18"
version = "0.2.19"
license = "Apache-2.0"
readme = "README.md"

Expand Down
18 changes: 14 additions & 4 deletions compute/src/config.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use crate::utils::{address_in_use, crypto::to_address};
use crate::utils::{
address_in_use,
crypto::{secret_to_keypair, to_address},
};
use dkn_p2p::libp2p::Multiaddr;
use dkn_workflows::DriaWorkflowsConfig;
use eyre::{eyre, Result};
Expand Down Expand Up @@ -71,14 +74,21 @@ impl DriaComputeNodeConfig {
panic!("Please provide an admin public key.");
}
};

let address = to_address(&public_key);
log::info!("Node Address: 0x{}", hex::encode(address));

// to this here to log the peer id at start
log::info!(
"Node PeerID: {}",
secret_to_keypair(&secret_key).public().to_peer_id()
);

log::info!(
"Admin Public Key: 0x{}",
hex::encode(admin_public_key.serialize_compressed())
);

let address = to_address(&public_key);
log::info!("Node Address: 0x{}", hex::encode(address));

let workflows =
DriaWorkflowsConfig::new_from_csv(&env::var("DKN_MODELS").unwrap_or_default());
#[cfg(not(test))]
Expand Down
56 changes: 39 additions & 17 deletions compute/src/handlers/workflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ impl ComputeHandler for WorkflowHandler {
node: &mut DriaComputeNode,
message: DKNMessage,
) -> Result<MessageAcceptance> {
let config = &node.config;
let task = message
.parse_payload::<TaskRequestPayload<WorkflowPayload>>(true)
.wrap_err("Could not parse workflow task")?;
Expand All @@ -55,7 +54,7 @@ impl ComputeHandler for WorkflowHandler {
}

// check task inclusion via the bloom filter
if !task.filter.contains(&config.address)? {
if !task.filter.contains(&node.config.address)? {
log::info!(
"Task {} does not include this node within the filter.",
task.task_id
Expand All @@ -66,16 +65,19 @@ impl ComputeHandler for WorkflowHandler {
}

// read model / provider from the task
let (model_provider, model) = config.workflows.get_any_matching_model(task.input.model)?;
let (model_provider, model) = node
.config
.workflows
.get_any_matching_model(task.input.model)?;
let model_name = model.to_string(); // get model name, we will pass it in payload
log::info!("Using model {} for task {}", model_name, task.task_id);

// prepare workflow executor
let executor = if model_provider == ModelProvider::Ollama {
Executor::new_at(
model,
&config.workflows.ollama.host,
config.workflows.ollama.port,
&node.config.workflows.ollama.host,
node.config.workflows.ollama.port,
)
} else {
Executor::new(model)
Expand All @@ -93,13 +95,14 @@ impl ComputeHandler for WorkflowHandler {
log::info!("Received cancellation, quitting all tasks.");
return Ok(MessageAcceptance::Accept);
},
exec_result_inner = executor.execute(entry.as_ref(), task.input.workflow, &mut memory) => {
exec_result_inner = executor.execute(entry.as_ref(), &task.input.workflow, &mut memory) => {
exec_result = exec_result_inner.map_err(|e| eyre!("Execution error: {}", e.to_string()));
}
}

match exec_result {
let (publish_result, acceptance) = match exec_result {
Ok(result) => {
log::warn!("Task {} result:", result);
// obtain public key from the payload
let task_public_key_bytes =
hex::decode(&task.public_key).wrap_err("Could not decode public key")?;
Expand All @@ -110,36 +113,55 @@ impl ComputeHandler for WorkflowHandler {
result,
&task.task_id,
&task_public_key,
&config.secret_key,
&node.config.secret_key,
model_name,
)?;
let payload_str = serde_json::to_string(&payload)
.wrap_err("Could not serialize response payload")?;

// publish the result
let message = DKNMessage::new(payload_str, Self::RESPONSE_TOPIC);
node.publish(message)?;

// accept so that if there are others included in filter they can do the task
Ok(MessageAcceptance::Accept)
let message = DKNMessage::new(payload_str, Self::RESPONSE_TOPIC);
(node.publish(message), MessageAcceptance::Accept)
}
Err(err) => {
// use pretty display string for error logging with causes
let err_string = format!("{:#}", err);
log::error!("Task {} failed: {}", task.task_id, err_string);

// prepare error payload
let error_payload = TaskErrorPayload::new(task.task_id, err_string, model_name);
let error_payload =
TaskErrorPayload::new(task.task_id.clone(), err_string, model_name);
let error_payload_str = serde_json::to_string(&error_payload)
.wrap_err("Could not serialize error payload")?;

// publish the error result for diagnostics
let message = DKNMessage::new(error_payload_str, Self::RESPONSE_TOPIC);
node.publish(message)?;

// ignore just in case, workflow may be bugged
Ok(MessageAcceptance::Ignore)
let message = DKNMessage::new_signed(
error_payload_str,
Self::RESPONSE_TOPIC,
&node.config.secret_key,
);
(node.publish(message), MessageAcceptance::Ignore)
}
};

// if for some reason we couldnt publish the result, publish the error itself so that RPC doesnt hang
if let Err(publish_err) = publish_result {
let err_msg = format!("Could not publish result: {:?}", publish_err);
log::error!("{}", err_msg);
let payload = serde_json::json!({
"taskId": task.task_id,
"error": err_msg
});
let message = DKNMessage::new_signed(
payload.to_string(),
Self::RESPONSE_TOPIC,
&node.config.secret_key,
);
node.publish(message)?;
}

Ok(acceptance)
}
}
22 changes: 19 additions & 3 deletions compute/src/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,24 @@ impl DriaComputeNode {
);

// create p2p client
let p2p = DriaP2PClient::new(
let mut p2p = DriaP2PClient::new(
keypair,
config.p2p_listen_addr.clone(),
&available_nodes.bootstrap_nodes,
&available_nodes.relay_nodes,
P2P_VERSION,
)?;

// dial rpc nodes
if available_nodes.rpc_addrs.is_empty() {
log::warn!("No RPC nodes found to be dialled!");
} else {
for rpc_addr in &available_nodes.rpc_addrs {
log::info!("Dialing RPC node: {}", rpc_addr);
p2p.dial(rpc_addr.clone())?;
}
}

Ok(DriaComputeNode {
p2p,
config,
Expand Down Expand Up @@ -136,6 +146,7 @@ impl DriaComputeNode {
event = self.p2p.process_events() => {
// refresh admin rpc peer ids
if self.available_nodes_last_refreshed.elapsed() > Duration::from_secs(RPC_PEER_ID_REFRESH_INTERVAL_SECS) {
log::info!("Refreshing available nodes.");
self.available_nodes = AvailableNodes::get_available_nodes().await.unwrap_or_default().join(self.available_nodes.clone()).sort_dedup();
self.available_nodes_last_refreshed = tokio::time::Instant::now();
}
Expand All @@ -156,12 +167,17 @@ impl DriaComputeNode {
}
};

// log::info!(
// "Received {} message ({})\nFrom: {}\nSource: {}",
// topic_str,
// message_id,
// peer_id,
// );
log::info!(
"Received {} message ({})\nFrom: {}\nSource: {}",
"Received {} message ({}) from {}",
topic_str,
message_id,
peer_id,
source_peer_id
);

// ensure that message is from the static RPCs
Expand Down
36 changes: 25 additions & 11 deletions compute/src/utils/available_nodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,20 @@ const STATIC_RPC_PEER_IDS: [&str; 0] = [];
/// API URL for refreshing the Admin RPC PeerIDs from Dria server.
const RPC_PEER_ID_REFRESH_API_URL: &str = "https://dkn.dria.co/available-nodes";

#[derive(serde::Deserialize, Debug)]
pub struct AvailableNodesApiResponse {
pub bootstraps: Vec<String>,
pub relays: Vec<String>,
pub rpcs: Vec<String>,
}

/// Available nodes within the hybrid P2P network.
///
/// - Bootstrap: used for Kademlia DHT bootstrap.
/// - Relay: used for DCutR relay protocol.
/// - RPC: used for RPC nodes for task & ping messages.
///
/// Note that while bootstrap & relay nodes are `Multiaddr`, RPC nodes are `PeerId` because we communicate
/// with them via GossipSub only.
#[derive(Debug, Default, Clone)]
pub struct AvailableNodes {
pub bootstrap_nodes: Vec<Multiaddr>,
pub relay_nodes: Vec<Multiaddr>,
pub rpc_nodes: Vec<PeerId>,
pub rpc_addrs: Vec<Multiaddr>,
}

impl AvailableNodes {
Expand Down Expand Up @@ -66,6 +68,7 @@ impl AvailableNodes {
bootstrap_nodes: parse_vec(bootstrap_nodes),
relay_nodes: parse_vec(relay_nodes),
rpc_nodes: vec![],
rpc_addrs: vec![],
}
}

Expand All @@ -75,6 +78,7 @@ impl AvailableNodes {
bootstrap_nodes: parse_vec(STATIC_BOOTSTRAP_NODES.to_vec()),
relay_nodes: parse_vec(STATIC_RELAY_NODES.to_vec()),
rpc_nodes: parse_vec(STATIC_RPC_PEER_IDS.to_vec()),
rpc_addrs: vec![],
}
}

Expand All @@ -83,7 +87,7 @@ impl AvailableNodes {
self.bootstrap_nodes.extend(other.bootstrap_nodes);
self.relay_nodes.extend(other.relay_nodes);
self.rpc_nodes.extend(other.rpc_nodes);

self.rpc_addrs.extend(other.rpc_addrs);
self
}

Expand All @@ -98,18 +102,31 @@ impl AvailableNodes {
self.rpc_nodes.sort_unstable();
self.rpc_nodes.dedup();

self.rpc_addrs.sort_unstable();
self.rpc_addrs.dedup();

self
}

/// Refreshes the available nodes for Bootstrap, Relay and RPC nodes.
pub async fn get_available_nodes() -> Result<Self> {
#[derive(serde::Deserialize, Debug)]
struct AvailableNodesApiResponse {
pub bootstraps: Vec<String>,
pub relays: Vec<String>,
pub rpcs: Vec<String>,
#[serde(rename = "rpcAddrs")]
pub rpc_addrs: Vec<String>,
}

let response = reqwest::get(RPC_PEER_ID_REFRESH_API_URL).await?;
let response_body = response.json::<AvailableNodesApiResponse>().await?;

Ok(Self {
bootstrap_nodes: parse_vec(response_body.bootstraps),
relay_nodes: parse_vec(response_body.relays),
rpc_nodes: parse_vec(response_body.rpcs),
rpc_addrs: parse_vec(response_body.rpc_addrs),
})
}
}
Expand Down Expand Up @@ -137,9 +154,6 @@ mod tests {
#[tokio::test]
#[ignore = "run this manually"]
async fn test_get_available_nodes() {
std::env::set_var("RUST_LOG", "info");
let _ = env_logger::try_init();

let available_nodes = AvailableNodes::get_available_nodes().await.unwrap();
println!("{:#?}", available_nodes);
}
Expand Down
12 changes: 6 additions & 6 deletions p2p/src/behaviour.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ use libp2p::{autonat, dcutr, gossipsub, identify, kad, relay};

#[derive(libp2p::swarm::NetworkBehaviour)]
pub struct DriaBehaviour {
pub(crate) relay: relay::client::Behaviour,
pub(crate) gossipsub: gossipsub::Behaviour,
pub(crate) kademlia: kad::Behaviour<MemoryStore>,
pub(crate) identify: identify::Behaviour,
pub(crate) autonat: autonat::Behaviour,
pub(crate) dcutr: dcutr::Behaviour,
pub relay: relay::client::Behaviour,
pub gossipsub: gossipsub::Behaviour,
pub kademlia: kad::Behaviour<MemoryStore>,
pub identify: identify::Behaviour,
pub autonat: autonat::Behaviour,
pub dcutr: dcutr::Behaviour,
}

impl DriaBehaviour {
Expand Down
18 changes: 14 additions & 4 deletions p2p/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
use super::*;
use eyre::Result;
use eyre::{Context, Result};
use libp2p::futures::StreamExt;
use libp2p::gossipsub::{
Message, MessageAcceptance, MessageId, PublishError, SubscriptionError, TopicHash,
};
use libp2p::kad::{GetClosestPeersError, GetClosestPeersOk, QueryResult};
use libp2p::{
autonat, gossipsub, identify, kad, multiaddr::Protocol, noise, swarm::SwarmEvent, tcp, yamux,
};
use libp2p::swarm::{dial_opts::DialOpts, NetworkInfo, SwarmEvent};
use libp2p::{autonat, gossipsub, identify, kad, multiaddr::Protocol, noise, tcp, yamux};
use libp2p::{Multiaddr, PeerId, StreamProtocol, Swarm, SwarmBuilder};
use libp2p_identity::Keypair;
use std::time::{Duration, Instant};
Expand Down Expand Up @@ -137,6 +136,12 @@ impl DriaP2PClient {
})
}

/// Returns the network information, such as the number of
/// incoming and outgoing connections.
pub fn network_info(&self) -> NetworkInfo {
self.swarm.network_info()
}

/// Subscribe to a topic.
pub fn subscribe(&mut self, topic_name: &str) -> Result<bool, SubscriptionError> {
log::debug!("Subscribing to {}", topic_name);
Expand Down Expand Up @@ -206,6 +211,11 @@ impl DriaP2PClient {
self.swarm.behaviour().gossipsub.all_peers().collect()
}

/// Dials a given peer.
pub fn dial(&mut self, peer_id: impl Into<DialOpts>) -> Result<()> {
self.swarm.dial(peer_id).wrap_err("could not dial")
}

/// Listens to the Swarm for incoming messages.
///
/// This method should be called in a loop to keep the client running.
Expand Down
1 change: 0 additions & 1 deletion workflows/src/apis/serper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ impl SerperConfig {
log::debug!("Serper API key not found, skipping Serper check");
return Ok(());
};
println!("API KEY: {}", api_key);
log::info!("Serper API key found, checking Serper service");

// make a dummy request
Expand Down
Loading

0 comments on commit 86b2877

Please sign in to comment.