Skip to content

Commit

Permalink
refactor Task into a trait
Browse files Browse the repository at this point in the history
  • Loading branch information
lorepozo committed Dec 14, 2023
1 parent 68c0a52 commit b9e059c
Show file tree
Hide file tree
Showing 15 changed files with 416 additions and 299 deletions.
43 changes: 21 additions & 22 deletions examples/json_compressor.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use polytype::TypeSchema;
use programinduction::{lambda, ECFrontier, Task};
use programinduction::{lambda, noop_task, ECFrontier};
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use std::f64;

#[derive(Deserialize)]
struct ExternalCompressionInput {
Expand Down Expand Up @@ -56,15 +55,11 @@ struct Solution {
loglikelihood: f64,
}

fn noop_oracle(_: &lambda::Language, _: &lambda::Expression) -> f64 {
f64::NEG_INFINITY
}

struct CompressionInput {
dsl: lambda::Language,
params: lambda::CompressionParams,
tasks: Vec<Task<'static, lambda::Language, lambda::Expression, ()>>,
frontiers: Vec<ECFrontier<lambda::Language>>,
task_types: Vec<TypeSchema>,
frontiers: Vec<ECFrontier<lambda::Expression>>,
}
impl From<ExternalCompressionInput> for CompressionInput {
fn from(eci: ExternalCompressionInput) -> Self {
Expand Down Expand Up @@ -99,16 +94,11 @@ impl From<ExternalCompressionInput> for CompressionInput {
aic: eci.params.aic,
arity: eci.params.arity,
};
let (tasks, frontiers) = eci
let (task_types, frontiers) = eci
.frontiers
.into_par_iter()
.map(|f| {
let tp = TypeSchema::parse(&f.task_tp).expect("invalid task type");
let task = Task {
oracle: Box::new(noop_oracle),
observation: (),
tp,
};
let sols = f
.solutions
.into_iter()
Expand All @@ -119,13 +109,13 @@ impl From<ExternalCompressionInput> for CompressionInput {
(expr, s.logprior, s.loglikelihood)
})
.collect();
(task, ECFrontier(sols))
(tp, ECFrontier(sols))
})
.unzip();
CompressionInput {
dsl,
params,
tasks,
task_types,
frontiers,
}
}
Expand Down Expand Up @@ -153,10 +143,10 @@ impl From<CompressionInput> for ExternalCompressionOutput {
})
.collect();
let frontiers = ci
.tasks
.task_types
.par_iter()
.zip(&ci.frontiers)
.map(|(t, f)| {
.map(|(tp, f)| {
let solutions = f
.iter()
.map(|&(ref expr, logprior, loglikelihood)| {
Expand All @@ -169,7 +159,7 @@ impl From<CompressionInput> for ExternalCompressionOutput {
})
.collect();
Frontier {
task_tp: format!("{}", t.tp),
task_tp: tp.to_string(),
solutions,
}
})
Expand All @@ -187,9 +177,18 @@ fn main() {
let eci: ExternalCompressionInput =
serde_json::from_slice(include_bytes!("realistic_input.json")).expect("invalid json");

let ci = CompressionInput::from(eci);
let (dsl, _) = ci.dsl.compress(&ci.params, &ci.tasks, ci.frontiers);
for i in ci.dsl.invented.len()..dsl.invented.len() {
let CompressionInput {
dsl,
params,
task_types,
frontiers,
} = CompressionInput::from(eci);
let tasks = task_types
.into_iter()
.map(|tp| noop_task(f64::NEG_INFINITY, tp))
.collect::<Vec<_>>();
let (dsl, _) = dsl.compress(&params, &tasks, frontiers);
for i in dsl.invented.len()..dsl.invented.len() {
let (expr, _, _) = &dsl.invented[i];
eprintln!("invented {}", dsl.display(expr));
}
Expand Down
76 changes: 49 additions & 27 deletions src/domains/circuits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use rand::{
Rng,
};
use std::iter;
use std::sync::Arc;

use crate::lambda::{Evaluator as EvaluatorT, Expression, Language};
use crate::Task;
Expand Down Expand Up @@ -121,7 +122,7 @@ impl EvaluatorT for Evaluator {
pub fn make_tasks<R: Rng>(
rng: &mut R,
count: u32,
) -> Vec<Task<'static, Language, Expression, Vec<bool>>> {
) -> Vec<impl Task<[bool], Representation = Language, Expression = Expression>> {
make_tasks_advanced(
rng,
count,
Expand Down Expand Up @@ -159,7 +160,7 @@ pub fn make_tasks_advanced<R: Rng>(
gate_or: u32,
gate_mux2: u32,
gate_mux4: u32,
) -> Vec<Task<'static, Language, Expression, Vec<bool>>> {
) -> Vec<impl Task<[bool], Representation = Language, Expression = Expression>> {
let n_input_distribution =
WeightedIndex::new(n_input_weights).expect("invalid weights for number of circuit inputs");
let n_gate_distribution =
Expand All @@ -175,40 +176,61 @@ pub fn make_tasks_advanced<R: Rng>(
n_inputs = 1 + n_input_distribution.sample(rng);
n_gates = 1 + n_gate_distribution.sample(rng);
}
let tp = TypeSchema::Monotype(Type::from(vec![tp!(bool); n_inputs + 1]));
let circuit = gates::Circuit::new(rng, &gate_weights, n_inputs as u32, n_gates);
let outputs: Vec<_> = iter::repeat(vec![false, true])
.take(n_inputs)
.multi_cartesian_product()
.map(|ins| circuit.eval(&ins))
.collect();
let oracle_outputs = outputs.clone();
let evaluator = std::sync::Arc::new(Evaluator);
let oracle = Box::new(move |dsl: &Language, expr: &Expression| -> f64 {
let success = iter::repeat(vec![false, true])
.take(n_inputs)
.multi_cartesian_product()
.zip(&oracle_outputs)
.all(|(inps, out)| {
if let Ok(o) = dsl.eval_arc(expr, &evaluator, &inps) {
o == *out
} else {
false
}
});
if success {
0f64
CircuitTask::new(n_inputs, outputs)
})
.collect()
}

struct CircuitTask {
n_inputs: usize,
expected_outputs: Vec<bool>,
tp: TypeSchema,
}
impl CircuitTask {
fn new(n_inputs: usize, expected_outputs: Vec<bool>) -> Self {
let tp = TypeSchema::Monotype(Type::from(vec![tp!(bool); n_inputs + 1]));
CircuitTask {
n_inputs,
expected_outputs,
tp,
}
}
}
impl Task<[bool]> for CircuitTask {
type Representation = Language;
type Expression = Expression;

fn oracle(&self, dsl: &Self::Representation, expr: &Self::Expression) -> f64 {
let evaluator = Arc::new(Evaluator);
let success = iter::repeat(vec![false, true])
.take(self.n_inputs)
.multi_cartesian_product()
.zip(&self.expected_outputs)
.all(|(inps, out)| {
if let Ok(o) = dsl.eval_arc(expr, &evaluator, &inps) {
o == *out
} else {
f64::NEG_INFINITY
false
}
});
Task {
oracle,
observation: outputs,
tp,
}
})
.collect()
if success {
0f64
} else {
f64::NEG_INFINITY
}
}
fn tp(&self) -> &TypeSchema {
&self.tp
}
fn observation(&self) -> &[bool] {
&self.expected_outputs
}
}

mod gates {
Expand Down
6 changes: 3 additions & 3 deletions src/domains/strings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use rand::Rng;
use std::collections::HashMap;

use crate::lambda::{
task_by_evaluation_owned, Evaluator as EvaluatorT, Expression, Language, LiftedFunction,
task_by_evaluation, Evaluator as EvaluatorT, Expression, Language, LiftedFunction,
};
use crate::Task;

Expand Down Expand Up @@ -302,11 +302,11 @@ pub fn make_tasks<R: Rng>(
rng: &mut R,
count: usize,
n_examples: usize,
) -> Vec<Task<'static, Language, Expression, Vec<(Vec<Space>, Space)>>> {
) -> Vec<impl Task<[(Vec<Space>, Space)], Representation = Language, Expression = Expression>> {
(0..=count / 1467) // make_examples yields 1467 tasks
.flat_map(|_| gen::make_examples(rng, n_examples))
.take(count)
.map(|(_name, tp, examples)| task_by_evaluation_owned(Evaluator, tp, examples))
.map(|(_name, tp, examples)| task_by_evaluation(Evaluator, tp, examples))
.collect()
}

Expand Down
Loading

0 comments on commit b9e059c

Please sign in to comment.