Skip to content

Commit

Permalink
move rand::thread_rng to seeded SmallRng
Browse files Browse the repository at this point in the history
  • Loading branch information
lorepozo committed Dec 14, 2023
1 parent b9e059c commit 7f00043
Show file tree
Hide file tree
Showing 9 changed files with 48 additions and 32 deletions.
5 changes: 3 additions & 2 deletions src/domains/circuits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
//! ```
//! use programinduction::domains::circuits;
//! use programinduction::{ECParams, EC};
//! use rand::{rngs::SmallRng, SeedableRng};
//!
//! let dsl = circuits::dsl();
//! let rng = &mut rand::thread_rng();
//! let rng = &mut SmallRng::from_seed([1u8; 32]);
//! let tasks = circuits::make_tasks(rng, 250);
//! let ec_params = ECParams {
//! frontier_limit: 100,
Expand All @@ -36,7 +37,7 @@ use crate::Task;

/// The circuit representation, a [`lambda::Language`], only defines the binary `nand` operation.
///
/// ```ignore
/// ```compile_fails
/// "nand": ptp!(@arrow[tp!(bool), tp!(bool), tp!(bool)])
/// ```
///
Expand Down
5 changes: 3 additions & 2 deletions src/domains/strings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
//! ```no_run
//! use programinduction::domains::strings;
//! use programinduction::{ECParams, EC};
//! use rand::{rngs::SmallRng, SeedableRng};
//!
//! let dsl = strings::dsl();
//! let rng = &mut rand::thread_rng();
//! let rng = &mut SmallRng::from_seed([1u8; 32]);
//! let tasks = strings::make_tasks(rng, 250, 4);
//! let ec_params = ECParams {
//! frontier_limit: 10,
Expand All @@ -33,7 +34,7 @@ use crate::Task;

