Skip to content

Commit

Permalink
other tensor index ops
Browse files Browse the repository at this point in the history
  • Loading branch information
dragazo committed Nov 27, 2023
1 parent e8b3e00 commit bb97ec9
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 6 deletions.
46 changes: 40 additions & 6 deletions src/process.rs
Original file line number Diff line number Diff line change
Expand Up @@ -654,8 +654,11 @@ impl<'gc, C: CustomTypes<S>, S: System<C>> Process<'gc, C, S> {
let val = self.value_stack.pop().unwrap();
let mut list = list.borrow_mut(mc);

let index = ops::prep_index(&index, list.len() + 1)?;
list.insert(index, val);
let index_set = ops::prep_index_set(&index, list.len() + 1)?;
for index in index_set.into_iter().rev() {
list.insert(index, val.clone());
}

self.pos = aft_pos;
}
Instruction::ListInsertLast => {
Expand Down Expand Up @@ -702,8 +705,11 @@ impl<'gc, C: CustomTypes<S>, S: System<C>> Process<'gc, C, S> {
let index = self.value_stack.pop().unwrap();
let mut list = list.borrow_mut(mc);

let index = ops::prep_index(&index, list.len())?;
list[index] = value;
let index_set = ops::prep_index_set(&index, list.len())?;
for index in index_set {
list[index] = value.clone();
}

self.pos = aft_pos;
}
Instruction::ListAssignLast => {
Expand All @@ -728,8 +734,12 @@ impl<'gc, C: CustomTypes<S>, S: System<C>> Process<'gc, C, S> {
let list = self.value_stack.pop().unwrap().as_list()?;
let index = self.value_stack.pop().unwrap();
let mut list = list.borrow_mut(mc);
let index = ops::prep_index(&index, list.len())?;
list.remove(index);

let index_set = ops::prep_index_set(&index, list.len())?;
for index in index_set.into_iter().rev() {
list.remove(index);
}

self.pos = aft_pos;
}
Instruction::ListRemoveLast => {
Expand Down Expand Up @@ -1420,6 +1430,30 @@ mod ops {
if len == 0 { return Err(ErrorCause::IndexOutOfBounds { index: 1, len: 0 }) }
Ok(system.rand(0..len))
}
pub(super) fn prep_index_set<'gc, C: CustomTypes<S>, S: System<C>>(index: &Value<'gc, C, S>, len: usize) -> Result<BTreeSet<usize>, ErrorCause<C, S>> {
fn set_impl<'gc, C: CustomTypes<S>, S: System<C>>(index: &Value<'gc, C, S>, len: usize, dest: &mut BTreeSet<usize>, cache: &mut BTreeSet<Identity<'gc, C, S>>) -> Result<(), ErrorCause<C, S>> {
match index {
Value::List(values) => {
let key = index.identity();
if cache.insert(key) {
for value in values.borrow().iter() {
set_impl(value, len, dest, cache)?;
}
cache.remove(&key);
}
}
_ => {
dest.insert(ops::prep_index(index, len)?);
}
}
Ok(())
}
let mut res = Default::default();
let mut cache = Default::default();
set_impl(index, len, &mut res, &mut cache)?;
debug_assert_eq!(cache.len(), 0);
Ok(res)
}

pub(super) fn flatten<'gc, C: CustomTypes<S>, S: System<C>>(value: &Value<'gc, C, S>) -> Result<VecDeque<Value<'gc, C, S>>, ErrorCause<C, S>> {
fn flatten_impl<'gc, C: CustomTypes<S>, S: System<C>>(value: &Value<'gc, C, S>, dest: &mut VecDeque<Value<'gc, C, S>>, cache: &mut BTreeSet<Identity<'gc, C, S>>) -> Result<(), ErrorCause<C, S>> {
Expand Down
1 change: 1 addition & 0 deletions src/test/blocks/tensor-list-idx.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
<blocks app="NetsBlox 2.2.0, http://netsblox.org" version="2.2.0"><block-definition collabId="item_-1_2" s="main" type="reporter" category="custom"><header></header><code></code><translations></translations><inputs></inputs><script><block collabId="item_0" s="doDeclareVariables"><list><l>data</l><l>idx</l><l>res</l></list></block><block collabId="item_2" s="doSetVar"><l>data</l><block collabId="item_6" s="reportNewList"><list><l>11</l><block collabId="item_13" s="reportVariadicSum"><list><l>2</l><l>0</l></list></block><l>help</l><block collabId="item_24" s="reportNewList"><list><l>4</l><l>3</l><l>5</l></list></block><l>tr</l><l>34</l><block collabId="item_18" s="reportMonadic"><l><option>neg</option></l><l>7</l></block></list></block></block><block collabId="item_57" s="doSetVar"><l>idx</l><block collabId="item_59" s="reportNewList"><list><l>3</l><l>1</l><block collabId="item_63" s="reportNewList"><list><l>2</l><l>1</l><l>3</l></list></block><l>7</l></list></block></block><block collabId="item_76" s="doAddToList"><block collabId="item_78" var="idx"/><block collabId="item_77" var="idx"/></block><block collabId="item_51" s="doSetVar"><l>res</l><block collabId="item_53" s="reportNewList"><list></list></block></block><block collabId="item_90" s="doAddToList"><block collabId="item_92" s="reportNewList"><block collabId="item_93" var="data"/></block><block collabId="item_90_2" var="res"/></block><block collabId="item_79" s="doAddToList"><block collabId="item_86" s="reportListItem"><block collabId="item_87" var="idx"/><block collabId="item_88" var="data"/></block><block collabId="item_80" var="res"/></block><block collabId="item_95" s="doAddToList"><block collabId="item_95_1" s="reportNewList"><block collabId="item_95_3" var="data"/></block><block collabId="item_95_2" var="res"/></block><block collabId="item_97" s="doReplaceInList"><block collabId="item_98" var="idx"/><block collabId="item_117" var="data"/><block collabId="item_100" s="reportNewList"><list><l>1</l><l>2</l></list></block></block><block collabId="item_103" s="doAddToList"><block collabId="item_103_1" s="reportNewList"><block collabId="item_103_3" var="data"/></block><block collabId="item_103_2" var="res"/></block><block collabId="item_104" s="doInsertInList"><block collabId="item_105" s="reportNewList"><list><l>7</l><l>5</l></list></block><block collabId="item_108" var="idx"/><block collabId="item_109" var="data"/></block><block collabId="item_110" s="doAddToList"><block collabId="item_110_1" s="reportNewList"><block collabId="item_110_3" var="data"/></block><block collabId="item_110_2" var="res"/></block><block collabId="item_111" s="doDeleteFromList"><block collabId="item_112" var="idx"/><block collabId="item_113" var="data"/></block><block collabId="item_114" s="doAddToList"><block collabId="item_114_1" s="reportNewList"><block collabId="item_114_3" var="data"/></block><block collabId="item_114_2" var="res"/></block><block collabId="item_54" s="doReport"><block collabId="item_56" var="res"/></block></script></block-definition></blocks>
62 changes: 62 additions & 0 deletions src/test/process.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1007,6 +1007,68 @@ fn test_proc_pick_random() {
});
}

