diff --git a/src/lib.rs b/src/lib.rs index a0a346e..dec8dc0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -112,6 +112,9 @@ pub fn parse_path(path: &PathBuf, config: ParseConfig) -> anyhow::Result = RefCell::new(FxHashMap::default()); + let symbolic_shape_specialization_index: RefCell = + RefCell::new(FxHashMap::default()); + // Store results in an output Vec let mut output: Vec<(PathBuf, String)> = Vec::new(); @@ -145,7 +148,12 @@ pub fn parse_path(path: &PathBuf, config: ParseConfig) -> anyhow::Result anyhow::Result anyhow::Result match parser.name() { "dynamo_guards" => { - eprintln!("Failed to parse guards json: {}", err); + multi.suspend(|| eprintln!("Failed to parse guards json: {}", err)); stats.fail_dynamo_guards_json += 1; } name => { - eprintln!("Parser {name} failed: {err}"); + multi.suspend(|| eprintln!("Parser {name} failed: {err}")); stats.fail_parser += 1; } }, @@ -290,6 +298,14 @@ pub fn parse_path(path: &PathBuf, config: ParseConfig) -> anyhow::Result String { + let mut trie = StackTrieNode::default(); + trie.insert_no_terminal(stack.to_vec()); + trie.fmt(None).unwrap() +} + pub struct CompilationMetricsParser<'t> { - tt: &'t TinyTemplate<'t>, - stack_index: &'t RefCell, + pub tt: &'t TinyTemplate<'t>, + pub stack_index: &'t RefCell, + pub symbolic_shape_specialization_index: &'t RefCell, } impl StructuredLogParser for CompilationMetricsParser<'_> { fn name(&self) -> &'static str { @@ -341,16 +348,27 @@ impl StructuredLogParser for CompilationMetricsParser<'_> { .stack_index .borrow() .get(&cid) - .map_or("".to_string(), |stack| { - let mut trie = StackTrieNode::default(); - trie.insert_no_terminal(stack.to_vec()); - trie.fmt(None).unwrap() - }); + .map_or("".to_string(), format_stack); + let specializations = self + .symbolic_shape_specialization_index + .borrow_mut() + .remove(&cid) + .unwrap_or(Vec::new()) + .drain(..) + .map(|spec| SymbolicShapeSpecializationContext { + symbol: spec.symbol.unwrap_or("".to_string()), + sources: spec.sources.unwrap_or(Vec::new()), + value: spec.value.unwrap_or("".to_string()), + user_stack_html: format_stack(&spec.user_stack.unwrap_or(Vec::new())), + stack_html: format_stack(&spec.stack.unwrap_or(Vec::new())), + }) + .collect(); let context = CompilationMetricsContext { css: crate::CSS, m: &m, compile_id: id, stack_html: stack_html, + symbolic_shape_specializations: specializations, }; let output = self.tt.render(&filename, &context)?; simple_file_output(&filename, lineno, compile_id, &output) @@ -401,10 +419,7 @@ impl StructuredLogParser for AOTAutogradBackwardCompilationMetricsParser<'_> { } // Register your parser here -pub fn default_parsers<'t>( - tt: &'t TinyTemplate<'t>, - stack_index: &'t RefCell, -) -> Vec> { +pub fn default_parsers<'t>(tt: &'t TinyTemplate<'t>) -> Vec> { // We need to use Box wrappers here because vecs in Rust need to have known size let result: Vec> = vec![ Box::new(SentinelFileParser::new("optimize_ddp_split_graph", |e| { @@ -430,7 +445,6 @@ pub fn default_parsers<'t>( Box::new(DynamoGuardParser { tt }), Box::new(InductorOutputCodeParser), Box::new(OptimizeDdpSplitChildParser), - Box::new(CompilationMetricsParser { tt, stack_index }), // TODO: use own tt instances Box::new(AOTAutogradBackwardCompilationMetricsParser { tt }), // TODO: use own tt instances Box::new(LinkParser), ]; diff --git a/src/templates.rs b/src/templates.rs index 2c401a0..b1968ea 100644 --- a/src/templates.rs +++ b/src/templates.rs @@ -259,6 +259,21 @@ pub static TEMPLATE_COMPILATION_METRICS: &str = r#" {{ for op in m.non_compliant_ops }}
  • {op}
  • {{ endfor }} +

    Symbolic shape specializations

    + + + + + {{ for spec in symbolic_shape_specializations }} + + + + + + + + {{ endfor }} +
    Sym Source(s) Value User stack Framework stack
    {spec.symbol}{{ for source in spec.sources }}{source}
    {{ endfor }}
    {spec.value}{spec.user_stack_html | format_unescaped}{spec.stack_html | format_unescaped}
    "#; diff --git a/src/types.rs b/src/types.rs index 7a1087a..5882d7f 100644 --- a/src/types.rs +++ b/src/types.rs @@ -15,6 +15,8 @@ use std::sync::Mutex; pub type ParseOutput = Vec<(PathBuf, String)>; pub type CompilationMetricsIndex = FxIndexMap, Vec>; pub type StackIndex = FxHashMap, StackSummary>; // NB: attempt is always 0 here +pub type SymbolicShapeSpecializationIndex = + FxHashMap, Vec>; pub type FxIndexMap = IndexMap>; @@ -265,6 +267,16 @@ pub struct AOTAutogradBackwardCompilationMetricsMetadata { pub fail_reason: Option, } +#[derive(Debug, Deserialize, Serialize)] +pub struct SymbolicShapeSpecializationMetadata { + pub symbol: Option, + pub sources: Option>, + pub value: Option, + pub reason: Option, + pub stack: Option, + pub user_stack: Option, +} + #[derive(Debug, Serialize)] pub struct AOTAutogradBackwardCompilationMetricsContext<'e> { pub m: &'e AOTAutogradBackwardCompilationMetricsMetadata, @@ -278,6 +290,7 @@ pub struct CompilationMetricsContext<'e> { pub css: &'static str, pub compile_id: String, pub stack_html: String, + pub symbolic_shape_specializations: Vec, } #[derive(Debug, Serialize)] @@ -360,6 +373,7 @@ pub struct Envelope { Option, pub graph_dump: Option, pub link: Option, + pub symbolic_shape_specialization: Option, } #[derive(Debug, Deserialize, Serialize)] @@ -385,3 +399,12 @@ pub struct IndexContext { pub num_breaks: usize, pub custom_header_html: String, } + +#[derive(Debug, Serialize)] +pub struct SymbolicShapeSpecializationContext { + pub symbol: String, + pub sources: Vec, + pub value: String, + pub user_stack_html: String, + pub stack_html: String, +}