Skip to content

Commit

Permalink
Merge pull request #97 from andfoy/add_stack
Browse files Browse the repository at this point in the history
Add stack
  • Loading branch information
andfoy authored Dec 2, 2023
2 parents 53d2f0b + cb9d355 commit 358e725
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 1 deletion.
61 changes: 60 additions & 1 deletion lib/extorch/native/tensor/ops/manipulation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1553,7 +1553,66 @@ defmodule ExTorch.Native.Tensor.Ops.Manipulation do
iex> b.size
{1, 3, 4, 5}
"""
@spec squeeze(ExTorch.Tensor.t(), integer() | tuple() | [integer()] | nil) :: ExTorch.Tensor.t()
@spec squeeze(ExTorch.Tensor.t(), integer() | tuple() | [integer()] | nil) ::
ExTorch.Tensor.t()
defbinding(squeeze(input, dim \\ nil))

@doc """
Concatenates a sequence of tensors along a new dimension.
All tensors need to be of the same size. This function is analogous to `ExTorch.cat/3`
## Arguments
- `tensors` (`[ExTorch.Tensor] | tuple()`) - A sequence of tensors of the same type. Non-empty
tensors provided must have the same shape.
## Optional arguments
- `dim` (`integer()`) - the dimension over which the tensors are concatenated. Default: 0
- `out` (`ExTorch.Tensor | nil`) - an optional pre-allocated tensor used to store the
concatenation output. Default: nil
## Examples
iex> a = ExTorch.rand({3, 4})
#Tensor<
[[0.7419, 0.4063, 0.0514, 0.4281],
[0.7350, 0.1977, 0.5593, 0.1701],
[0.4135, 0.7213, 0.9591, 0.2798]]
[size: {3, 4}, dtype: :float, device: :cpu, requires_grad: false]>
# Concatenate tensors into a new dimension at the beginning
iex> ExTorch.stack([a, a, a])
#Tensor<
[[[0.7419, 0.4063, 0.0514, 0.4281],
[0.7350, 0.1977, 0.5593, 0.1701],
[0.4135, 0.7213, 0.9591, 0.2798]],
[[0.7419, 0.4063, 0.0514, 0.4281],
[0.7350, 0.1977, 0.5593, 0.1701],
[0.4135, 0.7213, 0.9591, 0.2798]],
[[0.7419, 0.4063, 0.0514, 0.4281],
[0.7350, 0.1977, 0.5593, 0.1701],
[0.4135, 0.7213, 0.9591, 0.2798]]]
[size: {3, 3, 4}, dtype: :float, device: :cpu, requires_grad: false]>
# Concatenate tensors into a new dimension at the first position
iex> ExTorch.stack([a, a, a], 1)
#Tensor<
[[[0.7419, 0.4063, 0.0514, 0.4281],
[0.7419, 0.4063, 0.0514, 0.4281],
[0.7419, 0.4063, 0.0514, 0.4281]],
[[0.7350, 0.1977, 0.5593, 0.1701],
[0.7350, 0.1977, 0.5593, 0.1701],
[0.7350, 0.1977, 0.5593, 0.1701]],
[[0.4135, 0.7213, 0.9591, 0.2798],
[0.4135, 0.7213, 0.9591, 0.2798],
[0.4135, 0.7213, 0.9591, 0.2798]]]
[size: {3, 3, 4}, dtype: :float, device: :cpu, requires_grad: false]>
"""
@spec stack([ExTorch.Tensor.t()] | tuple(), integer(), ExTorch.Tensor.t() | nil) ::
ExTorch.Tensor.t()
defbinding(stack(input, dim \\ 0, out \\ nil))
end
end
2 changes: 2 additions & 0 deletions native/extorch/include/manipulation.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,5 @@ TensorList split(
std::shared_ptr<CrossTensor> squeeze(
const std::shared_ptr<CrossTensor> &input,
rust::Vec<int64_t> dims);

std::shared_ptr<CrossTensor> stack(TensorList seq, int64_t dim, TensorOut opt_out);
13 changes: 13 additions & 0 deletions native/extorch/src/csrc/manipulation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -722,3 +722,16 @@ std::shared_ptr<CrossTensor> squeeze(
}
return std::make_shared<CrossTensor>(std::move(out_tensor));
}

std::shared_ptr<CrossTensor> stack(TensorList seq, int64_t dim, TensorOut opt_out) {
CrossTensor out_tensor;
std::vector<CrossTensor> tensor_seq = unpack_tensor_list(seq);

if(opt_out.used) {
out_tensor = *opt_out.tensor.get();
out_tensor = torch::stack_out(out_tensor, tensor_seq, dim);
} else {
out_tensor = torch::stack(tensor_seq, dim);
}
return std::make_shared<CrossTensor>(std::move(out_tensor));
}
1 change: 1 addition & 0 deletions native/extorch/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ rustler::init!(
scatter_reduce,
split,
squeeze,
stack,

// Tensor comparing operations
allclose,
Expand Down
3 changes: 3 additions & 0 deletions native/extorch/src/native/tensor/ops.rs.in
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,6 @@ fn split(

/// Remove specified or all singleton dimensions from a input tensor.
fn squeeze(input: &SharedPtr<CrossTensor>, dims: Vec<i64>) -> Result<SharedPtr<CrossTensor>>;

/// Concatenate a sequence of tensors across a new dimension.
fn stack(seq: TensorList, dim: i64, out: TensorOut) -> Result<SharedPtr<CrossTensor>>;
2 changes: 2 additions & 0 deletions native/extorch/src/nifs/tensor_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,3 +281,5 @@ nif_impl!(
input: TensorStruct<'a>,
dims: Size
);

nif_impl!(stack, TensorStruct<'a>, seq: TensorList, dim: i64, out: TensorOut);
37 changes: 37 additions & 0 deletions test/tensor/manipulation_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -803,4 +803,41 @@ defmodule ExTorchTest.Tensor.ManipulationTest do
out = ExTorch.squeeze(input, {0, 2})
assert out.size == expected
end

test "stack/1" do
input = ExTorch.rand({3, 4, 2})
exp_input = ExTorch.unsqueeze(input, 0)
expected = ExTorch.cat([exp_input, exp_input, exp_input])

out = ExTorch.stack([input, input, input])
assert ExTorch.allclose(out, expected)
end

test "stack/2" do
input = ExTorch.rand({3, 4, 2})
exp_input = ExTorch.unsqueeze(input, 2)
expected = ExTorch.cat([exp_input, exp_input, exp_input], 2)

out = ExTorch.stack([input, input, input], 2)
assert ExTorch.allclose(out, expected)
end

test "stack/2 with kwargs" do
input = ExTorch.rand({3, 4, 2})
exp_input = ExTorch.unsqueeze(input, 3)
expected = ExTorch.cat([exp_input, exp_input, exp_input], dim: 3)

out = ExTorch.stack([input, input, input], dim: 3)
assert ExTorch.allclose(out, expected)
end

test "stack/3" do
input = ExTorch.rand({3, 4, 2})
exp_input = ExTorch.unsqueeze(input, 1)
expected = ExTorch.cat([exp_input, exp_input, exp_input], 1)

out = ExTorch.empty_like(expected)
ExTorch.stack([input, input, input], 1, out)
assert ExTorch.allclose(out, expected)
end
end

0 comments on commit 358e725

Please sign in to comment.