#[test]
fn test_proc_tensor_list_idx() {
let system = Rc::new(StdSystem::new_sync(BASE_URL.to_owned(), None, Config::default(), Arc::new(Clock::new(UtcOffset::UTC, None))));
let (mut env, _) = get_running_proc(&format!(include_str!("templates/generic-static.xml"),
globals = "",
fields = "",
funcs = include_str!("blocks/tensor-list-idx.xml"),
methods = "",
), Settings::default(), system, |_| SymbolTable::default());

run_till_term(&mut env, |mc, _, res| {
let results = res.unwrap().0.as_list().unwrap();
let results = &*results.borrow();

assert_eq!(results.len(), 6);
assert_values_eq(&results[0], &Value::from_simple(mc, SimpleValue::from_json(json!(["11", 2, "help", ["4", "3", "5"], "tr", "34", -7])).unwrap()), 1e-5, "tensor list idx 0");
match &results[1] {
Value::List(x) => {
let x = &*x.borrow();
assert_eq!(x.len(), 5);
match &x[0] {
Value::String(x) => assert_eq!(x.as_str(), "help"),
x => panic!("{x:?}"),
}
match &x[1] {
Value::String(x) => assert_eq!(x.as_str(), "11"),
x => panic!("{x:?}"),
}
match &x[2] {
Value::List(x) => {
let x = &*x.borrow();
assert_eq!(x.len(), 3);
match &x[0] {
Value::Number(x) => assert!(x.get() == 2.0),
x => panic!("{x:?}"),
}
match &x[1] {
Value::String(x) => assert_eq!(x.as_str(), "11"),
x => panic!("{x:?}"),
}
match &x[2] {
Value::String(x) => assert_eq!(x.as_str(), "help"),
x => panic!("{x:?}"),
}
}
x => panic!("{x:?}"),
}
match &x[3] {
Value::Number(x) => assert!(x.get() == -7.0),
x => panic!("{x:?}"),
}
assert_eq!(results[1].identity(), x[4].identity());
}
x => panic!("{x:?}"),
}
assert_values_eq(&results[2], &Value::from_simple(mc, SimpleValue::from_json(json!(["11", 2, "help", ["4", "3", "5"], "tr", "34", -7])).unwrap()), 1e-5, "tensor list idx 2");
assert_values_eq(&results[3], &Value::from_simple(mc, SimpleValue::from_json(json!([["1", "2"], ["1", "2"], ["1", "2"], ["4", "3", "5"], "tr", "34", ["1", "2"]])).unwrap()), 1e-5, "tensor list idx 3");
assert_values_eq(&results[4], &Value::from_simple(mc, SimpleValue::from_json(json!([["7", "5"], ["1", "2"], ["7", "5"], ["1", "2"], ["7", "5"], ["1", "2"], ["4", "3", "5"], "tr", "34", ["7", "5"], ["1", "2"]])).unwrap()), 1e-5, "tensor list idx 4");
assert_values_eq(&results[5], &Value::from_simple(mc, SimpleValue::from_json(json!([["1", "2"], ["7", "5"], ["1", "2"], "tr", "34", ["7", "5"], ["1", "2"]])).unwrap()), 1e-5, "tensor list idx 5");
});
}

#[test]
fn test_proc_rand_list_ops() {
let system = Rc::new(StdSystem::new_sync(BASE_URL.to_owned(), None, Config::default(), Arc::new(Clock::new(UtcOffset::UTC, None))));
Expand Down

0 comments on commit bb97ec9

Please sign in to comment.