From 7f00043d2f31f2087caa2f7d8452f744d280cbf1 Mon Sep 17 00:00:00 2001 From: Lore Anaya Pozo Date: Thu, 14 Dec 2023 06:30:36 -0500 Subject: [PATCH] move rand::thread_rng to seeded SmallRng --- src/domains/circuits.rs | 5 +++-- src/domains/strings.rs | 5 +++-- src/ec.rs | 6 ++++-- src/lambda/mod.rs | 3 ++- src/pcfg/mod.rs | 3 ++- src/trs/lexicon.rs | 3 ++- src/trs/rewrite.rs | 38 +++++++++++++++++++------------------- tests/ec.rs | 5 +++-- tests/lambda.rs | 12 ++++++++++-- 9 files changed, 48 insertions(+), 32 deletions(-) diff --git a/src/domains/circuits.rs b/src/domains/circuits.rs index 3ea489e..fea6026 100644 --- a/src/domains/circuits.rs +++ b/src/domains/circuits.rs @@ -7,9 +7,10 @@ //! ``` //! use programinduction::domains::circuits; //! use programinduction::{ECParams, EC}; +//! use rand::{rngs::SmallRng, SeedableRng}; //! //! let dsl = circuits::dsl(); -//! let rng = &mut rand::thread_rng(); +//! let rng = &mut SmallRng::from_seed([1u8; 32]); //! let tasks = circuits::make_tasks(rng, 250); //! let ec_params = ECParams { //! frontier_limit: 100, @@ -36,7 +37,7 @@ use crate::Task; /// The circuit representation, a [`lambda::Language`], only defines the binary `nand` operation. /// -/// ```ignore +/// ```compile_fails /// "nand": ptp!(@arrow[tp!(bool), tp!(bool), tp!(bool)]) /// ``` /// diff --git a/src/domains/strings.rs b/src/domains/strings.rs index 58efd86..4c82900 100644 --- a/src/domains/strings.rs +++ b/src/domains/strings.rs @@ -5,9 +5,10 @@ //! ```no_run //! use programinduction::domains::strings; //! use programinduction::{ECParams, EC}; +//! use rand::{rngs::SmallRng, SeedableRng}; //! //! let dsl = strings::dsl(); -//! let rng = &mut rand::thread_rng(); +//! let rng = &mut SmallRng::from_seed([1u8; 32]); //! let tasks = strings::make_tasks(rng, 250, 4); //! let ec_params = ECParams { //! frontier_limit: 10, @@ -33,7 +34,7 @@ use crate::Task; /// The string editing [`lambda::Language`] defines the following operations: /// -/// ```ignore +/// ```compile_fails /// "0": ptp!(int) /// "+1": ptp!(@arrow[tp!(int), tp!(int)]) /// "-1": ptp!(@arrow[tp!(int), tp!(int)]) diff --git a/src/ec.rs b/src/ec.rs index ad4d3be..eaf9372 100644 --- a/src/ec.rs +++ b/src/ec.rs @@ -44,10 +44,11 @@ pub struct ECParams { /// ```no_run /// use programinduction::domains::circuits; /// use programinduction::{lambda, ECParams, EC}; +/// use rand::{rngs::SmallRng, SeedableRng}; /// /// let mut dsl = circuits::dsl(); /// -/// let rng = &mut rand::thread_rng(); +/// let rng = &mut SmallRng::from_seed([1u8; 32]); /// let tasks = circuits::make_tasks(rng, 250); /// let ec_params = ECParams { /// frontier_limit: 10, @@ -122,10 +123,11 @@ pub trait EC: Sync + Sized { /// ```no_run /// use programinduction::domains::circuits; /// use programinduction::{lambda, ECParams, EC}; + /// use rand::{rngs::SmallRng, SeedableRng}; /// /// let mut dsl = circuits::dsl(); /// - /// let rng = &mut rand::thread_rng(); + /// let rng = &mut SmallRng::from_seed([1u8; 32]); /// let tasks = circuits::make_tasks(rng, 250); /// let ec_params = ECParams { /// frontier_limit: 10, diff --git a/src/lambda/mod.rs b/src/lambda/mod.rs index fba22d3..ce1fcdb 100644 --- a/src/lambda/mod.rs +++ b/src/lambda/mod.rs @@ -188,9 +188,10 @@ impl Language { /// ```no_run /// use programinduction::domains::circuits; /// use programinduction::{lambda, ECParams, EC}; + /// use rand::{rngs::SmallRng, SeedableRng}; /// /// let dsl = circuits::dsl(); - /// let rng = &mut rand::thread_rng(); + /// let rng = &mut SmallRng::from_seed([1u8; 32]); /// let tasks = circuits::make_tasks(rng, 100); /// let ec_params = ECParams { /// frontier_limit: 10, diff --git a/src/pcfg/mod.rs b/src/pcfg/mod.rs index 7fcce05..35d64dc 100644 --- a/src/pcfg/mod.rs +++ b/src/pcfg/mod.rs @@ -202,6 +202,7 @@ impl Grammar { /// ``` /// use polytype::tp; /// use programinduction::pcfg::{Grammar, Rule}; + /// use rand::{rngs::SmallRng, SeedableRng}; /// /// let g = Grammar::new( /// tp!(EXPR), @@ -211,7 +212,7 @@ impl Grammar { /// Rule::new("plus", tp!(@arrow[tp!(EXPR), tp!(EXPR), tp!(EXPR)]), 1.0), /// ], /// ); - /// let ar = g.sample(&tp!(EXPR), &mut rand::thread_rng()); + /// let ar = g.sample(&tp!(EXPR), &mut SmallRng::from_seed([1u8; 32])); /// assert_eq!(&ar.0, &tp!(EXPR)); /// println!("{}", g.display(&ar)); /// ``` diff --git a/src/trs/lexicon.rs b/src/trs/lexicon.rs index bec2f90..cd91237 100644 --- a/src/trs/lexicon.rs +++ b/src/trs/lexicon.rs @@ -244,6 +244,7 @@ impl Lexicon { /// ``` /// use polytype::{ptp, tp, Context as TypeContext}; /// use programinduction::trs::Lexicon; + /// use rand::{rngs::SmallRng, SeedableRng}; /// /// let operators = vec![ /// (2, Some("PLUS".to_string()), ptp![@arrow[tp!(int), tp!(int), tp!(int)]]), @@ -260,7 +261,7 @@ impl Lexicon { /// let atom_weights = (0.5, 0.25, 0.25); /// let max_size = 50; /// - /// let rng = &mut rand::thread_rng(); + /// let rng = &mut SmallRng::from_seed([1u8; 32]); /// let term = lexicon.sample_term(rng, &schema, &mut ctx, atom_weights, invent, variable, max_size).unwrap(); /// ``` /// diff --git a/src/trs/rewrite.rs b/src/trs/rewrite.rs index 2115f0b..498f032 100644 --- a/src/trs/rewrite.rs +++ b/src/trs/rewrite.rs @@ -167,10 +167,10 @@ impl TRS { /// # Example /// /// ``` - /// # use polytype::{ptp, tp, Context as TypeContext}; - /// # use programinduction::trs::{TRS, Lexicon}; - /// # use rand::{thread_rng}; - /// # use term_rewriting::{Context, RuleContext, Signature, parse_rule}; + /// use polytype::{ptp, tp, Context as TypeContext}; + /// use programinduction::trs::{TRS, Lexicon}; + /// use rand::{rngs::SmallRng, SeedableRng}; + /// use term_rewriting::{Context, RuleContext, Signature, parse_rule}; /// /// let mut sig = Signature::default(); /// @@ -213,11 +213,11 @@ impl TRS { /// rhs: vec![Context::Hole], /// } /// ]; - /// let mut rng = thread_rng(); + /// let rng = &mut SmallRng::from_seed([1u8; 32]); /// let atom_weights = (0.5, 0.25, 0.25); /// let max_size = 50; /// - /// if let Ok(new_trs) = trs.add_rule(&contexts, atom_weights, max_size, &mut rng) { + /// if let Ok(new_trs) = trs.add_rule(&contexts, atom_weights, max_size, rng) { /// assert_eq!(new_trs.len(), 3); /// } else { /// assert_eq!(trs.len(), 2); @@ -270,7 +270,7 @@ impl TRS { /// ``` /// use polytype::{ptp, tp, Context as TypeContext}; /// use programinduction::trs::{TRS, Lexicon}; - /// use rand::{thread_rng}; + /// use rand::{rngs::SmallRng, SeedableRng}; /// use term_rewriting::{Context, RuleContext, Signature, parse_rule}; /// /// let mut sig = Signature::default(); @@ -311,9 +311,9 @@ impl TRS { /// /// let pretty_before = trs.to_string(); /// - /// let mut rng = thread_rng(); + /// let rng = &mut SmallRng::from_seed([1u8; 32]); /// - /// let new_trs = trs.randomly_move_rule(&mut rng).expect("failed when moving rule"); + /// let new_trs = trs.randomly_move_rule(rng).expect("failed when moving rule"); /// /// assert_ne!(pretty_before, new_trs.to_string()); /// assert_eq!(new_trs.to_string(), "PLUS(x_ SUCC(y_)) = SUCC(PLUS(x_ y_));\nPLUS(x_ ZERO) = x_;"); @@ -343,7 +343,7 @@ impl TRS { /// ``` /// use polytype::{ptp, tp, Context as TypeContext}; /// use programinduction::trs::{TRS, Lexicon}; - /// use rand::{thread_rng}; + /// use rand::{rngs::SmallRng, SeedableRng}; /// use term_rewriting::{Context, RuleContext, Signature, parse_rule}; /// /// let mut sig = Signature::default(); @@ -380,9 +380,9 @@ impl TRS { /// /// assert_eq!(trs.len(), 1); /// - /// let mut rng = thread_rng(); + /// let rng = &mut SmallRng::from_seed([1u8; 32]); /// - /// if let Ok(new_trs) = trs.local_difference(&mut rng) { + /// if let Ok(new_trs) = trs.local_difference(rng) { /// assert_eq!(new_trs.len(), 2); /// let display_str = format!("{}", new_trs); /// assert_eq!(display_str, "PLUS(x_ SUCC(y_)) = SUCC(PLUS(x_ y_));\nSUCC(PLUS(x_ SUCC(y_))) = SUCC(SUCC(PLUS(x_ y_)));"); @@ -449,10 +449,10 @@ impl TRS { /// # Example /// /// ``` - /// # use polytype::{ptp, tp, Context as TypeContext}; - /// # use programinduction::trs::{TRS, Lexicon}; - /// # use rand::{thread_rng}; - /// # use term_rewriting::{Context, RuleContext, Signature, parse_rule}; + /// use polytype::{ptp, tp, Context as TypeContext}; + /// use programinduction::trs::{TRS, Lexicon}; + /// use rand::{rngs::SmallRng, SeedableRng}; + /// use term_rewriting::{Context, RuleContext, Signature, parse_rule}; /// /// let mut sig = Signature::default(); /// @@ -489,9 +489,9 @@ impl TRS { /// /// assert_eq!(trs.len(), 1); /// - /// let mut rng = thread_rng(); + /// let rng = &mut SmallRng::from_seed([1u8; 32]); /// - /// if let Ok(new_trs) = trs.swap_lhs_and_rhs(&mut rng) { + /// if let Ok(new_trs) = trs.swap_lhs_and_rhs(rng) { /// assert_eq!(new_trs.len(), 2); /// let display_str = format!("{}", new_trs); /// assert_eq!(display_str, "SUCC(PLUS(x_ y_)) = PLUS(x_ SUCC(y_));\nPLUS(SUCC(x_) y_) = PLUS(x_ SUCC(y_));"); @@ -524,7 +524,7 @@ impl TRS { /// /// let mut trs = TRS::new(&lexicon, rules, &lexicon.context()).unwrap(); /// - /// assert!(trs.swap_lhs_and_rhs(&mut rng).is_err()); + /// assert!(trs.swap_lhs_and_rhs(rng).is_err()); /// ``` pub fn swap_lhs_and_rhs(&self, rng: &mut R) -> Result { let num_rules = self.len(); diff --git a/tests/ec.rs b/tests/ec.rs index 9cbb752..17dc7c5 100644 --- a/tests/ec.rs +++ b/tests/ec.rs @@ -1,3 +1,4 @@ +use rand::{rngs::SmallRng, SeedableRng}; use std::time::Duration; use polytype::{ptp, tp}; @@ -19,7 +20,7 @@ fn arith_evaluate(name: &str, inps: &[i32]) -> Result { #[ignore] fn ec_circuits_dl() { let dsl = circuits::dsl(); - let rng = &mut rand::thread_rng(); + let rng = &mut SmallRng::from_seed([1u8; 32]); let tasks = circuits::make_tasks(rng, 100); let ec_params = ECParams { frontier_limit: 10, @@ -35,7 +36,7 @@ fn ec_circuits_dl() { #[test] fn explore_circuits_timeout() { let dsl = circuits::dsl(); - let rng = &mut rand::thread_rng(); + let rng = &mut SmallRng::from_seed([1u8; 32]); let tasks = circuits::make_tasks(rng, 100); let ec_params = ECParams { frontier_limit: 10, diff --git a/tests/lambda.rs b/tests/lambda.rs index 31fc5a6..9dc2e35 100644 --- a/tests/lambda.rs +++ b/tests/lambda.rs @@ -248,10 +248,18 @@ fn lambda_eval_somewhat_simple() { let expr = dsl.parse("(λ (not (eq (+ 1 $0) 1)))").unwrap(); - let out = dsl.eval(&expr, SimpleEvaluator::from(evaluate), &[ArithSpace::Num(1)]); + let out = dsl.eval( + &expr, + SimpleEvaluator::from(evaluate), + &[ArithSpace::Num(1)], + ); assert_eq!(out, Ok(ArithSpace::Bool(true))); - let out = dsl.eval(&expr, SimpleEvaluator::from(evaluate), &[ArithSpace::Num(0)]); + let out = dsl.eval( + &expr, + SimpleEvaluator::from(evaluate), + &[ArithSpace::Num(0)], + ); assert_eq!(out, Ok(ArithSpace::Bool(false))); }