Skip to content

Commit

Permalink
refactor(rust): Groundwork for allowing multi-output nodes in the new…
Browse files Browse the repository at this point in the history
… streaming engine (#20550)
  • Loading branch information
orlp authored Jan 4, 2025
1 parent 58d69d6 commit 7fddd84
Show file tree
Hide file tree
Showing 8 changed files with 355 additions and 269 deletions.
12 changes: 12 additions & 0 deletions crates/polars-stream/src/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,18 @@ pub fn execute_graph(
let num_pipelines = POOL.current_num_threads();
async_executor::set_num_threads(num_pipelines);

// Ensure everything is properly connected.
for (node_key, node) in &graph.nodes {
for (i, input) in node.inputs.iter().enumerate() {
assert!(graph.pipes[*input].receiver == node_key);
assert!(graph.pipes[*input].recv_port == i);
}
for (i, output) in node.outputs.iter().enumerate() {
assert!(graph.pipes[*output].sender == node_key);
assert!(graph.pipes[*output].send_port == i);
}
}

for node in graph.nodes.values_mut() {
node.compute.initialize(num_pipelines);
}
Expand Down
19 changes: 12 additions & 7 deletions crates/polars-stream/src/graph.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use polars_error::PolarsResult;
use slotmap::{SecondaryMap, SlotMap};
use slotmap::{Key, SecondaryMap, SlotMap};

use crate::nodes::ComputeNode;

Expand Down Expand Up @@ -32,7 +32,7 @@ impl Graph {
pub fn add_node<N: ComputeNode + 'static>(
&mut self,
node: N,
inputs: impl IntoIterator<Item = GraphNodeKey>,
inputs: impl IntoIterator<Item = (GraphNodeKey, usize)>,
) -> GraphNodeKey {
// Add the GraphNode.
let node_key = self.nodes.insert(GraphNode {
Expand All @@ -42,8 +42,7 @@ impl Graph {
});

// Create and add pipes that connect input to output.
for (recv_port, sender) in inputs.into_iter().enumerate() {
let send_port = self.nodes[sender].outputs.len();
for (recv_port, (sender, send_port)) in inputs.into_iter().enumerate() {
let pipe = LogicalPipe {
sender,
send_port,
Expand All @@ -58,7 +57,13 @@ impl Graph {

// And connect input to output.
self.nodes[node_key].inputs.push(pipe_key);
self.nodes[sender].outputs.push(pipe_key);
if self.nodes[sender].outputs.len() <= send_port {
self.nodes[sender]
.outputs
.resize(send_port + 1, LogicalPipeKey::null());
}
assert!(self.nodes[sender].outputs[send_port].is_null());
self.nodes[sender].outputs[send_port] = pipe_key;
}

node_key
Expand Down Expand Up @@ -142,14 +147,14 @@ pub struct LogicalPipe {
pub sender: GraphNodeKey,
// Output location:
// graph[x].output[i].send_port == i
send_port: usize,
pub send_port: usize,
pub send_state: PortState,

// Node that we receive data from.
pub receiver: GraphNodeKey,
// Input location:
// graph[x].inputs[i].recv_port == i
recv_port: usize,
pub recv_port: usize,
pub recv_state: PortState,
}

Expand Down
4 changes: 2 additions & 2 deletions crates/polars-stream/src/physical_plan/fmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,10 +251,10 @@ fn visualize_plan_rec(
label
));
for input in inputs {
visualize_plan_rec(*input, phys_sm, expr_arena, visited, out);
visualize_plan_rec(input.node, phys_sm, expr_arena, visited, out);
out.push(format!(
"{} -> {};",
input.data().as_ffi(),
input.node.data().as_ffi(),
node_key.data().as_ffi()
));
}
Expand Down
Loading

0 comments on commit 7fddd84

Please sign in to comment.