Skip to content

Commit

Permalink
Add array type
Browse files Browse the repository at this point in the history
  • Loading branch information
ssgier committed Jun 21, 2023
1 parent 325ddf0 commit 5e90e1f
Show file tree
Hide file tree
Showing 10 changed files with 434 additions and 25 deletions.
117 changes: 117 additions & 0 deletions src/crossover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,14 @@ where
path_node_ctx,
rng,
),
spec::Node::Array { value_type, size } => self.crossover_array(
value_type,
*size,
individuals_ordered,
crossover_params,
path_node_ctx,
rng,
),
spec::Node::AnonMap {
value_type,
min_size,
Expand Down Expand Up @@ -286,6 +294,52 @@ where
selected_keys
}

fn crossover_array(
&self,
value_type: &spec::Node,
size: usize,
individuals_ordered: &[&value::Node],
crossover_params: &CrossoverParams,
path_node_ctx: &mut PathNodeContext,
rng: &mut StdRng,
) -> value::Node {
let inner_vectors = individuals_ordered
.iter()
.map(|individual| {
extract_inner_vector_from_array(individual)
.iter()
.map(AsRef::as_ref)
.collect_vec()
})
.collect_vec();

let result_elements = (0..size)
.map(|idx| {
let child_values = inner_vectors
.iter()
.map(|inner_vector| inner_vector[idx])
.collect_vec();

let child_path_node_ctx = path_node_ctx.get_or_create_child_mut(&idx.to_string());

let child_crossover_params = child_path_node_ctx
.rescaling_ctx
.current_rescaling
.rescale_crossover(crossover_params);

Box::new(self.do_crossover(
value_type,
&child_values,
&child_crossover_params,
child_path_node_ctx,
rng,
))
})
.collect();

value::Node::Array(result_elements)
}

