How can I use boolean indexing in Burn? #1585
Answered
by
laggui
zemelLeong
asked this question in
Q&A
-
In >>> t
tensor([[[0.5213, 0.6049, 0.2158, 0.7163],
[0.4655, 0.7438, 0.6514, 0.6525],
[0.5567, 0.9781, 0.9310, 0.3846],
[0.8512, 0.7049, 0.5219, 0.9497],
[0.9796, 0.7220, 0.7281, 0.3046],
[0.9927, 0.6197, 0.5130, 0.1818],
[0.7108, 0.9334, 0.4279, 0.8117],
[0.7960, 0.3307, 0.8622, 0.3465],
[0.5505, 0.5056, 0.0849, 0.0585],
[0.9278, 0.5415, 0.5889, 0.7620]]])
>>> inds = t[0, :, -1] > 0.5
>>> t[0, inds, :2]
tensor([[0.5213, 0.6049],
[0.4655, 0.7438],
[0.8512, 0.7049],
[0.7108, 0.9334],
[0.9278, 0.5415]]) This is a method I have hypothesized. let t = Tensor::<B, 3, Int>::random([1, 10, 4], Distribution::Default, &B::Device::default());
let mask = t.clone().slice([0..1, 0..10, 3..4]).greater_elem(50);
// A hypothetical method that could fetch the indices.
let indices = t.get_indices(1, mask);
t.select(1, indices); |
Beta Was this translation helpful? Give feedback.
Answered by
laggui
Apr 8, 2024
Replies: 2 comments
-
Currently, I have only thought of this not very graceful method. let t =
Tensor::<B, 3, Int>::random([1, 10, 4], Distribution::Default, &B::Device::default());
let mask = t
.clone()
.slice([0..1, 0..10, 3..4])
.greater_elem(50)
.reshape([-1]);
let indices = {
let indices = mask
.to_data()
.value
.iter()
.enumerate()
.filter_map(|(index, val)| if *val { Some(index as i32) } else { None })
.collect::<Vec<i32>>();
Tensor::<B, 1, Int>::from_data(
Data::new(indices.clone(), Shape::new([indices.len()])).convert(),
&B::Device::default(),
)
};
let res = t.select(1, indices);
println!("{:?}", res); |
Beta Was this translation helpful? Give feedback.
0 replies
-
On main we added let t = Tensor::<B, 3>::from_floats(
[[
[0.5213, 0.6049, 0.2158, 0.7163],
[0.4655, 0.7438, 0.6514, 0.6525],
[0.5567, 0.9781, 0.9310, 0.3846],
[0.8512, 0.7049, 0.5219, 0.9497],
[0.9796, 0.7220, 0.7281, 0.3046],
[0.9927, 0.6197, 0.5130, 0.1818],
[0.7108, 0.9334, 0.4279, 0.8117],
[0.7960, 0.3307, 0.8622, 0.3465],
[0.5505, 0.5056, 0.0849, 0.0585],
[0.9278, 0.5415, 0.5889, 0.7620],
]],
&device,
);
let mask = t.clone().slice([0..1, 0..10, 3..4]).greater_elem(0.5);
let indices = mask.nonzero();
let indices_dim1 = indices.get(1).unwrap();
let selected = t.select(1, indices_dim1.clone()); This works well because you're only interested in selecting across one dimension. But for more complex indexing it might get a bit more convoluted 😅 |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
laggui
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
On main we added
tensor.nonzero()
andtensor.argwhere()
, so you could do the following for your original example: