Skip to content

Commit

Permalink
Implement eval recursively.
Browse files Browse the repository at this point in the history
  • Loading branch information
porcuquine committed Apr 19, 2024
1 parent 7423db2 commit 565e2f5
Showing 1 changed file with 83 additions and 48 deletions.
131 changes: 83 additions & 48 deletions examples/ski/ski/src/terms.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
#![allow(non_snake_case)]
use serde::{Deserialize, Serialize};
use std::borrow::Borrow;
use std::cmp::Eq;
use std::collections::HashMap;
use std::fmt::Formatter;
use std::hash::Hash;
use std::io;
use std::str::FromStr;

Expand Down Expand Up @@ -97,36 +99,48 @@ pub struct Step {
depth: usize,
}

impl Step {
impl Op {
pub fn fmt_to_string(&self, mem: &Mem) -> String {
let mut out = Vec::new();
self.fmt(mem, &mut out).unwrap();
String::from_utf8(out).unwrap()
}

pub fn fmt<W: io::Write>(&self, mem: &Mem, w: &mut W) -> Result<(), io::Error> {
write!(w, "[{}]", self.depth)?;
for _ in 0..self.depth {
write!(w, " ")?;
}

match &self.op {
match self {
Op::Reduce(term) => {
write!(w, "Reduce ")?;
term.fmt(mem, w)?;
term.fmt(mem, w)
}
Op::Eval(term) => {
write!(w, "Eval ")?;
term.fmt(mem, w)?;
term.fmt(mem, w)
}
Op::Apply(left, right) => {
write!(w, "Apply ")?;
left.fmt(mem, w)?;
write!(w, "<-[")?;
right.fmt(mem, w)?;
write!(w, "] ")?;
write!(w, "] ")
}
}
}
}

impl Step {
pub fn fmt_to_string(&self, mem: &Mem) -> String {
let mut out = Vec::new();
self.fmt(mem, &mut out).unwrap();
String::from_utf8(out).unwrap()
}

pub fn fmt<W: io::Write>(&self, mem: &Mem, w: &mut W) -> Result<(), io::Error> {
write!(w, "[{}]", self.depth)?;
for _ in 0..self.depth {
write!(w, " ")?;
}

self.op.fmt(mem, w)?;
write!(w, " => ")?;
self.out.fmt(mem, w)
}
Expand All @@ -152,35 +166,43 @@ impl Step {
String::from_utf8(out).unwrap()
}
}

const S_ADDR: usize = 0;
const K_ADDR: usize = 1;
const I_ADDR: usize = 2;
const NIL_ADDR: usize = 3;

fn setup(op: Op, mem: &mut Mem, depth: usize) -> (usize, Option<Term>) {
let mut existing = None;
let mut found = None;
mem.memo.entry(op.clone()).and_modify(|e| {
// `query_value` increments multiplicity.
existing = Some(e.query_value());
e.1 += 1;
found = Some(e.0.clone());
});

if let Some(found) = existing {
return (depth, Some(found.clone()));
};

let step = Step {
op,
// Placeholder: this will be updated when output is known.
out: mem.I(),
out: found.clone().unwrap_or_else(|| mem.I()),
depth,
};
mem.steps.push(step);
(mem.steps.len() - 1, None)

(mem.steps.len() - 1, found)
}

fn finalize(op: Op, step_index: usize, mem: &mut Mem, result: Term) {
mem.steps[step_index].out = result.clone();
mem.memo.insert(op, WithMultiplicity::first_access(result));

mem.memo
.entry(op)
// Just in case it was inserted in the body, after setup.
.and_modify(|e| {
// `query_value` increments multiplicity.
assert_eq!(result, e.query_value());

unreachable!("This should never happen.");
})
.or_insert_with(|| WithMultiplicity::first_access(result));
}

macro_rules! with_memo {
Expand All @@ -190,11 +212,12 @@ macro_rules! with_memo {
let op = $op;
let (step_index, found) = setup(op.clone(), $mem, $depth);
if let Some(found) = found {
return found;
};
let $result = $body;
finalize(op, step_index, $mem, $result.clone());
$result
found
} else {
let $result = $body;
finalize(op, step_index, $mem, $result.clone());
$result
}
}};
}

