From 565e2f537d31cb274ab28e11261a5e0d2226d9fe Mon Sep 17 00:00:00 2001 From: porcuquine Date: Thu, 18 Apr 2024 13:49:09 -0700 Subject: [PATCH] Implement eval recursively. --- examples/ski/ski/src/terms.rs | 131 +++++++++++++++++++++------------- 1 file changed, 83 insertions(+), 48 deletions(-) diff --git a/examples/ski/ski/src/terms.rs b/examples/ski/ski/src/terms.rs index 14a3d908..b52a41ef 100644 --- a/examples/ski/ski/src/terms.rs +++ b/examples/ski/ski/src/terms.rs @@ -1,8 +1,10 @@ #![allow(non_snake_case)] use serde::{Deserialize, Serialize}; use std::borrow::Borrow; +use std::cmp::Eq; use std::collections::HashMap; use std::fmt::Formatter; +use std::hash::Hash; use std::io; use std::str::FromStr; @@ -97,7 +99,7 @@ pub struct Step { depth: usize, } -impl Step { +impl Op { pub fn fmt_to_string(&self, mem: &Mem) -> String { let mut out = Vec::new(); self.fmt(mem, &mut out).unwrap(); @@ -105,28 +107,40 @@ impl Step { } pub fn fmt(&self, mem: &Mem, w: &mut W) -> Result<(), io::Error> { - write!(w, "[{}]", self.depth)?; - for _ in 0..self.depth { - write!(w, " ")?; - } - - match &self.op { + match self { Op::Reduce(term) => { write!(w, "Reduce ")?; - term.fmt(mem, w)?; + term.fmt(mem, w) } Op::Eval(term) => { write!(w, "Eval ")?; - term.fmt(mem, w)?; + term.fmt(mem, w) } Op::Apply(left, right) => { write!(w, "Apply ")?; left.fmt(mem, w)?; write!(w, "<-[")?; right.fmt(mem, w)?; - write!(w, "] ")?; + write!(w, "] ") } } + } +} + +impl Step { + pub fn fmt_to_string(&self, mem: &Mem) -> String { + let mut out = Vec::new(); + self.fmt(mem, &mut out).unwrap(); + String::from_utf8(out).unwrap() + } + + pub fn fmt(&self, mem: &Mem, w: &mut W) -> Result<(), io::Error> { + write!(w, "[{}]", self.depth)?; + for _ in 0..self.depth { + write!(w, " ")?; + } + + self.op.fmt(mem, w)?; write!(w, " => ")?; self.out.fmt(mem, w) } @@ -152,35 +166,43 @@ impl Step { String::from_utf8(out).unwrap() } } + const S_ADDR: usize = 0; const K_ADDR: usize = 1; const I_ADDR: usize = 2; const NIL_ADDR: usize = 3; fn setup(op: Op, mem: &mut Mem, depth: usize) -> (usize, Option) { - let mut existing = None; + let mut found = None; mem.memo.entry(op.clone()).and_modify(|e| { - // `query_value` increments multiplicity. - existing = Some(e.query_value()); + e.1 += 1; + found = Some(e.0.clone()); }); - if let Some(found) = existing { - return (depth, Some(found.clone())); - }; - let step = Step { op, // Placeholder: this will be updated when output is known. - out: mem.I(), + out: found.clone().unwrap_or_else(|| mem.I()), depth, }; mem.steps.push(step); - (mem.steps.len() - 1, None) + + (mem.steps.len() - 1, found) } fn finalize(op: Op, step_index: usize, mem: &mut Mem, result: Term) { mem.steps[step_index].out = result.clone(); - mem.memo.insert(op, WithMultiplicity::first_access(result)); + + mem.memo + .entry(op) + // Just in case it was inserted in the body, after setup. + .and_modify(|e| { + // `query_value` increments multiplicity. + assert_eq!(result, e.query_value()); + + unreachable!("This should never happen."); + }) + .or_insert_with(|| WithMultiplicity::first_access(result)); } macro_rules! with_memo { @@ -190,11 +212,12 @@ macro_rules! with_memo { let op = $op; let (step_index, found) = setup(op.clone(), $mem, $depth); if let Some(found) = found { - return found; - }; - let $result = $body; - finalize(op, step_index, $mem, $result.clone()); - $result + found + } else { + let $result = $body; + finalize(op, step_index, $mem, $result.clone()); + $result + } }}; } @@ -238,7 +261,7 @@ impl Term { } } - fn first(&self, mem: &mut Mem) -> (Self, Option) { + fn first(&self, mem: &Mem) -> (Self, Option) { match self { Self::Cons(_, first, rest) => { let rest_term = mem.get_term(*rest); @@ -367,18 +390,23 @@ impl Term { with_memo!(Op::Reduce(self.clone()), (mem, depth, result), { match self { Self::S3(_, x, y, z) => { + // Q: Do we need to eval x too? + // // Some of these `eval`s could probably be just `reduce`s, and that would // reduce the amount of work performed. However, if we get it wrong that might // result in some terms not being fully reduced. So for now, we will just conservatively // fully evaluate all sub-terms before performing the reduction. let ix = mem.get_term(x); - let iy = mem.get_term(y).eval(mem, depth + 1); - let iz = mem.get_term(z).eval(mem, depth + 1); + let evaled_y = mem.get_term(y).eval(mem, depth + 1); + let evaled_z = mem.get_term(z).eval(mem, depth + 1); let xz = ix - .apply(mem, iz.clone(), depth + 1) + .apply(mem, evaled_z.clone(), depth + 1) .eval(mem, 1 + depth) .clone(); - let yz = iy.apply(mem, iz, depth + 1).eval(mem, depth + 1).clone(); + let yz = evaled_y + .apply(mem, evaled_z, depth + 1) + .eval(mem, depth + 1) + .clone(); xz.apply(mem, yz, depth + 1) } Self::K2(_, x, _) => mem.get_term(x), @@ -386,23 +414,20 @@ impl Term { Self::Cons(_, first, rest) => { let first = mem.get_term(first); let first_evaled = first.clone().eval(mem, depth + 1); - if mem.get_term(rest) == Term::Nil { + let rest = mem.get_term(rest); + + if rest == Term::Nil { if first_evaled == first { first_evaled } else { mem.cons(first_evaled, Self::Nil) } } else { - let rest = mem.get_term(rest); let (second, tail) = rest.first(mem); let second_evaled = second.clone().eval(mem, depth + 1); let applied = first_evaled.clone().apply(mem, second_evaled, depth + 1); - if let Some(tail) = tail { - mem.cons(applied, tail) - } else { - mem.cons(applied, Self::Nil) - } + mem.cons(applied, tail.unwrap_or(Self::Nil)) } } Self::Nil => unreachable!(), @@ -413,18 +438,13 @@ impl Term { pub fn eval(self, mem: &mut Mem, depth: usize) -> Self { with_memo!(Op::Eval(self.clone()), (mem, depth, result), { - let mut prev_addr; - let mut term = self.clone(); - - loop { - prev_addr = term.addr(); - term = term.reduce(mem, depth + 1); + let reduced = self.clone().reduce(mem, depth + 1); - if term.addr() == prev_addr { - break; - }; + if reduced == self { + reduced + } else { + reduced.eval(mem, depth + 1) } - term }) } @@ -520,10 +540,25 @@ impl Mem { &self.terms[addr].0 } - pub fn get_term(&mut self, addr: Addr) -> Term { + pub fn get_term(&self, addr: Addr) -> Term { self.terms[addr].0.clone() } + pub fn query(&self, op: Op) -> Option { + // TODO: this should be indexed to avoid the expensive scan. + if let Some(found) = self.steps.iter().find(|step| step.op == op) { + Some(found.out.clone()) + } else { + None + } + } + + pub fn assert_memo_steps_consistency(&self) { + for step in &self.steps { + assert!(self.memo.get(&step.op).is_some()); + } + } + // NOTE: The clones are shallow. pub fn S(&mut self) -> Term { self.terms[0].query_value().clone()