From 516a10690ec4abdd52ecba275a6ddc7a5db72f17 Mon Sep 17 00:00:00 2001 From: Eli Rosenthal Date: Thu, 30 May 2024 13:01:39 -0700 Subject: [PATCH] Simplify control flow for CFGs generated by eggcc. (#610) Add a new branch simplification pass to reduce overhead introduced by RVSDGs --- Cargo.lock | 94 +-- dag_in_context/Cargo.toml | 2 +- dag_in_context/rust-toolchain | 2 +- dag_in_context/src/typechecker.rs | 3 +- rust-toolchain | 2 +- src/rvsdg/mod.rs | 1 + src/rvsdg/optimize_direct_jumps.rs | 11 +- src/rvsdg/rvsdg2svg.rs | 23 +- src/rvsdg/simplify_branches.rs | 652 ++++++++++++++++++ src/rvsdg/to_dag.rs | 2 +- .../snapshots/files__fib_shape-optimize.snap | 6 +- .../files__flatten_loop-optimize.snap | 19 +- .../files__implicit-return-optimize.snap | 14 +- tests/snapshots/files__loop_if-optimize.snap | 23 +- .../files__range_check-optimize.snap | 10 +- .../files__range_splitting-optimize.snap | 10 +- .../snapshots/files__small-fib-optimize.snap | 6 +- tests/snapshots/files__sqrt-optimize.snap | 13 +- 18 files changed, 788 insertions(+), 105 deletions(-) create mode 100644 src/rvsdg/simplify_branches.rs diff --git a/Cargo.lock b/Cargo.lock index 68ff4942f..800bc9486 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -91,9 +91,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.83" +version = "1.0.86" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25bdb32cbbdce2b519a9cd7df3a678443100e265d5e25ca763b7572a5104f5f3" +checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da" [[package]] name = "arbitrary" @@ -130,7 +130,7 @@ dependencies = [ "argh_shared", "proc-macro2", "quote", - "syn 2.0.64", + "syn 2.0.66", ] [[package]] @@ -196,7 +196,7 @@ dependencies = [ [[package]] name = "bril-rs" version = "0.1.0" -source = "git+https://github.com/uwplse/bril?rev=78881c45aa53231915f333d1d6dcc26cedc63b57#78881c45aa53231915f333d1d6dcc26cedc63b57" +source = "git+https://github.com/uwplse/bril?rev=e2be3f5#e2be3f5d7e160f02b7aed0ef2bcc3e13ae722d2b" dependencies = [ "serde", "serde_json", @@ -290,9 +290,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "cc" -version = "1.0.97" +version = "1.0.98" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "099a5357d84c4c61eb35fc8eafa9a79a902c2f76911e5747ced4e032edd8d9b4" +checksum = "41c270e7540d725e65ac7f1b212ac8ce349719624d7bcff99f8e2e488e8cf03f" [[package]] name = "cfg-if" @@ -331,7 +331,7 @@ dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.64", + "syn 2.0.66", ] [[package]] @@ -504,9 +504,9 @@ dependencies = [ [[package]] name = "crc32fast" -version = "1.4.0" +version = "1.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3855a8a784b474f333699ef2bbca9db2c4a1f6d9088a90a2d25b1eb53111eaa" +checksum = "a97769d94ddab943e4510d138150169a2758b5ef3eb191a9ee688de3e23ef7b3" dependencies = [ "cfg-if", ] @@ -531,7 +531,7 @@ dependencies = [ name = "dag_in_context" version = "0.1.0" dependencies = [ - "bril-rs 0.1.0 (git+https://github.com/uwplse/bril?rev=78881c45aa53231915f333d1d6dcc26cedc63b57)", + "bril-rs 0.1.0 (git+https://github.com/uwplse/bril?rev=e2be3f5)", "dot-structures", "egglog", "egraph-serialize", @@ -633,7 +633,7 @@ dependencies = [ "rs2bril", "serde_json", "smallvec", - "syn 2.0.64", + "syn 2.0.66", "tempfile", "thiserror", ] @@ -718,7 +718,7 @@ checksum = "f282cfdfe92516eb26c2af8589c274c7c17681f5ecc03c18255fe741c6aa64eb" dependencies = [ "proc-macro2", "quote", - "syn 2.0.64", + "syn 2.0.66", ] [[package]] @@ -817,9 +817,9 @@ source = "git+https://github.com/oflatt/symbolic-expressions?rev=655b6a4c06b4b3d [[package]] name = "getrandom" -version = "0.2.14" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94b22e06ecb0110981051723910cbf0b5f5e09a2062dd7663334ee79a9d1286c" +checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" dependencies = [ "cfg-if", "libc", @@ -958,7 +958,7 @@ source = "git+https://github.com/TheDan64/inkwell.git?rev=6c0fb56b3554e939f9ca61 dependencies = [ "proc-macro2", "quote", - "syn 2.0.64", + "syn 2.0.66", ] [[package]] @@ -976,9 +976,9 @@ dependencies = [ [[package]] name = "instant" -version = "0.1.12" +version = "0.1.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c" +checksum = "e0242819d153cba4b4b05a5a8f2a7e9bbf97b6055b2a002b395c96b5ff3c0222" dependencies = [ "cfg-if", ] @@ -1076,15 +1076,15 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] name = "libc" -version = "0.2.153" +version = "0.2.155" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" +checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" [[package]] name = "libmimalloc-sys" -version = "0.1.37" +version = "0.1.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81eb4061c0582dedea1cbc7aff2240300dd6982e0239d1c99e65c1dbf4a30ba7" +checksum = "0e7bb23d733dfcc8af652a78b7bf232f0e967710d044732185e561e47c0336b6" dependencies = [ "cc", "libc", @@ -1119,9 +1119,9 @@ checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" [[package]] name = "linux-raw-sys" -version = "0.4.13" +version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c" +checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" [[package]] name = "llvm-sys" @@ -1176,9 +1176,9 @@ checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d" [[package]] name = "mimalloc" -version = "0.1.41" +version = "0.1.42" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f41a2280ded0da56c8cf898babb86e8f10651a34adcfff190ae9a1159c6908d" +checksum = "e9186d86b79b52f4a77af65604b51225e8db1d6ee7e3f41aec1e40829c71a176" dependencies = [ "libmimalloc-sys", ] @@ -1284,9 +1284,9 @@ dependencies = [ [[package]] name = "parking_lot" -version = "0.12.2" +version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e4af0ca4f6caed20e900d564c242b8e5d4903fdacf31d3daf527b66fe6f42fb" +checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" dependencies = [ "lock_api", "parking_lot_core", @@ -1336,7 +1336,7 @@ dependencies = [ "pest_meta", "proc-macro2", "quote", - "syn 2.0.64", + "syn 2.0.66", ] [[package]] @@ -1395,9 +1395,9 @@ checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c" [[package]] name = "proc-macro2" -version = "1.0.82" +version = "1.0.84" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ad3d49ab951a01fbaafe34f2ec74122942fe18a3f9814c3268f1bb72042131b" +checksum = "ec96c6a92621310b51366f1e28d05ef11489516e93be030060e5fc12024a49d6" dependencies = [ "unicode-ident", ] @@ -1540,7 +1540,7 @@ dependencies = [ "bril-rs 0.1.0 (git+https://github.com/uwplse/bril?rev=e2be3f5d7e160f02b7aed0ef2bcc3e13ae722d2b)", "clap", "proc-macro2", - "syn 2.0.64", + "syn 2.0.66", ] [[package]] @@ -1597,22 +1597,22 @@ checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b" [[package]] name = "serde" -version = "1.0.202" +version = "1.0.203" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "226b61a0d411b2ba5ff6d7f73a476ac4f8bb900373459cd00fab8512828ba395" +checksum = "7253ab4de971e72fb7be983802300c30b5a7f0c2e56fab8abfc6a214307c0094" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.202" +version = "1.0.203" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6048858004bcff69094cd972ed40a32500f153bd3be9f716b2eed2e8217c4838" +checksum = "500cbc0ebeb6f46627f50f3f5811ccf6bf00643be300b4c3eabc0ef55dc5b5ba" dependencies = [ "proc-macro2", "quote", - "syn 2.0.64", + "syn 2.0.66", ] [[package]] @@ -1720,7 +1720,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.64", + "syn 2.0.66", ] [[package]] @@ -1746,9 +1746,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.64" +version = "2.0.66" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ad3dee41f36859875573074334c200d1add8e4a87bb37113ebd31d926b7b11f" +checksum = "c42f3f41a2de00b01c0aaad383c5a45241efc8b2d1eda5661812fda5f3cdcff5" dependencies = [ "proc-macro2", "quote", @@ -1795,22 +1795,22 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.60" +version = "1.0.61" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "579e9083ca58dd9dcf91a9923bb9054071b9ebbd800b342194c9feb0ee89fc18" +checksum = "c546c80d6be4bc6a00c0f01730c08df82eaa7a7a61f11d656526506112cc1709" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.60" +version = "1.0.61" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2470041c06ec3ac1ab38d0356a6119054dedaea53e12fbefc0de730a1c08524" +checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533" dependencies = [ "proc-macro2", "quote", - "syn 2.0.64", + "syn 2.0.66", ] [[package]] @@ -1866,9 +1866,9 @@ dependencies = [ [[package]] name = "triomphe" -version = "0.1.11" +version = "0.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "859eb650cfee7434994602c3a68b25d77ad9e68c8a6cd491616ef86661382eb3" +checksum = "1b2cb4fbb9995eeb36ac86fadf24031ccd58f99d6b4b2d7b911db70bddb80d90" [[package]] name = "typenum" @@ -2054,5 +2054,5 @@ checksum = "15e934569e47891f7d9411f1a451d947a60e000ab3bd24fbb970f000387d1b3b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.64", + "syn 2.0.66", ] diff --git a/dag_in_context/Cargo.toml b/dag_in_context/Cargo.toml index 5e2ba1741..da7a5d891 100644 --- a/dag_in_context/Cargo.toml +++ b/dag_in_context/Cargo.toml @@ -12,7 +12,7 @@ strum_macros = "0.25" main_error = "0.1.2" thiserror = "1.0" egraph-serialize = "0.1.0" -bril-rs = { git = "https://github.com/uwplse/bril", rev = "78881c45aa53231915f333d1d6dcc26cedc63b57" } +bril-rs = { git = "https://github.com/uwplse/bril", rev = "e2be3f5" } indexmap = "2.0.0" rustc-hash = "1.1.0" ordered-float = "3" diff --git a/dag_in_context/rust-toolchain b/dag_in_context/rust-toolchain index 283edc6d7..54227249d 100644 --- a/dag_in_context/rust-toolchain +++ b/dag_in_context/rust-toolchain @@ -1 +1 @@ -1.74.0 \ No newline at end of file +1.78.0 diff --git a/dag_in_context/src/typechecker.rs b/dag_in_context/src/typechecker.rs index 48d2d9920..5d933a5db 100644 --- a/dag_in_context/src/typechecker.rs +++ b/dag_in_context/src/typechecker.rs @@ -62,8 +62,8 @@ impl Expr { ); new_expr } - /// Adds argument types to the expression. + #[allow(dead_code)] pub(crate) fn add_arg_type(self: RcExpr, input_ty: Type) -> RcExpr { // we need a dummy program, since there are no calls in self let prog = program!(function("dummy", tuplet!(), tuplet!(), empty()),); @@ -73,6 +73,7 @@ impl Expr { new_expr } + #[allow(dead_code)] pub(crate) fn func_with_arg_types(self: RcExpr) -> RcExpr { match self.as_ref() { Expr::Function(name, in_ty, out_ty, body) => RcExpr::new(Expr::Function( diff --git a/rust-toolchain b/rust-toolchain index 283edc6d7..54227249d 100644 --- a/rust-toolchain +++ b/rust-toolchain @@ -1 +1 @@ -1.74.0 \ No newline at end of file +1.78.0 diff --git a/src/rvsdg/mod.rs b/src/rvsdg/mod.rs index bf91a8a17..5ffb1914a 100644 --- a/src/rvsdg/mod.rs +++ b/src/rvsdg/mod.rs @@ -34,6 +34,7 @@ pub(crate) mod live_variables; pub(crate) mod optimize_direct_jumps; pub(crate) mod restructure; pub(crate) mod rvsdg2svg; +pub(crate) mod simplify_branches; pub(crate) mod to_cfg; mod to_dag; diff --git a/src/rvsdg/optimize_direct_jumps.rs b/src/rvsdg/optimize_direct_jumps.rs index dab0c1a28..016f865f9 100644 --- a/src/rvsdg/optimize_direct_jumps.rs +++ b/src/rvsdg/optimize_direct_jumps.rs @@ -168,7 +168,16 @@ impl SimpleCfgFunction { impl SimpleCfgProgram { pub fn optimize_jumps(&self) -> Self { SimpleCfgProgram { - functions: self.functions.iter().map(|f| f.optimize_jumps()).collect(), + functions: self + .functions + .iter() + .map(|f| { + // NB: We could avoid this copy by having `optimize_jumps` take `self` by value. + let mut res = f.optimize_jumps(); + res.simplify_branches(); + res + }) + .collect(), } } } diff --git a/src/rvsdg/rvsdg2svg.rs b/src/rvsdg/rvsdg2svg.rs index b044ed04b..8c1559725 100644 --- a/src/rvsdg/rvsdg2svg.rs +++ b/src/rvsdg/rvsdg2svg.rs @@ -1,4 +1,5 @@ use std::collections::{BTreeMap, BTreeSet}; +use std::fmt; use std::iter::once; use bril_rs::ConstOps; @@ -70,22 +71,18 @@ impl Xml { } } -impl ToString for Xml { - fn to_string(self: &Xml) -> String { - use std::fmt::Write; - let mut out = String::new(); - - write!(out, "<{}", self.tag).unwrap(); +impl fmt::Display for Xml { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "<{}", self.tag)?; for (key, value) in &self.attributes { - write!(out, "\n\t{key}=\"{value}\"").unwrap(); + write!(f, "\n\t{key}=\"{value}\"")?; } - writeln!(out, ">").unwrap(); + writeln!(f, ">")?; for line in self.body.lines() { - writeln!(out, "\t{line}").unwrap(); + writeln!(f, "\t{line}")?; } - writeln!(out, "", self.tag).unwrap(); - - out + writeln!(f, "", self.tag)?; + Ok(()) } } @@ -655,7 +652,7 @@ impl RvsdgProgram { let (size, mut xml) = function.to_region().to_xml(false); // assert that it doesn't have a transform yet - assert!(xml.attributes.get("transform").is_none()); + assert!(!xml.attributes.contains_key("transform")); xml.attributes .insert("transform".to_owned(), format!("translate(0, {})", height)); xmls.push(xml); diff --git a/src/rvsdg/simplify_branches.rs b/src/rvsdg/simplify_branches.rs new file mode 100644 index 000000000..6822bed7f --- /dev/null +++ b/src/rvsdg/simplify_branches.rs @@ -0,0 +1,652 @@ +//! A pass over a CFG returned from the RVSDG=>CFG [conversion module](crate::rvsdg::to_cfg) +//! to simplify branch structures. +//! +//! # Introduction +//! RVSDGs are more structured than arbitrary CFGs. The only control strutures +//! that RVSDGs support directly are ifs (with an else) and tail-controlled +//! loops. This means that any other control flow constructs, from `for` loops +//! all the way to `goto`s need to be simulated using auxiliary predicate +//! variables introduced during translation time. +//! +//! The resulting representation is great for declarative rewrites, but can +//! generate slower code when directly interpreted as a CFG, e.g. to break out +//! of multiple layers of nested loop, CFGs only require a single jump but +//! RVSDGs require a branch for every layer of nesting. +//! +//! The algorithm in this module aims to recover more natural, efficient +//! control-flow structure from the raw CFG generated by an RVSDG. It is +//! inspired by the PCFR algorithm described in "Perfect Reconstructability of +//! Control Flow from Demand Dependence Graphs" by Bahman, Reissmann, Jahre, and +//! Meyer, but it has a different structure: +//! +//! * PCFR operates on the RVSDG directly, while this algorithm operates on the +//! resulting CFG. This is pragmatically useful for eggcc, which already has +//! fairly involved RVSDG=>CFG conversion code. +//! * PCFR expects an RVSDG in _predicate continuation form_, where predicates +//! are introduced immediately before they are used. eggcc almost certainly +//! does not preserve this property, and we want to avoid duplicating or +//! splitting RVSDG nodes to reintroduce it. The algorithm in this module is +//! robust to some predicates being used more than once, sometiems across +//! branches. +//! * The algorithm in this module has not been optimized for efficiency and as +//! a result is likely slower than a good implementation of PCFR. This doesn't +//! seem like an inherent issue and the performance of the two should be +//! similar after some optimization. +//! * The paper from Bahman et. al. also sketches a "ShortCircuitCFG" algorithm +//! that is similar to the algorithm here, but makes some simplifying +//! assumptions, potentially based again on predicate continuation form. +//! +//! # Algorithm Overview +//! The algorithm code is fairly heavily commented. It relies on computing the +//! fixpoint of a monotone dataflow analysis tracking the value of boolean +//! identifiers at each CFG node. The analysis takes branches into account: +//! successors along the "true" edge of a branch on 'x' know that 'x' is true. +//! With that information in place (along with a few technical details explained +//! in code comments), we apply two kinds of rewrites on the CFG: +//! +//! * For patterns like `X -[e]-> Y -[if a=1]-> Z` where we know that `a=1` in `X` +//! (and `Y` doesn't overwrite `a`), rewrite to `X -[e]-> Z`. +//! * For patterns like `X -[if a=1]-> Y` where we know that `a=1` in `X`, +//! rewrite to `X -[jump]-> Y` and remove all other outgoing edges from `X`. +//! If this is the only incoming branch to `Y`, a future optimize_direct_jumps +//! pass will merge the two blocks entirely. +//! +//! The boolean value analysis should converge quickly given the structure of +//! the CFGs we generate, but the current implementation involves lots of +//! copying of data: If the CFG were in SSI form (SSA + variable splits on +//! branches), I believe that we could build a more efficient analysis by +//! looking at nodes where variables are assigned (or branch targets) and +//! relying on dominance information to infer whether a boolean variable has a +//! known value at a node. + +use std::{collections::VecDeque, io::Write, mem}; + +use crate::cfg::{BasicBlock, BlockName, Branch, BranchOp, CondVal, Identifier, SimpleCfgFunction}; +use bril_rs::{Argument, Instruction, Literal, Type, ValueOps}; +use hashbrown::{HashMap, HashSet}; +use indexmap::{IndexMap, IndexSet}; +use petgraph::{ + graph::{EdgeIndex, NodeIndex}, + visit::Dfs, + Direction, +}; + +impl SimpleCfgFunction { + pub(crate) fn simplify_branches(&mut self) { + // Run the whole thing in a fixpoint loop. Question: does doing reverse + // post-order iteration + incrementally maintaining the value analysis + // allow us to converge in a single iteration? + let mut changed = true; + let mut var_counter = 0; + while changed { + // Step 1: compute some information about the CFG. + // * Find "administrative" nodes. + // * Find conditional branches. + // * Start off a Value Analysis for the function. + let branch_meta = self.get_branch_metadata(); + let mut val_analysis = ValueAnalysis::new(self); + // Step 2: split conditional branches and mark the relevant constants as + // known in the later nodes. This lets us simplify the value analysis by + // having empty nodes enncapsulate the information imparted by the + // branch. + for (id, edge, val) in branch_meta + .branches + .iter() + .flat_map(|(id, edges)| edges.iter().map(move |(edge, val)| (id, *edge, val))) + { + let Some(lit) = to_lit(val) else { + continue; + }; + // Count downwards from usize::MAX to avoid collisions with other placeholders + let node_bound = usize::MAX - var_counter; + var_counter += 1; + let (source, target) = self.graph.edge_endpoints(edge).unwrap(); + let weight = self.graph.remove_edge(edge).unwrap(); + let block_name = BlockName::Placeholder(node_bound); + let mid = self.graph.add_node(BasicBlock::empty(block_name)); + self.graph.add_edge(source, mid, weight); + // NB: We rely on the optimize_direct_jumps pass to collapse this + // back down. + self.graph.add_edge( + mid, + target, + Branch { + op: BranchOp::Jmp, + pos: None, + }, + ); + val_analysis.add_assignment(mid, id.clone(), ValueInfo::Known(lit)); + } + // Step 3: Compute the fixpoint of the value analysis. + val_analysis.compute_fixpoint(self); + // Step 4: Rewrite branches + changed = self.rewrite_branches(&branch_meta, &val_analysis); + // Step 5: Remove any nodes no longer reachable from the entry. + self.retain_reachable(); + + // Step 6: Rewrite direct jumps. This will un-split any edges split + // in step 2, and potentially merge nodes where a conditional branch + // was replaced with a jump. + *self = self.optimize_jumps(); + } + } + + fn get_branch_metadata(&self) -> BranchMetadata { + let mut res = BranchMetadata::default(); + for node in self.graph.node_indices() { + let block = &self.graph[node]; + if block.footer.is_empty() && block.instrs.iter().all(is_admin_instr) { + res.admin_nodes.insert(node); + } + for (id, lit) in block.instrs.iter().filter_map(constants_assigned) { + res.constants_known.add_constant(node, id, lit); + } + let mut walker = self + .graph + .neighbors_directed(node, Direction::Outgoing) + .detach(); + while let Some((edge, _)) = walker.next(&self.graph) { + if let BranchOp::Cond { arg, val, .. } = &self.graph[edge].op { + res.branches + .entry(arg.clone()) + .or_default() + .insert(edge, *val); + } + } + } + res + } + + /// Remove any nodes that are no longer reachable from the entry node. + fn retain_reachable(&mut self) { + let mut walker = Dfs::new(&self.graph, self.entry); + while walker.next(&self.graph).is_some() {} + let mut to_remove = vec![]; + for node_id in self.graph.node_indices() { + if !walker.discovered.contains(node_id.index()) { + to_remove.push(node_id); + assert_ne!( + node_id, self.exit, + "branch simplification removed the exit node!" + ); + } + } + for node_id in to_remove { + self.graph.remove_node(node_id); + } + } + + /// Simplify control flow using the information gathered in the initial steps. + /// * For each administrative node `n``... + /// * For each outgoing branch [edge e1] with cond val `v` for `id` + /// * Check if `id` was written to in `n`, if it was, then move on + /// _unless_ we know the value of `id`; in which case we can replace the branch with a jump. + /// * Otherwise, check if a predecessor [via edge e2] node has `v` as a + /// known value for `id`. + /// * If so, copy the contents of the admin node to that predecessor, and + /// reroute e2 to the target of e1. + fn rewrite_branches( + &mut self, + branch_meta: &BranchMetadata, + val_analysis: &ValueAnalysis, + ) -> bool { + let mut scratch = Vec::new(); + let mut changed = false; + for admin_node in &branch_meta.admin_nodes { + let mut walker = self + .graph + .neighbors_directed(*admin_node, Direction::Outgoing) + .detach(); + // Don't reroute past the exit node. We want to make sure it stays reachable. + if admin_node == &self.exit { + continue; + } + while let Some((outgoing, succ)) = walker.next(&self.graph) { + let BranchOp::Cond { arg, val, .. } = self.graph[outgoing].op.clone() else { + continue; + }; + let Some(val) = to_lit(&val) else { + continue; + }; + if val_analysis.data[admin_node].kills.contains(&arg) { + if succ != self.exit + && self.graph.neighbors(*admin_node).any(|x| x == self.exit) + { + // Don't remove any outgoing links to the exit node. + break; + } + // We assign to the branched-on argument in the admin + // node. See if we can fold the constant branch here. + let ValueInfo::Known(lit) = val_analysis.data[admin_node].get_output(&arg) + else { + continue; + }; + if lit != val { + continue; + } + // okay, we have found a matching edge. Replace this branch + // with a jump. + let mut walker = self + .graph + .neighbors_directed(*admin_node, Direction::Outgoing) + .detach(); + while let Some((outgoing, _)) = walker.next(&self.graph) { + self.graph.remove_edge(outgoing); + } + self.graph.add_edge( + *admin_node, + succ, + Branch { + op: BranchOp::Jmp, + pos: None, + }, + ); + changed = true; + // Don't run the rest of the inner loop. + break; + } + let mut incoming_walker = self + .graph + .neighbors_directed(*admin_node, Direction::Incoming) + .detach(); + while let Some((incoming, pred)) = incoming_walker.next(&self.graph) { + // We may be able to reroute a branch if the value in + // question is known to equal the branch value in a + // predecessor block. + if !matches!(val_analysis.data[&pred].get_output(&arg), ValueInfo::Known(v) if v == val) + { + continue; + } + + // We only have to worry about `instrs` because we + // checked that the footer was empty when we populated + // admin_nodes. We do this because we more or less don't + // use footers on our way back to bril. + scratch.extend(self.graph[*admin_node].instrs.iter().cloned()); + let (_, target) = self.graph.edge_endpoints(outgoing).unwrap(); + let target_incoming = self + .graph + .neighbors_directed(target, Direction::Incoming) + .count(); + let is_jump = matches!(self.graph[incoming].op, BranchOp::Jmp); + // Now it comes to move the block somewhere: if the + // incoming edge is a jump, then we would run all of the + // instructions in the current block anyway, we can just + // move them up. + if is_jump { + let weight = self.graph.remove_edge(incoming).unwrap(); + self.graph[pred].instrs.append(&mut scratch); + self.graph.add_edge(pred, target, weight); + changed = true; + break; + } else if target_incoming == 1 { + // The next safe case is if we are replacing the targets + // only incoming edge. In that case, we can move the + // data down. + let weight = self.graph.remove_edge(incoming).unwrap(); + let target_block = &mut self.graph[target]; + scratch.append(&mut target_block.instrs); + mem::swap(&mut target_block.instrs, &mut scratch); + self.graph.add_edge(pred, target, weight); + changed = true; + break; + } else { + scratch.clear(); + // Otherwise we may need some sort of compatibility check to + // merge the block somewhere. Add the edge back for now: + } + } + } + } + changed + } +} + +#[derive(Default, Debug)] +struct BranchMetadata { + /// Nodes that only contain administrative instructions. + admin_nodes: IndexSet, + /// Information about known constant values at particular nodes. + constants_known: ConstantInfo, + /// Relevant values used as branches. + branches: IndexMap>, +} + +/// Constants with a known value as of a given node. +/// +/// For now, the constants are always booleans, but we keep arbitrary +/// Literals around to make it easier to handle multi-way branches later. +#[derive(Default, Debug)] +struct ConstantInfo { + by_node: IndexMap>, + by_id: IndexMap>, +} + +impl ConstantInfo { + fn add_constant(&mut self, node: NodeIndex, id: Identifier, lit: Literal) { + if self + .by_node + .entry(node) + .or_default() + .insert(id.clone(), lit.clone()) + .is_none() + { + self.by_id.entry(id).or_default().push((node, lit)); + } + } +} + +/// "Administrative Instructions" are ones that will have essentially no runtime +/// cost once they go through instruction selection / register allocation. We +/// use these as a heuristic to find blocks that are safe to merge into their +/// predecessors in exchange for simpler control flow: RVSDG conversion overhead +/// is largely contained in blocks only containing these instructions. +fn is_admin_instr(inst: &Instruction) -> bool { + matches!( + inst, + Instruction::Constant { .. } + | Instruction::Value { + op: ValueOps::Id, + .. + } + ) +} + +fn constants_assigned(inst: &Instruction) -> Option<(Identifier, Literal)> { + if let Instruction::Constant { + dest, + value: value @ Literal::Bool(..), + .. + } = inst + { + Some((dest.into(), value.clone())) + } else { + None + } +} + +fn to_lit(cv: &CondVal) -> Option { + if cv.of == 2 { + Some(if cv.val == 0 { + Literal::Bool(false) + } else { + Literal::Bool(true) + }) + } else { + // Not handling multi-way branches for now. + None + } +} + +/// A basic semilattice describing the state of a value. +#[derive(Clone, Default, Debug)] +enum ValueInfo { + /// Nothing is currently known about the value. + #[default] + Bot, + /// The value is known to hold a concrete value. + Known(Literal), + /// We know that we cannot approximate the value with a single constant. + Top, +} + +impl ValueInfo { + /// Merge two ValueInfos, returning true if `self` changed. + fn merge(&mut self, other: &ValueInfo) -> bool { + match (self, other) { + (ValueInfo::Bot, ValueInfo::Bot) => false, + (slf @ ValueInfo::Bot, x) => { + *slf = x.clone(); + true + } + (ValueInfo::Top, _) => false, + (slf, ValueInfo::Top) => { + *slf = ValueInfo::Top; + true + } + (ValueInfo::Known(l), ValueInfo::Known(r)) if l == r => false, + (slf @ ValueInfo::Known(_), ValueInfo::Known(_)) => { + *slf = ValueInfo::Top; + true + } + (ValueInfo::Known(_), ValueInfo::Bot) => false, + } + } +} + +/// Monotone transforms on ValueInfos. +#[derive(Debug)] +enum Transform { + Id, + Negate, + OverWrite(ValueInfo), +} + +impl Transform { + fn apply(&self, val: &ValueInfo) -> ValueInfo { + match self { + Transform::Id => val.clone(), + Transform::Negate => match val { + ValueInfo::Bot => ValueInfo::Bot, + ValueInfo::Known(Literal::Bool(b)) => ValueInfo::Known(Literal::Bool(!b)), + ValueInfo::Known(..) => ValueInfo::Top, + ValueInfo::Top => ValueInfo::Top, + }, + Transform::OverWrite(info) => info.clone(), + } + } +} + +/// The state of the (boolean) values in a particular basic block. +#[derive(Default, Debug)] +struct ValueState { + /// The (pointwise) join of all of the values in incoming branches. + inherited: IndexMap, + /// The transforms induced by any operations on variables in the block. + /// + /// These are computed during initialization from instructions in a basic + /// block and during step 2 of the main algorithm to add values for the + /// targets of conditional branches. + transforms: VecDeque<( + Identifier, /* dst */ + Identifier, /* src */ + Transform, + )>, + /// The set of variables written to in this basic block. + kills: HashSet, + /// The materialized output of transforms on inherited. + outputs: IndexMap, + /// A variable indicating if `outputs` is stale. + recompute: bool, +} + +impl ValueState { + /// Recompute the outputs for this state, if necessary. + fn maybe_recompute(&mut self) -> bool { + let res = self.recompute; + if self.recompute { + self.outputs.clear(); + for (id, info) in &self.inherited { + self.outputs.insert(id.clone(), info.clone()); + } + for (dst, src, transform) in &self.transforms { + let src_val = self.outputs.get(src).unwrap_or(&ValueInfo::Bot); + let dst_val = transform.apply(src_val); + self.outputs.insert(dst.clone(), dst_val); + self.kills.insert(dst.clone()); + } + + self.recompute = false; + } + res + } + fn outputs(&self) -> impl Iterator { + assert!(!self.recompute); + self.outputs.iter() + } + + fn get_output(&self, id: &Identifier) -> ValueInfo { + assert!(!self.recompute); + self.outputs.get(id).cloned().unwrap_or(ValueInfo::Bot) + } + + /// A special case of `merge_from` to handle self-loops. + fn merge_self(&mut self) { + let mut changed = false; + for (id, out) in self.outputs.iter() { + changed |= self.inherited.entry(id.clone()).or_default().merge(out); + } + if changed { + self.recompute = true; + } + } + + /// Update the given inputs with the contents of `other`. + fn merge_from(&mut self, other: &ValueState) { + let mut changed = false; + for (id, out) in other.outputs() { + changed |= self.inherited.entry(id.clone()).or_default().merge(out); + } + if changed { + self.recompute = true; + } + } + + /// Populate the ValueState with relevant instructions from the given basic + /// block. + fn new(block: &BasicBlock) -> ValueState { + let mut transforms = VecDeque::new(); + for instr in &block.instrs { + match instr { + Instruction::Constant { + dest, + value: lit @ Literal::Bool(..), + .. + } => { + // The `src` identifier is unused in this case. + transforms.push_back(( + Identifier::from(dest.clone()), + Identifier::Num(usize::MAX), + Transform::OverWrite(ValueInfo::Known(lit.clone())), + )); + } + Instruction::Value { + args, + dest, + op, + op_type: Type::Bool, + .. + } => match op { + ValueOps::Id => { + assert_eq!(args.len(), 1); + transforms.push_back(( + Identifier::from(dest.clone()), + args[0].clone().into(), + Transform::Id, + )); + } + ValueOps::Not => { + assert_eq!(args.len(), 1); + transforms.push_back(( + Identifier::from(dest.clone()), + args[0].clone().into(), + Transform::Negate, + )); + } + _ => { + transforms.push_back(( + Identifier::from(dest.clone()), + Identifier::Num(usize::MAX), + Transform::OverWrite(ValueInfo::Top), + )); + } + }, + Instruction::Effect { .. } => {} + Instruction::Constant { .. } => {} + Instruction::Value { .. } => {} + } + } + ValueState { + transforms, + recompute: true, + ..Default::default() + } + } +} + +struct ValueAnalysis { + data: HashMap, +} + +impl ValueAnalysis { + fn new(graph: &SimpleCfgFunction) -> ValueAnalysis { + let mut res = ValueAnalysis { + data: Default::default(), + }; + for node in graph.graph.node_indices() { + res.data.insert(node, ValueState::new(&graph.graph[node])); + } + for Argument { name, arg_type } in &graph.args { + if let Type::Bool = arg_type { + let id = Identifier::from(name.clone()); + res.add_assignment(graph.entry, id, ValueInfo::Top); + } + } + res + } + + /// Prepend a virtual `id` instruction to the analysis for this node. + fn add_assignment(&mut self, node: NodeIndex, dst: Identifier, val: ValueInfo) { + let state = self.data.entry(node).or_default(); + state + .transforms + .push_front((dst, Identifier::Num(usize::MAX), Transform::OverWrite(val))); + state.recompute = true; + } + + /// A simple worklist algorithm for propagating values through the CFG. + fn compute_fixpoint(&mut self, func: &SimpleCfgFunction) { + let mut worklist = IndexSet::::default(); + for node in func.graph.node_indices() { + self.data.entry(node).or_default().maybe_recompute(); + worklist.insert(node); + } + while let Some(node) = worklist.pop() { + let mut cur = mem::take(self.data.get_mut(&node).unwrap()); + for incoming in func.graph.neighbors_directed(node, Direction::Incoming) { + if incoming == node { + cur.merge_self(); + } else { + cur.merge_from(&self.data[&incoming]); + } + } + let changed = cur.maybe_recompute(); + self.data.insert(node, cur); + if changed { + for outgoing in func.graph.neighbors_directed(node, Direction::Outgoing) { + worklist.insert(outgoing); + } + } + } + } + + /// Debugging routine for printing out the state of the analysis. + #[allow(unused)] + fn render(&self, func: &SimpleCfgFunction) -> String { + let mut buf = Vec::::new(); + for (node, state) in &self.data { + let name = &func.graph[*node].name; + writeln!(buf, "{name}.inputs: {{").unwrap(); + for (id, info) in &state.inherited { + writeln!(buf, " {id:?}: {info:?}").unwrap(); + } + writeln!(buf, "}}").unwrap(); + writeln!(buf, "{name}.outputs: {{").unwrap(); + for (id, info) in &state.outputs { + writeln!(buf, " {id:?}: {info:?}").unwrap(); + } + writeln!(buf, "}}").unwrap(); + } + String::from_utf8(buf).unwrap() + } +} diff --git a/src/rvsdg/to_dag.rs b/src/rvsdg/to_dag.rs index a6b410435..ccb9a3be4 100644 --- a/src/rvsdg/to_dag.rs +++ b/src/rvsdg/to_dag.rs @@ -123,7 +123,7 @@ impl<'a> DagTranslator<'a> { /// the first output using `skip_outputs`. fn translate_subregion( &mut self, - operands: impl Iterator + DoubleEndedIterator, + operands: impl DoubleEndedIterator, ) -> RcExpr { let resulting_exprs = operands.map(|operand| { let res = self.translate_operand(operand); diff --git a/tests/snapshots/files__fib_shape-optimize.snap b/tests/snapshots/files__fib_shape-optimize.snap index 28e318faf..662af857c 100644 --- a/tests/snapshots/files__fib_shape-optimize.snap +++ b/tests/snapshots/files__fib_shape-optimize.snap @@ -20,11 +20,13 @@ expression: visualization.result v9_: int = id v14_; v10_: int = id v5_; v11_: int = id v6_; + v4_: int = id v9_; + v5_: int = id v10_; + v6_: int = id v11_; + jmp .b7_; .b13_: v4_: int = id v9_; v5_: int = id v10_; v6_: int = id v11_; - br v8_ .b7_ .b15_; -.b15_: print v4_; } diff --git a/tests/snapshots/files__flatten_loop-optimize.snap b/tests/snapshots/files__flatten_loop-optimize.snap index f6eccb4d9..a3c090de6 100644 --- a/tests/snapshots/files__flatten_loop-optimize.snap +++ b/tests/snapshots/files__flatten_loop-optimize.snap @@ -42,25 +42,32 @@ expression: visualization.result v27_: int = id v34_; v28_: int = id v21_; v29_: int = id v22_; + v18_: int = id v25_; + v19_: int = id v26_; + v20_: int = id v27_; + v21_: int = id v28_; + v22_: int = id v29_; + jmp .b23_; .b31_: v18_: int = id v25_; v19_: int = id v26_; v20_: int = id v27_; v21_: int = id v28_; v22_: int = id v29_; - br v24_ .b23_ .b35_; -.b35_: - v36_: int = add v5_ v6_; - v11_: int = id v36_; + v35_: int = add v5_ v6_; + v11_: int = id v35_; v12_: int = id v6_; v13_: int = id v7_; v14_: int = id v8_; + v5_: int = id v11_; + v6_: int = id v12_; + v7_: int = id v13_; + v8_: int = id v14_; + jmp .b9_; .b16_: v5_: int = id v11_; v6_: int = id v12_; v7_: int = id v13_; v8_: int = id v14_; - br v10_ .b9_ .b37_; -.b37_: print v5_; } diff --git a/tests/snapshots/files__implicit-return-optimize.snap b/tests/snapshots/files__implicit-return-optimize.snap index 13def91e7..da3565f39 100644 --- a/tests/snapshots/files__implicit-return-optimize.snap +++ b/tests/snapshots/files__implicit-return-optimize.snap @@ -26,13 +26,16 @@ expression: visualization.result v13_: int = id v20_; v14_: int = id v6_; v15_: int = id v7_; + v4_: int = id v12_; + v5_: int = id v13_; + v6_: int = id v14_; + v7_: int = id v15_; + jmp .b8_; .b17_: v4_: int = id v12_; v5_: int = id v13_; v6_: int = id v14_; v7_: int = id v15_; - br v11_ .b8_ .b21_; -.b21_: print v4_; } @main { @@ -60,12 +63,15 @@ expression: visualization.result v12_: int = id v19_; v13_: int = id v6_; v14_: int = id v7_; + v4_: int = id v11_; + v5_: int = id v12_; + v6_: int = id v13_; + v7_: int = id v14_; + jmp .b8_; .b16_: v4_: int = id v11_; v5_: int = id v12_; v6_: int = id v13_; v7_: int = id v14_; - br v10_ .b8_ .b20_; -.b20_: print v4_; } diff --git a/tests/snapshots/files__loop_if-optimize.snap b/tests/snapshots/files__loop_if-optimize.snap index bbf8b9417..5d1067a5c 100644 --- a/tests/snapshots/files__loop_if-optimize.snap +++ b/tests/snapshots/files__loop_if-optimize.snap @@ -20,22 +20,25 @@ expression: visualization.result v6_: int = id v12_; v7_: bool = id v5_; v8_: int = id v13_; -.b9_: c14_: bool = const true; v15_: int = id v6_; v16_: bool = id c14_; v17_: int = id v8_; - br v5_ .b18_ .b19_; .b18_: - c20_: bool = const false; - v15_: int = id v6_; - v16_: bool = id c20_; - v17_: int = id v8_; -.b19_: - v21_: bool = not v5_; + v19_: bool = not v5_; v2_: int = id v6_; v3_: int = id v8_; - br v21_ .b4_ .b22_; -.b22_: + br v19_ .b4_ .b20_; +.b9_: + c14_: bool = const true; + v15_: int = id v6_; + v16_: bool = id c14_; + v17_: int = id v8_; + c21_: bool = const false; + v15_: int = id v6_; + v16_: bool = id c21_; + v17_: int = id v8_; + jmp .b18_; +.b20_: print v2_; } diff --git a/tests/snapshots/files__range_check-optimize.snap b/tests/snapshots/files__range_check-optimize.snap index cfecc22af..20046716a 100644 --- a/tests/snapshots/files__range_check-optimize.snap +++ b/tests/snapshots/files__range_check-optimize.snap @@ -23,14 +23,14 @@ expression: visualization.result br v5_ .b16_ .b17_; .b16_: v15_: int = id v14_; -.b17_: v2_: int = id v15_; - br v5_ .b3_ .b18_; + jmp .b3_; .b9_: - c19_: int = const 2; - print c19_; + c18_: int = const 2; + print c18_; v11_: int = id v2_; jmp .b12_; -.b18_: +.b17_: + v2_: int = id v15_; print v2_; } diff --git a/tests/snapshots/files__range_splitting-optimize.snap b/tests/snapshots/files__range_splitting-optimize.snap index 156f153b9..d99cbf262 100644 --- a/tests/snapshots/files__range_splitting-optimize.snap +++ b/tests/snapshots/files__range_splitting-optimize.snap @@ -22,14 +22,14 @@ expression: visualization.result br v7_ .b15_ .b16_; .b15_: v14_: int = id v5_; -.b16_: v2_: int = id v14_; - br v7_ .b3_ .b17_; + jmp .b3_; .b10_: - c18_: int = const 2; - print c18_; + c17_: int = const 2; + print c17_; v12_: int = id v2_; jmp .b13_; -.b17_: +.b16_: + v2_: int = id v14_; print v2_; } diff --git a/tests/snapshots/files__small-fib-optimize.snap b/tests/snapshots/files__small-fib-optimize.snap index 6899c481e..aa255d216 100644 --- a/tests/snapshots/files__small-fib-optimize.snap +++ b/tests/snapshots/files__small-fib-optimize.snap @@ -31,11 +31,15 @@ expression: visualization.result v14_: int = id v19_; v15_: int = id v8_; v16_: int = id v10_; + v7_: int = id v13_; + v8_: int = id v14_; + v9_: int = id v15_; + v10_: int = id v16_; + jmp .b11_; .b18_: v7_: int = id v13_; v8_: int = id v14_; v9_: int = id v15_; v10_: int = id v16_; - br v12_ .b11_ .b5_; .b5_: } diff --git a/tests/snapshots/files__sqrt-optimize.snap b/tests/snapshots/files__sqrt-optimize.snap index 24a54e20c..69c298317 100644 --- a/tests/snapshots/files__sqrt-optimize.snap +++ b/tests/snapshots/files__sqrt-optimize.snap @@ -72,11 +72,12 @@ expression: visualization.result .b14_: v8_: float = id v13_; v9_: bool = id v12_; -.b11_: - br v9_ .b47_ .b48_; + br v9_ .b11_ .b47_; .b47_: - v49_: float = fdiv v8_ v8_; - print v49_; -.b48_: -.b50_: + ret; +.b11_: + v48_: float = fdiv v8_ v8_; + print v48_; + jmp .b47_; +.b49_: }