From 5e90e1f0fbc18dfd2a7e960e330329634892cc6d Mon Sep 17 00:00:00 2001 From: Sandro Sgier Date: Wed, 21 Jun 2023 20:46:32 +0200 Subject: [PATCH] Add array type --- src/crossover.rs | 117 ++++++++++++++++++++++++++++++++++++++++++++++ src/error.rs | 2 + src/mutation.rs | 72 ++++++++++++++++++++++++++++ src/path.rs | 24 +++++++++- src/spec.rs | 27 +++++++++++ src/spec_util.rs | 111 ++++++++++++++++++++++++++++++++++++++----- src/testutil.rs | 31 ++++++++---- src/value.rs | 15 +++++- src/value_util.rs | 51 ++++++++++++++++++++ tests/checksum.rs | 9 +++- 10 files changed, 434 insertions(+), 25 deletions(-) diff --git a/src/crossover.rs b/src/crossover.rs index 0ed1313..0c9667a 100644 --- a/src/crossover.rs +++ b/src/crossover.rs @@ -145,6 +145,14 @@ where path_node_ctx, rng, ), + spec::Node::Array { value_type, size } => self.crossover_array( + value_type, + *size, + individuals_ordered, + crossover_params, + path_node_ctx, + rng, + ), spec::Node::AnonMap { value_type, min_size, @@ -286,6 +294,52 @@ where selected_keys } + fn crossover_array( + &self, + value_type: &spec::Node, + size: usize, + individuals_ordered: &[&value::Node], + crossover_params: &CrossoverParams, + path_node_ctx: &mut PathNodeContext, + rng: &mut StdRng, + ) -> value::Node { + let inner_vectors = individuals_ordered + .iter() + .map(|individual| { + extract_inner_vector_from_array(individual) + .iter() + .map(AsRef::as_ref) + .collect_vec() + }) + .collect_vec(); + + let result_elements = (0..size) + .map(|idx| { + let child_values = inner_vectors + .iter() + .map(|inner_vector| inner_vector[idx]) + .collect_vec(); + + let child_path_node_ctx = path_node_ctx.get_or_create_child_mut(&idx.to_string()); + + let child_crossover_params = child_path_node_ctx + .rescaling_ctx + .current_rescaling + .rescale_crossover(crossover_params); + + Box::new(self.do_crossover( + value_type, + &child_values, + &child_crossover_params, + child_path_node_ctx, + rng, + )) + }) + .collect(); + + value::Node::Array(result_elements) + } + fn crossover_anon_map( &self, value_type: &spec::Node, @@ -498,6 +552,14 @@ pub struct Crossover { selection: S, } +fn extract_inner_vector_from_array(value: &value::Node) -> &[Box] { + if let value::Node::Array(elements) = value { + elements + } else { + unreachable!() + } +} + fn extract_anon_map_inner(value: &value::Node) -> &HashMap> { if let value::Node::AnonMap(mapping) = value { mapping @@ -512,6 +574,7 @@ mod tests { use crate::path::testutil::set_rescaling_at_path; use crate::rescaling::{CrossoverRescaling, MutationRescaling, Rescaling}; use crate::spec_util; + use crate::testutil::extract_as_bool; use crate::testutil::extract_from_value; use crate::types::HashSet; use rand::SeedableRng; @@ -1203,6 +1266,60 @@ mod tests { } } + #[test] + fn array_crossover() { + let spec_str = " + type: array + size: 2 + valueType: + type: bool + init: true + "; + + let value0 = Value(value::Node::Array(vec![ + Box::new(value::Node::Bool(false)), + Box::new(value::Node::Bool(false)), + ])); + + let value1 = Value(value::Node::Array(vec![ + Box::new(value::Node::Bool(true)), + Box::new(value::Node::Bool(true)), + ])); + + let spec = spec_util::from_yaml_str(spec_str).unwrap(); + + let mut root_path_node_ctx = PathNodeContext::default(); + root_path_node_ctx.add_nodes_for(&value0.0); + root_path_node_ctx.add_nodes_for(&value1.0); + + set_rescaling_at_path( + &mut root_path_node_ctx, + &["0"], + make_rescaling(1.0, SELECT_1), + ); + + set_rescaling_at_path( + &mut root_path_node_ctx, + &["1"], + make_rescaling(1.0, SELECT_0), + ); + let sut = make_crossover_with_pressure_aware_selection(); + + let result = sut.crossover( + &spec, + &[&value0, &value1], + &ALWAYS_CROSSOVER_PARAMS, + &mut PathContext(root_path_node_ctx), + &mut make_rng(), + ); + + let value0 = extract_as_bool(&result, &["0"]).unwrap(); + let value1 = extract_as_bool(&result, &["1"]).unwrap(); + + assert!(value0); + assert!(!value1); + } + #[test] fn anon_map_crossover() { let spec_str = " diff --git a/src/error.rs b/src/error.rs index 02ab44d..e3c3850 100644 --- a/src/error.rs +++ b/src/error.rs @@ -44,6 +44,8 @@ pub enum Error { InvalidBounds { path_hint: String }, #[error("at path {path_hint:?}: min size must be lower than max size")] InvalidSizeBounds { path_hint: String }, + #[error("at path {path_hint:?}: array size must be strictly greater than 1")] + ArraySize { path_hint: String }, #[error("at path {path_hint:?}: max size must not be zero")] ZeroMaxSize { path_hint: String }, #[error("at path {path_hint:?}: mandatory attribute missing: {}", .missing_attribute_name)] diff --git a/src/mutation.rs b/src/mutation.rs index c2ee433..7e239b8 100644 --- a/src/mutation.rs +++ b/src/mutation.rs @@ -68,6 +68,9 @@ fn do_mutate( (spec::Node::Sub { map, .. }, value::Node::Sub(value_map)) => { mutate_sub(map, value_map, mutation_params, path_node_ctx, rng) } + (spec::Node::Array { value_type, .. }, value::Node::Array(elements)) => { + mutate_array(value_type, elements, mutation_params, path_node_ctx, rng) + } ( spec::Node::AnonMap { value_type, @@ -151,6 +154,36 @@ lazy_static! { static ref COIN_FLIP: Bernoulli = Bernoulli::new(0.5).unwrap(); } +fn mutate_array( + value_type: &spec::Node, + elements: &[Box], + mutation_params: &MutationParams, + path_node_ctx: &mut PathNodeContext, + rng: &mut StdRng, +) -> value::Node { + let result_elements = elements + .iter() + .enumerate() + .map(|(idx, element)| { + let child_path_node_ctx = path_node_ctx.get_or_create_child_mut(&idx.to_string()); + let child_mutation_params = child_path_node_ctx + .rescaling_ctx + .current_rescaling + .rescale_mutation(mutation_params); + + Box::new(do_mutate( + element, + value_type, + &child_mutation_params, + child_path_node_ctx, + rng, + )) + }) + .collect(); + + value::Node::Array(result_elements) +} + fn mutate_anon_map( value_type: &spec::Node, min_size: &Option, @@ -450,6 +483,7 @@ mod tests { use crate::rescaling::Rescaling; use crate::spec_util; use crate::testutil::extract_as_anon_map; + use crate::testutil::extract_as_array; use crate::testutil::extract_as_bool; use crate::testutil::extract_as_int; use crate::types::HashSet; @@ -1054,6 +1088,44 @@ mod tests { assert!(!extract_as_bool(&result, &["bool_a"]).unwrap()); } + #[test] + fn mutate_array() { + let spec_str = " + type: array + valueType: + type: bool + init: false + size: 2 + "; + + let value_str = r#" + [false, false] + "#; + + let spec = spec_util::from_yaml_str(spec_str).unwrap(); + let value = value_util::from_json_str(value_str, &spec).unwrap(); + + let mutation_params = MutationParams { + mutation_prob: 1.0, + mutation_scale: 10.0, + }; + + let mut rng = rng(); + let mut path_ctx = PathContext::default(); + path_ctx.0.add_nodes_for(&value.0); + + let rescaling = never_mutate_rescaling(); + set_rescaling_at_path(&mut path_ctx.0, &["1"], rescaling); + let result = mutate(&spec, &value, &mutation_params, &mut path_ctx, &mut rng); + + let result_elements = extract_as_array(&result, &[]).unwrap(); + + assert_eq!(result_elements.len(), 2); + + assert_eq!(result_elements[0], value::Node::Bool(true)); + assert_eq!(result_elements[1], value::Node::Bool(false)); + } + #[test] fn mutate_anon_map() { let spec_str = " diff --git a/src/path.rs b/src/path.rs index 38de088..77b7e67 100644 --- a/src/path.rs +++ b/src/path.rs @@ -1,7 +1,7 @@ use crate::rescaling::RescalingContext; +use crate::types::HashMap; use crate::value; use crate::value::Node::*; -use crate::types::HashMap; #[derive(Default)] pub struct PathContext(pub PathNodeContext); @@ -45,6 +45,12 @@ impl PathNodeContext { child_node.add_nodes_for(value); } } + Array(elemengs) => { + for (idx, value) in elemengs.iter().enumerate() { + let child_node = self.child_nodes.entry(idx.to_string()).or_default(); + child_node.add_nodes_for(value); + } + } AnonMap(mapping) => { for (key, value) in mapping { self.key_mgr.on_key_seen(*key); @@ -115,6 +121,13 @@ mod tests { "d".to_string(), Box::new(value::Node::Enum("foo".to_string())), ), + ( + "arr".to_string(), + Box::new(value::Node::Array(vec![ + Box::new(value::Node::Bool(true)), + Box::new(value::Node::Bool(false)), + ])), + ), ( "foo".to_string(), Box::new(value::Node::AnonMap(HashMap::from_iter([( @@ -140,12 +153,19 @@ mod tests { let mut sut = PathNodeContext::default(); sut.add_nodes_for(&value.0); - assert_eq!(sut.child_nodes.len(), 7); + assert_eq!(sut.child_nodes.len(), 8); for key in ["a", "b", "c", "d"] { let node = sut.get_child(key); assert!(node.child_nodes.is_empty()); } + let arr = sut.get_child("arr"); + assert_eq!(arr.child_nodes.len(), 2); + let arr_child_0 = arr.get_child("0"); + let arr_child_1 = arr.get_child("1"); + assert!(arr_child_0.child_nodes.is_empty()); + assert!(arr_child_1.child_nodes.is_empty()); + let foo = sut.get_child("foo"); assert_eq!(foo.child_nodes.len(), 1); let foo_child = foo.get_child("4"); diff --git a/src/spec.rs b/src/spec.rs index e14dceb..93f66da 100644 --- a/src/spec.rs +++ b/src/spec.rs @@ -2,6 +2,7 @@ use crate::types::HashMap; use crate::value; use crate::value::Value; use serde::{Deserialize, Serialize}; +use std::iter; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct Spec(pub Node); @@ -26,6 +27,10 @@ pub enum Node { Sub { map: HashMap>, }, + Array { + value_type: Box, + size: usize, + }, AnonMap { value_type: Box, init_size: usize, @@ -59,6 +64,15 @@ impl Node { Node::Real { init, .. } => value::Node::Real(*init), Node::Int { init, .. } => value::Node::Int(*init), Node::Bool { init } => value::Node::Bool(*init), + Node::Array { value_type, size } => { + let init_val = value_type.initial_value(); + + value::Node::Array( + iter::repeat_with(|| Box::new(init_val.clone())) + .take(*size) + .collect(), + ) + } Node::AnonMap { value_type, init_size, @@ -142,6 +156,12 @@ mod tests { bar: type: bool init: false + e: + type: array + size: 2 + valueType: + type: bool + init: true "; let expected_init_val = Value(Node::Sub(HashMap::from_iter([ @@ -158,6 +178,13 @@ mod tests { "d".to_string(), Box::new(Node::Variant("foo".to_string(), Box::new(Node::Const))), ), + ( + "e".to_string(), + Box::new(Node::Array(vec![ + Box::new(Node::Bool(true)), + Box::new(Node::Bool(true)), + ])), + ), ]))); let init_val = from_yaml_str(spec_str).unwrap().initial_value(); diff --git a/src/spec_util.rs b/src/spec_util.rs index bd83339..06233a4 100644 --- a/src/spec_util.rs +++ b/src/spec_util.rs @@ -17,9 +17,11 @@ pub fn from_yaml_str(yaml_str: &str) -> Result { pub fn is_leaf(spec_node: &Node) -> bool { match spec_node { - Node::Sub { .. } | Node::AnonMap { .. } | Node::Variant { .. } | Node::Optional { .. } => { - false - } + Node::Sub { .. } + | Node::Array { .. } + | Node::AnonMap { .. } + | Node::Variant { .. } + | Node::Optional { .. } => false, Node::Bool { .. } | Node::Real { .. } | Node::Int { .. } @@ -29,7 +31,7 @@ pub fn is_leaf(spec_node: &Node) -> bool { } const BUILT_IN_TYPE_NAMES: &[&str] = &[ - "real", "int", "bool", "sub", "anon map", "variant", "enum", "optional", "const", + "real", "int", "bool", "sub", "array", "anon map", "variant", "enum", "optional", "const", ]; fn build_node( @@ -54,6 +56,7 @@ fn build_node( "int" => build_int(mapping, path), "bool" => build_bool(mapping, path), "sub" => build_sub(mapping, type_defs, path), + "array" => build_array(mapping, type_defs, path), "anon map" => build_anon_map(mapping, type_defs, path), "variant" => build_variant(mapping, type_defs, path), "enum" => build_enum(mapping, path), @@ -273,15 +276,33 @@ fn extract_value_type_attr_value( type_defs: &HashMap, path: &[&str], ) -> Result, Error> { - return match mapping.get("valueType") { + match mapping.get("valueType") { Some(value) => Ok(Box::new(build_node(value, type_defs, path)?)), - None => { - return Err(Error::MandatoryAttributeMissing { - path_hint: format_path(path), - missing_attribute_name: "valueType".to_string(), - }); - } - }; + None => Err(Error::MandatoryAttributeMissing { + path_hint: format_path(path), + missing_attribute_name: "valueType".to_string(), + }), + } +} + +fn build_array( + mapping: &serde_yaml::Mapping, + type_defs: &HashMap, + path: &[&str], +) -> Result { + check_for_unexpected_attributes(mapping, ["type", "size", "valueType"], path)?; + + let value_type = extract_value_type_attr_value(mapping, type_defs, path)?; + + let size = extract_usize_attribute_value(mapping, "size", path, true)?.unwrap(); + + if size < 2 { + Err(Error::ArraySize { + path_hint: format_path(path), + }) + } else { + Ok(Node::Array { value_type, size }) + } } fn build_anon_map( @@ -949,6 +970,72 @@ mod tests { )); } + #[test] + fn array() { + let yaml_str = " + type: array + valueType: + type: bool + init: false + size: 2 + "; + + assert!(matches!( + from_yaml_str(yaml_str), + Ok(Spec(Node::Array { + value_type, + size: 2, + })) if *value_type.as_ref() == Node::Bool {init: false} + )); + } + + #[test] + fn array_zero_size() { + let yaml_str = " + type: array + valueType: + type: bool + init: false + size: 1 + "; + + assert!(matches!( + from_yaml_str(yaml_str), + Err(Error::ArraySize { path_hint }) + if path_hint == "(root)" + )); + } + + #[test] + fn array_missing_value_type() { + let yaml_str = " + type: array + size: 4 + "; + + assert!(matches!( + from_yaml_str(yaml_str), + Err(Error::MandatoryAttributeMissing { path_hint, missing_attribute_name }) + if path_hint == "(root)" && missing_attribute_name == "valueType" + )); + } + + #[test] + fn array_missing_size() { + let yaml_str = " + type: array + valueType: + type: bool + init: false + "; + + assert!(matches!( + from_yaml_str(yaml_str), + Err(Error::MandatoryAttributeMissing { path_hint, missing_attribute_name }) + if path_hint == "(root)" && missing_attribute_name == "size" + )); + } + #[test] fn anon_map() { let yaml_str = " diff --git a/src/testutil.rs b/src/testutil.rs index 5b76baa..2d5ef13 100644 --- a/src/testutil.rs +++ b/src/testutil.rs @@ -2,7 +2,7 @@ use crate::types::HashMap; use crate::value::Node::{self, *}; use crate::value::Value; -pub fn extract_as_real<'a>(value: &'a Value, path: &[&str]) -> Option { +pub fn extract_as_real(value: &Value, path: &[&str]) -> Option { extract_from_node(Some(&value.0), path).map(|node| { if let Node::Real(value) = node { *value @@ -12,7 +12,7 @@ pub fn extract_as_real<'a>(value: &'a Value, path: &[&str]) -> Option { }) } -pub fn extract_as_int<'a>(value: &'a Value, path: &[&str]) -> Option { +pub fn extract_as_int(value: &Value, path: &[&str]) -> Option { extract_from_node(Some(&value.0), path).map(|node| { if let Node::Int(value) = node { *value @@ -22,7 +22,7 @@ pub fn extract_as_int<'a>(value: &'a Value, path: &[&str]) -> Option { }) } -pub fn extract_as_bool<'a>(value: &'a Value, path: &[&str]) -> Option { +pub fn extract_as_bool(value: &Value, path: &[&str]) -> Option { extract_from_node(Some(&value.0), path).map(|node| { if let Node::Bool(value) = node { *value @@ -32,15 +32,22 @@ pub fn extract_as_bool<'a>(value: &'a Value, path: &[&str]) -> Option { }) } -pub fn extract_as_anon_map<'a>( - value: &'a Value, - path: &[&str], -) -> Option>> { +pub fn extract_as_array(value: &Value, path: &[&str]) -> Option> { + extract_from_node(Some(&value.0), path).map(|node| { + if let Node::Array(elements) = node { + elements.iter().map(|elem| *elem.clone()).collect() + } else { + panic!("Node is not of type array") + } + }) +} + +pub fn extract_as_anon_map(value: &Value, path: &[&str]) -> Option>> { extract_from_node(Some(&value.0), path).map(|node| { if let Node::AnonMap(mapping) = node { mapping.clone() } else { - panic!("Node is not of type bool") + panic!("Node is not of type anon map") } }) } @@ -53,9 +60,15 @@ pub fn extract_from_node<'a>(node: Option<&'a Node>, path: &[&str]) -> Option<&' node.and_then(|node| match path.first() { Some(head) => match node { Sub(mapping) => extract_from_node(mapping.get(*head).map(Box::as_ref), &path[1..]), + Array(elements) => extract_from_node( + elements + .get(str::parse::(head).expect("Invalid path")) + .map(Box::as_ref), + &path[1..], + ), AnonMap(mapping) => extract_from_node( mapping - .get(&str::parse(*head).expect("Invalid path")) + .get(&str::parse(head).expect("Invalid path")) .map(Box::as_ref), &path[1..], ), diff --git a/src/value.rs b/src/value.rs index 848b271..99600a3 100644 --- a/src/value.rs +++ b/src/value.rs @@ -13,6 +13,7 @@ pub enum Node { Int(i64), Bool(bool), Sub(HashMap>), + Array(Vec>), AnonMap(HashMap>), Variant(String, Box), Enum(String), @@ -32,6 +33,7 @@ impl Node { Node::Real(number) => serde_json::Value::Number(Number::from_f64(*number).unwrap()), Node::Int(number) => serde_json::Value::Number(Number::from(*number)), Node::Bool(val) => serde_json::Value::Bool(*val), + Node::Array(elements) => Self::map_to_json_array(elements), Node::AnonMap(mapping) => Self::map_to_json_obj(mapping), Node::Sub(mapping) => Self::map_to_json_obj(mapping), Node::Variant(variant_name, value) => { @@ -59,6 +61,10 @@ impl Node { } serde_json::Value::Object(out_mapping) } + + fn map_to_json_array(elements: &[Box]) -> serde_json::Value { + serde_json::Value::Array(elements.iter().map(|elem| elem.to_json()).collect()) + } } #[cfg(test)] @@ -79,9 +85,16 @@ mod tests { Box::new(Node::Bool(true)), )]))), ), + ( + "d".to_string(), + Box::new(Node::Array(vec![ + Box::new(Node::Int(1)), + Box::new(Node::Int(3)), + ])), + ), ]))); - let expect_json_text = r#"{"a":2.0,"b":1,"c":{"0":true}}"#; + let expect_json_text = r#"{"a":2.0,"b":1,"c":{"0":true},"d":[1,3]}"#; assert_eq!(value.to_json().to_string(), expect_json_text); } } diff --git a/src/value_util.rs b/src/value_util.rs index 2009510..d5bbaf5 100644 --- a/src/value_util.rs +++ b/src/value_util.rs @@ -25,6 +25,7 @@ fn build_node( spec::Node::Int { min, max, .. } => build_int(json_val, min, max, path), spec::Node::Bool { .. } => build_bool(json_val, path), spec::Node::Sub { map: ref spec_map } => build_sub(json_val, spec_map, path), + spec::Node::Array { ref value_type, .. } => build_array(json_val, value_type, path), spec::Node::AnonMap { ref value_type, .. } => build_anon_map(json_val, value_type, path), spec::Node::Variant { map: ref spec_map, .. @@ -161,6 +162,31 @@ fn build_sub( } } +fn build_array( + json_val: &serde_json::Value, + spec_node: &spec::Node, + path: &[&str], +) -> Result { + match json_val { + serde_json::Value::Array(json_values) => { + let mut elements = Vec::with_capacity(json_values.len()); + + for (idx, json_value) in json_values.iter().enumerate() { + let path_item = idx.to_string(); + let path_of_sub = [path, &[&path_item]].concat(); + let value = build_node(json_value, spec_node, &path_of_sub)?; + elements.push(Box::new(value)); + } + + Ok(Node::Array(elements)) + } + _ => Err(Error::WrongTypeForValue { + path_hint: format_path(path), + type_hint: "array".to_string(), + }), + } +} + fn build_anon_map( json_val: &serde_json::Value, spec_node: &spec::Node, @@ -867,6 +893,31 @@ mod tests { if path_hint.as_str() == "foo")); } + #[test] + fn array() { + let spec_str = " + type: array + size: 2 + valueType: + type: bool + init: false + "; + let value_str = " + [true, true] + "; + + let spec = spec_util::from_yaml_str(spec_str).unwrap(); + let result = from_json_str(value_str, &spec); + + assert_eq!( + result.unwrap(), + Value(Node::Array(vec![ + Box::new(Node::Bool(true)), + Box::new(Node::Bool(true)) + ])) + ); + } + #[test] fn anon_map_no_keys() { let spec_str = " diff --git a/tests/checksum.rs b/tests/checksum.rs index a88c84d..dcb6b9d 100644 --- a/tests/checksum.rs +++ b/tests/checksum.rs @@ -5,7 +5,7 @@ use std::collections::hash_map::DefaultHasher; use std::hash::{Hash, Hasher}; // not constant across targets -const CHECKSUM: u64 = 15057406331161146692; +const CHECKSUM: u64 = 16590185745624415069; fn compute_hash(value: &serde_json::Value) -> u64 { let mut hasher = DefaultHasher::new(); @@ -55,6 +55,13 @@ fn checksum() { init: false baz: type: const + arr: + type: array + size: 2 + valueType: + type: int + init: 2 + scale: 4 "; let spec = spec_util::from_yaml_str(spec_str).unwrap();