From 8f494ab40fb09af0f036835d7107b004c02c9e9c Mon Sep 17 00:00:00 2001 From: Devin Jean Date: Wed, 29 Nov 2023 11:10:17 -0600 Subject: [PATCH] unary op preserve tensor topology --- Cargo.toml | 2 +- src/process.rs | 85 ++++++++++---------- src/test/blocks/preserve-tensor-topology.xml | 1 + src/test/process.rs | 22 +++++ 4 files changed, 68 insertions(+), 42 deletions(-) create mode 100644 src/test/blocks/preserve-tensor-topology.xml diff --git a/Cargo.toml b/Cargo.toml index ab6a28d..2297be5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "netsblox-vm" -version = "0.3.3" +version = "0.3.4" edition = "2021" license = "MIT OR Apache-2.0" authors = ["Devin Jean "] diff --git a/src/process.rs b/src/process.rs index c6fb9b1..c0bc5cf 100644 --- a/src/process.rs +++ b/src/process.rs @@ -1407,6 +1407,12 @@ impl<'gc, C: CustomTypes, S: System> Process<'gc, C, S> { mod ops { use super::*; + #[derive(Clone, Copy, PartialEq, Eq)] + enum OpType { + Deterministic, + Nondeterministic, + } + fn as_list<'gc, C: CustomTypes, S: System>(v: &Value<'gc, C, S>) -> Option>>>> { v.as_list().ok() } @@ -1433,13 +1439,9 @@ mod ops { pub(super) fn prep_index_set<'gc, C: CustomTypes, S: System>(index: &Value<'gc, C, S>, len: usize) -> Result, ErrorCause> { fn set_impl<'gc, C: CustomTypes, S: System>(index: &Value<'gc, C, S>, len: usize, dest: &mut BTreeSet, cache: &mut BTreeSet>) -> Result<(), ErrorCause> { match index { - Value::List(values) => { - let key = index.identity(); - if cache.insert(key) { - for value in values.borrow().iter() { - set_impl(value, len, dest, cache)?; - } - cache.remove(&key); + Value::List(values) => if cache.insert(index.identity()) { + for value in values.borrow().iter() { + set_impl(value, len, dest, cache)?; } } _ => { @@ -1449,9 +1451,7 @@ mod ops { Ok(()) } let mut res = Default::default(); - let mut cache = Default::default(); - set_impl(index, len, &mut res, &mut cache)?; - debug_assert_eq!(cache.len(), 0); + set_impl(index, len, &mut res, &mut Default::default())?; Ok(res) } @@ -1752,7 +1752,7 @@ mod ops { } } - fn unary_op_impl<'gc, C: CustomTypes, S: System>(mc: &Mutation<'gc>, system: &S, x: &Value<'gc, C, S>, cache: &mut BTreeMap, Value<'gc, C, S>>, scalar_op: &dyn Fn(&Mutation<'gc>, &S, &Value<'gc, C, S>) -> Result, ErrorCause>) -> Result, ErrorCause> { + fn unary_op_impl<'gc, C: CustomTypes, S: System>(mc: &Mutation<'gc>, system: &S, x: &Value<'gc, C, S>, cache: &mut BTreeMap, Value<'gc, C, S>>, op_type: OpType, scalar_op: &dyn Fn(&Mutation<'gc>, &S, &Value<'gc, C, S>) -> Result, ErrorCause>) -> Result, ErrorCause> { let cache_key = x.identity(); Ok(match cache.get(&cache_key) { Some(x) => x.clone(), @@ -1764,9 +1764,12 @@ mod ops { let res = as_list(&real_res).unwrap(); let mut res = res.borrow_mut(mc); for x in &*x { - res.push_back(unary_op_impl(mc, system, x, cache, scalar_op)?); + res.push_back(unary_op_impl(mc, system, x, cache, op_type, scalar_op)?); + } + match op_type { + OpType::Deterministic => (), + OpType::Nondeterministic => { cache.remove(&cache_key); } } - cache.remove(&cache_key); real_res } None => scalar_op(mc, system, x)?, @@ -1776,52 +1779,52 @@ mod ops { pub(super) fn unary_op<'gc, C: CustomTypes, S: System>(mc: &Mutation<'gc>, system: &S, x: &Value<'gc, C, S>, op: UnaryOp) -> Result, ErrorCause> { let mut cache = Default::default(); match op { - UnaryOp::ToNumber => unary_op_impl(mc, system, x, &mut cache, &|_, _, x| Ok(x.as_number()?.into())), - UnaryOp::Not => unary_op_impl(mc, system, x, &mut cache, &|_, _, x| Ok((!x.as_bool()?).into())), - UnaryOp::Abs => unary_op_impl(mc, system, x, &mut cache, &|_, _, x| Ok(x.as_number()?.abs()?.into())), - UnaryOp::Neg => unary_op_impl(mc, system, x, &mut cache, &|_, _, x| Ok(x.as_number()?.neg()?.into())), - UnaryOp::Sqrt => unary_op_impl(mc, system, x, &mut cache, &|_, _, x| Ok(x.as_number()?.sqrt()?.into())), - UnaryOp::Round => unary_op_impl(mc, system, x, &mut cache, &|_, _, x| Ok(x.as_number()?.round()?.into())), - UnaryOp::Floor => unary_op_impl(mc, system, x, &mut cache, &|_, _, x| Ok(x.as_number()?.floor()?.into())), - UnaryOp::Ceil => unary_op_impl(mc, system, x, &mut cache, &|_, _, x| Ok(x.as_number()?.ceil()?.into())), - UnaryOp::Sin => unary_op_impl(mc, system, x, &mut cache, &|_, _, x| Ok(Number::new(libm::sin(x.as_number()?.get().to_radians()))?.into())), - UnaryOp::Cos => unary_op_impl(mc, system, x, &mut cache, &|_, _, x| Ok(Number::new(libm::cos(x.as_number()?.get().to_radians()))?.into())), - UnaryOp::Tan => unary_op_impl(mc, system, x, &mut cache, &|_, _, x| Ok(Number::new(libm::tan(x.as_number()?.get().to_radians()))?.into())), - UnaryOp::Asin => unary_op_impl(mc, system, x, &mut cache, &|_, _, x| Ok(Number::new(libm::asin(x.as_number()?.get()).to_degrees())?.into())), - UnaryOp::Acos => unary_op_impl(mc, system, x, &mut cache, &|_, _, x| Ok(Number::new(libm::acos(x.as_number()?.get()).to_degrees())?.into())), - UnaryOp::Atan => unary_op_impl(mc, system, x, &mut cache, &|_, _, x| Ok(Number::new(libm::atan(x.as_number()?.get()).to_degrees())?.into())), - UnaryOp::StrLen => unary_op_impl(mc, system, x, &mut cache, &|_, _, x| Ok(Number::new(x.as_string()?.chars().count() as f64)?.into())), - - UnaryOp::StrGetLast => unary_op_impl(mc, system, x, &mut cache, &|_, _, x| match x.as_string()?.chars().next_back() { + UnaryOp::ToNumber => unary_op_impl(mc, system, x, &mut cache, OpType::Deterministic, &|_, _, x| Ok(x.as_number()?.into())), + UnaryOp::Not => unary_op_impl(mc, system, x, &mut cache, OpType::Deterministic, &|_, _, x| Ok((!x.as_bool()?).into())), + UnaryOp::Abs => unary_op_impl(mc, system, x, &mut cache, OpType::Deterministic, &|_, _, x| Ok(x.as_number()?.abs()?.into())), + UnaryOp::Neg => unary_op_impl(mc, system, x, &mut cache, OpType::Deterministic, &|_, _, x| Ok(x.as_number()?.neg()?.into())), + UnaryOp::Sqrt => unary_op_impl(mc, system, x, &mut cache, OpType::Deterministic, &|_, _, x| Ok(x.as_number()?.sqrt()?.into())), + UnaryOp::Round => unary_op_impl(mc, system, x, &mut cache, OpType::Deterministic, &|_, _, x| Ok(x.as_number()?.round()?.into())), + UnaryOp::Floor => unary_op_impl(mc, system, x, &mut cache, OpType::Deterministic, &|_, _, x| Ok(x.as_number()?.floor()?.into())), + UnaryOp::Ceil => unary_op_impl(mc, system, x, &mut cache, OpType::Deterministic, &|_, _, x| Ok(x.as_number()?.ceil()?.into())), + UnaryOp::Sin => unary_op_impl(mc, system, x, &mut cache, OpType::Deterministic, &|_, _, x| Ok(Number::new(libm::sin(x.as_number()?.get().to_radians()))?.into())), + UnaryOp::Cos => unary_op_impl(mc, system, x, &mut cache, OpType::Deterministic, &|_, _, x| Ok(Number::new(libm::cos(x.as_number()?.get().to_radians()))?.into())), + UnaryOp::Tan => unary_op_impl(mc, system, x, &mut cache, OpType::Deterministic, &|_, _, x| Ok(Number::new(libm::tan(x.as_number()?.get().to_radians()))?.into())), + UnaryOp::Asin => unary_op_impl(mc, system, x, &mut cache, OpType::Deterministic, &|_, _, x| Ok(Number::new(libm::asin(x.as_number()?.get()).to_degrees())?.into())), + UnaryOp::Acos => unary_op_impl(mc, system, x, &mut cache, OpType::Deterministic, &|_, _, x| Ok(Number::new(libm::acos(x.as_number()?.get()).to_degrees())?.into())), + UnaryOp::Atan => unary_op_impl(mc, system, x, &mut cache, OpType::Deterministic, &|_, _, x| Ok(Number::new(libm::atan(x.as_number()?.get()).to_degrees())?.into())), + UnaryOp::StrLen => unary_op_impl(mc, system, x, &mut cache, OpType::Deterministic, &|_, _, x| Ok(Number::new(x.as_string()?.chars().count() as f64)?.into())), + + UnaryOp::StrGetLast => unary_op_impl(mc, system, x, &mut cache, OpType::Deterministic, &|_, _, x| match x.as_string()?.chars().next_back() { Some(ch) => Ok(Rc::new(ch.to_string()).into()), None => Err(ErrorCause::IndexOutOfBounds { index: 1, len: 0 }), }), - UnaryOp::StrGetRandom => unary_op_impl(mc, system, x, &mut cache, &|_, system, x| { + UnaryOp::StrGetRandom => unary_op_impl(mc, system, x, &mut cache, OpType::Nondeterministic, &|_, system, x| { let x = x.as_string()?; let i = prep_rand_index(system, x.chars().count())?; Ok(Rc::new(x.chars().nth(i).unwrap().to_string()).into()) }), - UnaryOp::SplitLetter => unary_op_impl(mc, system, x, &mut cache, &|mc, _, x| { + UnaryOp::SplitLetter => unary_op_impl(mc, system, x, &mut cache, OpType::Deterministic, &|mc, _, x| { Ok(Gc::new(mc, RefLock::new(x.as_string()?.chars().map(|x| Rc::new(x.to_string()).into()).collect::>())).into()) }), - UnaryOp::SplitWord => unary_op_impl(mc, system, x, &mut cache, &|mc, _, x| { + UnaryOp::SplitWord => unary_op_impl(mc, system, x, &mut cache, OpType::Deterministic, &|mc, _, x| { Ok(Gc::new(mc, RefLock::new(x.as_string()?.split_whitespace().map(|x| Rc::new(x.to_owned()).into()).collect::>())).into()) }), - UnaryOp::SplitTab => unary_op_impl(mc, system, x, &mut cache, &|mc, _, x| { + UnaryOp::SplitTab => unary_op_impl(mc, system, x, &mut cache, OpType::Deterministic, &|mc, _, x| { Ok(Gc::new(mc, RefLock::new(x.as_string()?.split('\t').map(|x| Rc::new(x.to_owned()).into()).collect::>())).into()) }), - UnaryOp::SplitCR => unary_op_impl(mc, system, x, &mut cache, &|mc, _, x| { + UnaryOp::SplitCR => unary_op_impl(mc, system, x, &mut cache, OpType::Deterministic, &|mc, _, x| { Ok(Gc::new(mc, RefLock::new(x.as_string()?.split('\r').map(|x| Rc::new(x.to_owned()).into()).collect::>())).into()) }), - UnaryOp::SplitLF => unary_op_impl(mc, system, x, &mut cache, &|mc, _, x| { + UnaryOp::SplitLF => unary_op_impl(mc, system, x, &mut cache, OpType::Deterministic, &|mc, _, x| { Ok(Gc::new(mc, RefLock::new(x.as_string()?.lines().map(|x| Rc::new(x.to_owned()).into()).collect::>())).into()) }), - UnaryOp::SplitCsv => unary_op_impl(mc, system, x, &mut cache, &|mc, _, x| { + UnaryOp::SplitCsv => unary_op_impl(mc, system, x, &mut cache, OpType::Deterministic, &|mc, _, x| { let value = from_csv(mc, x.as_string()?.as_ref())?; Ok(Gc::new(mc, RefLock::new(value)).into()) }), - UnaryOp::SplitJson => unary_op_impl(mc, system, x, &mut cache, &|mc, _, x| { + UnaryOp::SplitJson => unary_op_impl(mc, system, x, &mut cache, OpType::Deterministic, &|mc, _, x| { let value = x.as_string()?; match parse_json::(&value) { Ok(json) => Ok(Value::from_simple(mc, SimpleValue::from_json(json)?)), @@ -1829,7 +1832,7 @@ mod ops { } }), - UnaryOp::UnicodeToChar => unary_op_impl(mc, system, x, &mut cache, &|_, _, x| { + UnaryOp::UnicodeToChar => unary_op_impl(mc, system, x, &mut cache, OpType::Deterministic, &|_, _, x| { let fnum = x.as_number()?.get(); if fnum < 0.0 || fnum > u32::MAX as f64 { return Err(ErrorCause::InvalidUnicode { value: fnum }) } let num = fnum as u32; @@ -1839,7 +1842,7 @@ mod ops { None => Err(ErrorCause::InvalidUnicode { value: fnum }), } }), - UnaryOp::CharToUnicode => unary_op_impl(mc, system, x, &mut cache, &|mc, _, x| { + UnaryOp::CharToUnicode => unary_op_impl(mc, system, x, &mut cache, OpType::Deterministic, &|mc, _, x| { let src = x.as_string()?; let values: VecDeque<_> = src.chars().map(|ch| Ok(Number::new(ch as u32 as f64)?.into())).collect::>()?; Ok(match values.len() { @@ -1852,7 +1855,7 @@ mod ops { pub(super) fn index_list<'gc, C: CustomTypes, S: System>(mc: &Mutation<'gc>, system: &S, list: &Value<'gc, C, S>, index: &Value<'gc, C, S>) -> Result, ErrorCause> { let list = list.as_list()?; let list = list.borrow(); - unary_op_impl(mc, system, index, &mut Default::default(), &|_, _, x| Ok(list[prep_index(x, list.len())?].clone())) + unary_op_impl(mc, system, index, &mut Default::default(), OpType::Deterministic, &|_, _, x| Ok(list[prep_index(x, list.len())?].clone())) } fn cmp_impl<'gc, C: CustomTypes, S: System>(a: &Value<'gc, C, S>, b: &Value<'gc, C, S>, cache: &mut BTreeMap<(Identity<'gc, C, S>, Identity<'gc, C, S>), Option>>) -> Result, ErrorCause> { diff --git a/src/test/blocks/preserve-tensor-topology.xml b/src/test/blocks/preserve-tensor-topology.xml new file mode 100644 index 0000000..403e5db --- /dev/null +++ b/src/test/blocks/preserve-tensor-topology.xml @@ -0,0 +1 @@ +
\ No newline at end of file diff --git a/src/test/process.rs b/src/test/process.rs index 0ece415..8915235 100644 --- a/src/test/process.rs +++ b/src/test/process.rs @@ -1150,6 +1150,28 @@ fn test_proc_variadic_sum_product() { }); } +#[test] +fn test_proc_preserve_tensor_topology() { + let system = Rc::new(StdSystem::new_sync(BASE_URL.to_owned(), None, Config::default(), Arc::new(Clock::new(UtcOffset::UTC, None)))); + let (mut env, _) = get_running_proc(&format!(include_str!("templates/generic-static.xml"), + globals = "", + fields = "", + funcs = include_str!("blocks/preserve-tensor-topology.xml"), + methods = "", + ), Settings::default(), system, |_| SymbolTable::default()); + + run_till_term(&mut env, |mc, _, res| { + let expect = Value::from_simple(mc, SimpleValue::from_json(json!([ + [[["5", "3", "4"], ["5", "3", "4"], ["5", "3", "4"]], true, false, false], + [[[6, 4, 5], [6, 4, 5], [6, 4, 5]], true, false, false], + [[[10, 6, 8], [10, 6, 8], [10, 6, 8]], true, false, false], + [[[-5, -3, -4], [-5, -3, -4], [-5, -3, -4]], true, false, false], + [[[32, 8, 16], [32, 8, 16], [32, 8, 16]], true, false, false], + ])).unwrap()); + assert_values_eq(&res.unwrap().0, &expect, 1e-5, "preserve tensor topology"); + }); +} + #[test] fn test_proc_variadic_min_max() { let system = Rc::new(StdSystem::new_sync(BASE_URL.to_owned(), None, Config::default(), Arc::new(Clock::new(UtcOffset::UTC, None))));