/// The string editing [`lambda::Language`] defines the following operations:
///
/// ```ignore
/// ```compile_fails
/// "0": ptp!(int)
/// "+1": ptp!(@arrow[tp!(int), tp!(int)])
/// "-1": ptp!(@arrow[tp!(int), tp!(int)])
Expand Down
6 changes: 4 additions & 2 deletions src/ec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,11 @@ pub struct ECParams {
/// ```no_run
/// use programinduction::domains::circuits;
/// use programinduction::{lambda, ECParams, EC};
/// use rand::{rngs::SmallRng, SeedableRng};
///
/// let mut dsl = circuits::dsl();
///
/// let rng = &mut rand::thread_rng();
/// let rng = &mut SmallRng::from_seed([1u8; 32]);
/// let tasks = circuits::make_tasks(rng, 250);
/// let ec_params = ECParams {
/// frontier_limit: 10,
Expand Down Expand Up @@ -122,10 +123,11 @@ pub trait EC<Observation: ?Sized>: Sync + Sized {
/// ```no_run
/// use programinduction::domains::circuits;
/// use programinduction::{lambda, ECParams, EC};
/// use rand::{rngs::SmallRng, SeedableRng};
///
/// let mut dsl = circuits::dsl();
///
/// let rng = &mut rand::thread_rng();
/// let rng = &mut SmallRng::from_seed([1u8; 32]);
/// let tasks = circuits::make_tasks(rng, 250);
/// let ec_params = ECParams {
/// frontier_limit: 10,
Expand Down
3 changes: 2 additions & 1 deletion src/lambda/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,10 @@ impl Language {
/// ```no_run
/// use programinduction::domains::circuits;
/// use programinduction::{lambda, ECParams, EC};
/// use rand::{rngs::SmallRng, SeedableRng};
///
/// let dsl = circuits::dsl();
/// let rng = &mut rand::thread_rng();
/// let rng = &mut SmallRng::from_seed([1u8; 32]);
/// let tasks = circuits::make_tasks(rng, 100);
/// let ec_params = ECParams {
/// frontier_limit: 10,
Expand Down
3 changes: 2 additions & 1 deletion src/pcfg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ impl Grammar {
/// ```
/// use polytype::tp;
/// use programinduction::pcfg::{Grammar, Rule};
/// use rand::{rngs::SmallRng, SeedableRng};
///
/// let g = Grammar::new(
/// tp!(EXPR),
Expand All @@ -211,7 +212,7 @@ impl Grammar {
/// Rule::new("plus", tp!(@arrow[tp!(EXPR), tp!(EXPR), tp!(EXPR)]), 1.0),
/// ],
/// );
/// let ar = g.sample(&tp!(EXPR), &mut rand::thread_rng());
/// let ar = g.sample(&tp!(EXPR), &mut SmallRng::from_seed([1u8; 32]));
/// assert_eq!(&ar.0, &tp!(EXPR));
/// println!("{}", g.display(&ar));
/// ```
Expand Down
3 changes: 2 additions & 1 deletion src/trs/lexicon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ impl Lexicon {
/// ```
/// use polytype::{ptp, tp, Context as TypeContext};
/// use programinduction::trs::Lexicon;
/// use rand::{rngs::SmallRng, SeedableRng};
///
/// let operators = vec![
/// (2, Some("PLUS".to_string()), ptp![@arrow[tp!(int), tp!(int), tp!(int)]]),
Expand All @@ -260,7 +261,7 @@ impl Lexicon {
/// let atom_weights = (0.5, 0.25, 0.25);
/// let max_size = 50;
///
/// let rng = &mut rand::thread_rng();
/// let rng = &mut SmallRng::from_seed([1u8; 32]);
/// let term = lexicon.sample_term(rng, &schema, &mut ctx, atom_weights, invent, variable, max_size).unwrap();
/// ```
///
Expand Down
38 changes: 19 additions & 19 deletions src/trs/rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,10 @@ impl TRS {
/// # Example
///
/// ```
/// # use polytype::{ptp, tp, Context as TypeContext};
/// # use programinduction::trs::{TRS, Lexicon};
/// # use rand::{thread_rng};
/// # use term_rewriting::{Context, RuleContext, Signature, parse_rule};
/// use polytype::{ptp, tp, Context as TypeContext};
/// use programinduction::trs::{TRS, Lexicon};
/// use rand::{rngs::SmallRng, SeedableRng};
/// use term_rewriting::{Context, RuleContext, Signature, parse_rule};
///
/// let mut sig = Signature::default();
///
Expand Down Expand Up @@ -213,11 +213,11 @@ impl TRS {
/// rhs: vec![Context::Hole],
/// }
/// ];
/// let mut rng = thread_rng();
/// let rng = &mut SmallRng::from_seed([1u8; 32]);
/// let atom_weights = (0.5, 0.25, 0.25);
/// let max_size = 50;
///
/// if let Ok(new_trs) = trs.add_rule(&contexts, atom_weights, max_size, &mut rng) {
/// if let Ok(new_trs) = trs.add_rule(&contexts, atom_weights, max_size, rng) {
/// assert_eq!(new_trs.len(), 3);
/// } else {
/// assert_eq!(trs.len(), 2);
Expand Down Expand Up @@ -270,7 +270,7 @@ impl TRS {
/// ```
/// use polytype::{ptp, tp, Context as TypeContext};
/// use programinduction::trs::{TRS, Lexicon};
/// use rand::{thread_rng};
/// use rand::{rngs::SmallRng, SeedableRng};
/// use term_rewriting::{Context, RuleContext, Signature, parse_rule};
///
/// let mut sig = Signature::default();
Expand Down Expand Up @@ -311,9 +311,9 @@ impl TRS {
///
/// let pretty_before = trs.to_string();
///
/// let mut rng = thread_rng();
/// let rng = &mut SmallRng::from_seed([1u8; 32]);
///
/// let new_trs = trs.randomly_move_rule(&mut rng).expect("failed when moving rule");
/// let new_trs = trs.randomly_move_rule(rng).expect("failed when moving rule");
///
/// assert_ne!(pretty_before, new_trs.to_string());
/// assert_eq!(new_trs.to_string(), "PLUS(x_ SUCC(y_)) = SUCC(PLUS(x_ y_));\nPLUS(x_ ZERO) = x_;");
Expand Down Expand Up @@ -343,7 +343,7 @@ impl TRS {
/// ```
/// use polytype::{ptp, tp, Context as TypeContext};
/// use programinduction::trs::{TRS, Lexicon};
/// use rand::{thread_rng};
/// use rand::{rngs::SmallRng, SeedableRng};
/// use term_rewriting::{Context, RuleContext, Signature, parse_rule};
///
/// let mut sig = Signature::default();
Expand Down Expand Up @@ -380,9 +380,9 @@ impl TRS {
///
/// assert_eq!(trs.len(), 1);
///
/// let mut rng = thread_rng();
/// let rng = &mut SmallRng::from_seed([1u8; 32]);
///
/// if let Ok(new_trs) = trs.local_difference(&mut rng) {
/// if let Ok(new_trs) = trs.local_difference(rng) {
/// assert_eq!(new_trs.len(), 2);
/// let display_str = format!("{}", new_trs);
/// assert_eq!(display_str, "PLUS(x_ SUCC(y_)) = SUCC(PLUS(x_ y_));\nSUCC(PLUS(x_ SUCC(y_))) = SUCC(SUCC(PLUS(x_ y_)));");
Expand Down Expand Up @@ -449,10 +449,10 @@ impl TRS {
/// # Example
///
/// ```
/// # use polytype::{ptp, tp, Context as TypeContext};
/// # use programinduction::trs::{TRS, Lexicon};
/// # use rand::{thread_rng};
/// # use term_rewriting::{Context, RuleContext, Signature, parse_rule};
/// use polytype::{ptp, tp, Context as TypeContext};
/// use programinduction::trs::{TRS, Lexicon};
/// use rand::{rngs::SmallRng, SeedableRng};
/// use term_rewriting::{Context, RuleContext, Signature, parse_rule};
///
/// let mut sig = Signature::default();
///
Expand Down Expand Up @@ -489,9 +489,9 @@ impl TRS {
///
/// assert_eq!(trs.len(), 1);
///
/// let mut rng = thread_rng();
/// let rng = &mut SmallRng::from_seed([1u8; 32]);
///
/// if let Ok(new_trs) = trs.swap_lhs_and_rhs(&mut rng) {
/// if let Ok(new_trs) = trs.swap_lhs_and_rhs(rng) {
/// assert_eq!(new_trs.len(), 2);
/// let display_str = format!("{}", new_trs);
/// assert_eq!(display_str, "SUCC(PLUS(x_ y_)) = PLUS(x_ SUCC(y_));\nPLUS(SUCC(x_) y_) = PLUS(x_ SUCC(y_));");
Expand Down Expand Up @@ -524,7 +524,7 @@ impl TRS {
///
/// let mut trs = TRS::new(&lexicon, rules, &lexicon.context()).unwrap();
///
/// assert!(trs.swap_lhs_and_rhs(&mut rng).is_err());
/// assert!(trs.swap_lhs_and_rhs(rng).is_err());
/// ```
pub fn swap_lhs_and_rhs<R: Rng>(&self, rng: &mut R) -> Result<TRS, SampleError> {
let num_rules = self.len();
Expand Down
5 changes: 3 additions & 2 deletions tests/ec.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use rand::{rngs::SmallRng, SeedableRng};
use std::time::Duration;

use polytype::{ptp, tp};
Expand All @@ -19,7 +20,7 @@ fn arith_evaluate(name: &str, inps: &[i32]) -> Result<i32, ()> {
#[ignore]
fn ec_circuits_dl() {
let dsl = circuits::dsl();
let rng = &mut rand::thread_rng();
let rng = &mut SmallRng::from_seed([1u8; 32]);
let tasks = circuits::make_tasks(rng, 100);
let ec_params = ECParams {
frontier_limit: 10,
Expand All @@ -35,7 +36,7 @@ fn ec_circuits_dl() {
#[test]
fn explore_circuits_timeout() {
let dsl = circuits::dsl();
let rng = &mut rand::thread_rng();
let rng = &mut SmallRng::from_seed([1u8; 32]);
let tasks = circuits::make_tasks(rng, 100);
let ec_params = ECParams {
frontier_limit: 10,
Expand Down
12 changes: 10 additions & 2 deletions tests/lambda.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,10 +248,18 @@ fn lambda_eval_somewhat_simple() {

let expr = dsl.parse("(λ (not (eq (+ 1 $0) 1)))").unwrap();

let out = dsl.eval(&expr, SimpleEvaluator::from(evaluate), &[ArithSpace::Num(1)]);
let out = dsl.eval(
&expr,
SimpleEvaluator::from(evaluate),
&[ArithSpace::Num(1)],
);
assert_eq!(out, Ok(ArithSpace::Bool(true)));

let out = dsl.eval(&expr, SimpleEvaluator::from(evaluate), &[ArithSpace::Num(0)]);
let out = dsl.eval(
&expr,
SimpleEvaluator::from(evaluate),
&[ArithSpace::Num(0)],
);
assert_eq!(out, Ok(ArithSpace::Bool(false)));
}

Expand Down

0 comments on commit 7f00043

Please sign in to comment.