fn crossover_anon_map(
&self,
value_type: &spec::Node,
Expand Down Expand Up @@ -498,6 +552,14 @@ pub struct Crossover<S: Selection = SelectionImpl> {
selection: S,
}

fn extract_inner_vector_from_array(value: &value::Node) -> &[Box<value::Node>] {
if let value::Node::Array(elements) = value {
elements
} else {
unreachable!()
}
}

fn extract_anon_map_inner(value: &value::Node) -> &HashMap<usize, Box<value::Node>> {
if let value::Node::AnonMap(mapping) = value {
mapping
Expand All @@ -512,6 +574,7 @@ mod tests {
use crate::path::testutil::set_rescaling_at_path;
use crate::rescaling::{CrossoverRescaling, MutationRescaling, Rescaling};
use crate::spec_util;
use crate::testutil::extract_as_bool;
use crate::testutil::extract_from_value;
use crate::types::HashSet;
use rand::SeedableRng;
Expand Down Expand Up @@ -1203,6 +1266,60 @@ mod tests {
}
}

#[test]
fn array_crossover() {
let spec_str = "
type: array
size: 2
valueType:
type: bool
init: true
";

let value0 = Value(value::Node::Array(vec![
Box::new(value::Node::Bool(false)),
Box::new(value::Node::Bool(false)),
]));

let value1 = Value(value::Node::Array(vec![
Box::new(value::Node::Bool(true)),
Box::new(value::Node::Bool(true)),
]));

let spec = spec_util::from_yaml_str(spec_str).unwrap();

let mut root_path_node_ctx = PathNodeContext::default();
root_path_node_ctx.add_nodes_for(&value0.0);
root_path_node_ctx.add_nodes_for(&value1.0);

set_rescaling_at_path(
&mut root_path_node_ctx,
&["0"],
make_rescaling(1.0, SELECT_1),
);

set_rescaling_at_path(
&mut root_path_node_ctx,
&["1"],
make_rescaling(1.0, SELECT_0),
);
let sut = make_crossover_with_pressure_aware_selection();

let result = sut.crossover(
&spec,
&[&value0, &value1],
&ALWAYS_CROSSOVER_PARAMS,
&mut PathContext(root_path_node_ctx),
&mut make_rng(),
);

let value0 = extract_as_bool(&result, &["0"]).unwrap();
let value1 = extract_as_bool(&result, &["1"]).unwrap();

assert!(value0);
assert!(!value1);
}

#[test]
fn anon_map_crossover() {
let spec_str = "
Expand Down
2 changes: 2 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ pub enum Error {
InvalidBounds { path_hint: String },
#[error("at path {path_hint:?}: min size must be lower than max size")]
InvalidSizeBounds { path_hint: String },
#[error("at path {path_hint:?}: array size must be strictly greater than 1")]
ArraySize { path_hint: String },
#[error("at path {path_hint:?}: max size must not be zero")]
ZeroMaxSize { path_hint: String },
#[error("at path {path_hint:?}: mandatory attribute missing: {}", .missing_attribute_name)]
Expand Down
72 changes: 72 additions & 0 deletions src/mutation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ fn do_mutate(
(spec::Node::Sub { map, .. }, value::Node::Sub(value_map)) => {
mutate_sub(map, value_map, mutation_params, path_node_ctx, rng)
}
(spec::Node::Array { value_type, .. }, value::Node::Array(elements)) => {
mutate_array(value_type, elements, mutation_params, path_node_ctx, rng)
}
(
spec::Node::AnonMap {
value_type,
Expand Down Expand Up @@ -151,6 +154,36 @@ lazy_static! {
static ref COIN_FLIP: Bernoulli = Bernoulli::new(0.5).unwrap();
}

fn mutate_array(
value_type: &spec::Node,
elements: &[Box<value::Node>],
mutation_params: &MutationParams,
path_node_ctx: &mut PathNodeContext,
rng: &mut StdRng,
) -> value::Node {
let result_elements = elements
.iter()
.enumerate()
.map(|(idx, element)| {
let child_path_node_ctx = path_node_ctx.get_or_create_child_mut(&idx.to_string());
let child_mutation_params = child_path_node_ctx
.rescaling_ctx
.current_rescaling
.rescale_mutation(mutation_params);

Box::new(do_mutate(
element,
value_type,
&child_mutation_params,
child_path_node_ctx,
rng,
))
})
.collect();

value::Node::Array(result_elements)
}

fn mutate_anon_map(
value_type: &spec::Node,
min_size: &Option<usize>,
Expand Down Expand Up @@ -450,6 +483,7 @@ mod tests {
use crate::rescaling::Rescaling;
use crate::spec_util;
use crate::testutil::extract_as_anon_map;
use crate::testutil::extract_as_array;
use crate::testutil::extract_as_bool;
use crate::testutil::extract_as_int;
use crate::types::HashSet;
Expand Down Expand Up @@ -1054,6 +1088,44 @@ mod tests {
assert!(!extract_as_bool(&result, &["bool_a"]).unwrap());
}

#[test]
fn mutate_array() {
let spec_str = "
type: array
valueType:
type: bool
init: false
size: 2
";

let value_str = r#"
[false, false]
"#;

let spec = spec_util::from_yaml_str(spec_str).unwrap();
let value = value_util::from_json_str(value_str, &spec).unwrap();

let mutation_params = MutationParams {
mutation_prob: 1.0,
mutation_scale: 10.0,
};

let mut rng = rng();
let mut path_ctx = PathContext::default();
path_ctx.0.add_nodes_for(&value.0);

let rescaling = never_mutate_rescaling();
set_rescaling_at_path(&mut path_ctx.0, &["1"], rescaling);
let result = mutate(&spec, &value, &mutation_params, &mut path_ctx, &mut rng);

let result_elements = extract_as_array(&result, &[]).unwrap();

assert_eq!(result_elements.len(), 2);

assert_eq!(result_elements[0], value::Node::Bool(true));
assert_eq!(result_elements[1], value::Node::Bool(false));
}

#[test]
fn mutate_anon_map() {
let spec_str = "
Expand Down
24 changes: 22 additions & 2 deletions src/path.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::rescaling::RescalingContext;
use crate::types::HashMap;
use crate::value;
use crate::value::Node::*;
use crate::types::HashMap;

#[derive(Default)]
pub struct PathContext(pub PathNodeContext);
Expand Down Expand Up @@ -45,6 +45,12 @@ impl PathNodeContext {
child_node.add_nodes_for(value);
}
}
Array(elemengs) => {
for (idx, value) in elemengs.iter().enumerate() {
let child_node = self.child_nodes.entry(idx.to_string()).or_default();
child_node.add_nodes_for(value);
}
}
AnonMap(mapping) => {
for (key, value) in mapping {
self.key_mgr.on_key_seen(*key);
Expand Down Expand Up @@ -115,6 +121,13 @@ mod tests {
"d".to_string(),
Box::new(value::Node::Enum("foo".to_string())),
),
(
"arr".to_string(),
Box::new(value::Node::Array(vec![
Box::new(value::Node::Bool(true)),
Box::new(value::Node::Bool(false)),
])),
),
(
"foo".to_string(),
Box::new(value::Node::AnonMap(HashMap::from_iter([(
Expand All @@ -140,12 +153,19 @@ mod tests {
let mut sut = PathNodeContext::default();
sut.add_nodes_for(&value.0);

assert_eq!(sut.child_nodes.len(), 7);
assert_eq!(sut.child_nodes.len(), 8);
for key in ["a", "b", "c", "d"] {
let node = sut.get_child(key);
assert!(node.child_nodes.is_empty());
}

let arr = sut.get_child("arr");
assert_eq!(arr.child_nodes.len(), 2);
let arr_child_0 = arr.get_child("0");
let arr_child_1 = arr.get_child("1");
assert!(arr_child_0.child_nodes.is_empty());
assert!(arr_child_1.child_nodes.is_empty());

let foo = sut.get_child("foo");
assert_eq!(foo.child_nodes.len(), 1);
let foo_child = foo.get_child("4");
Expand Down
27 changes: 27 additions & 0 deletions src/spec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::types::HashMap;
use crate::value;
use crate::value::Value;
use serde::{Deserialize, Serialize};
use std::iter;

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Spec(pub Node);
Expand All @@ -26,6 +27,10 @@ pub enum Node {
Sub {
map: HashMap<String, Box<Node>>,
},
Array {
value_type: Box<Node>,
size: usize,
},
AnonMap {
value_type: Box<Node>,
init_size: usize,
Expand Down Expand Up @@ -59,6 +64,15 @@ impl Node {
Node::Real { init, .. } => value::Node::Real(*init),
Node::Int { init, .. } => value::Node::Int(*init),
Node::Bool { init } => value::Node::Bool(*init),
Node::Array { value_type, size } => {
let init_val = value_type.initial_value();

value::Node::Array(
iter::repeat_with(|| Box::new(init_val.clone()))
.take(*size)
.collect(),
)
}
Node::AnonMap {
value_type,
init_size,
Expand Down Expand Up @@ -142,6 +156,12 @@ mod tests {
bar:
type: bool
init: false
e:
type: array
size: 2
valueType:
type: bool
init: true
";

let expected_init_val = Value(Node::Sub(HashMap::from_iter([
Expand All @@ -158,6 +178,13 @@ mod tests {
"d".to_string(),
Box::new(Node::Variant("foo".to_string(), Box::new(Node::Const))),
),
(
"e".to_string(),
Box::new(Node::Array(vec![
Box::new(Node::Bool(true)),
Box::new(Node::Bool(true)),
])),
),
])));

let init_val = from_yaml_str(spec_str).unwrap().initial_value();
Expand Down
Loading

0 comments on commit 5e90e1f

Please sign in to comment.