Expand Down Expand Up @@ -238,7 +261,7 @@ impl Term {
}
}

fn first(&self, mem: &mut Mem) -> (Self, Option<Self>) {
fn first(&self, mem: &Mem) -> (Self, Option<Self>) {
match self {
Self::Cons(_, first, rest) => {
let rest_term = mem.get_term(*rest);
Expand Down Expand Up @@ -367,42 +390,44 @@ impl Term {
with_memo!(Op::Reduce(self.clone()), (mem, depth, result), {
match self {
Self::S3(_, x, y, z) => {
// Q: Do we need to eval x too?
//
// Some of these `eval`s could probably be just `reduce`s, and that would
// reduce the amount of work performed. However, if we get it wrong that might
// result in some terms not being fully reduced. So for now, we will just conservatively
// fully evaluate all sub-terms before performing the reduction.
let ix = mem.get_term(x);
let iy = mem.get_term(y).eval(mem, depth + 1);
let iz = mem.get_term(z).eval(mem, depth + 1);
let evaled_y = mem.get_term(y).eval(mem, depth + 1);
let evaled_z = mem.get_term(z).eval(mem, depth + 1);
let xz = ix
.apply(mem, iz.clone(), depth + 1)
.apply(mem, evaled_z.clone(), depth + 1)
.eval(mem, 1 + depth)
.clone();
let yz = iy.apply(mem, iz, depth + 1).eval(mem, depth + 1).clone();
let yz = evaled_y
.apply(mem, evaled_z, depth + 1)
.eval(mem, depth + 1)
.clone();
xz.apply(mem, yz, depth + 1)
}
Self::K2(_, x, _) => mem.get_term(x),
Self::I1(_, x) => mem.get_term(x),
Self::Cons(_, first, rest) => {
let first = mem.get_term(first);
let first_evaled = first.clone().eval(mem, depth + 1);
if mem.get_term(rest) == Term::Nil {
let rest = mem.get_term(rest);

if rest == Term::Nil {
if first_evaled == first {
first_evaled
} else {
mem.cons(first_evaled, Self::Nil)
}
} else {
let rest = mem.get_term(rest);
let (second, tail) = rest.first(mem);
let second_evaled = second.clone().eval(mem, depth + 1);
let applied = first_evaled.clone().apply(mem, second_evaled, depth + 1);

if let Some(tail) = tail {
mem.cons(applied, tail)
} else {
mem.cons(applied, Self::Nil)
}
mem.cons(applied, tail.unwrap_or(Self::Nil))
}
}
Self::Nil => unreachable!(),
Expand All @@ -413,18 +438,13 @@ impl Term {

pub fn eval(self, mem: &mut Mem, depth: usize) -> Self {
with_memo!(Op::Eval(self.clone()), (mem, depth, result), {
let mut prev_addr;
let mut term = self.clone();

loop {
prev_addr = term.addr();
term = term.reduce(mem, depth + 1);
let reduced = self.clone().reduce(mem, depth + 1);

if term.addr() == prev_addr {
break;
};
if reduced == self {
reduced
} else {
reduced.eval(mem, depth + 1)
}
term
})
}

Expand Down Expand Up @@ -520,10 +540,25 @@ impl Mem {
&self.terms[addr].0
}

pub fn get_term(&mut self, addr: Addr) -> Term {
pub fn get_term(&self, addr: Addr) -> Term {
self.terms[addr].0.clone()
}

pub fn query(&self, op: Op) -> Option<Term> {
// TODO: this should be indexed to avoid the expensive scan.
if let Some(found) = self.steps.iter().find(|step| step.op == op) {
Some(found.out.clone())
} else {
None
}
}

pub fn assert_memo_steps_consistency(&self) {
for step in &self.steps {
assert!(self.memo.get(&step.op).is_some());
}
}

// NOTE: The clones are shallow.
pub fn S(&mut self) -> Term {
self.terms[0].query_value().clone()
Expand Down

0 comments on commit 565e2f5

Please sign in to comment.