From b0dcf760e502aa3bfc01ae57126ccdf6784a01c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Tue, 8 Oct 2024 00:04:04 +0800 Subject: [PATCH] Fix tests --- lib/bumblebee/utils/nx.ex | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/lib/bumblebee/utils/nx.ex b/lib/bumblebee/utils/nx.ex index f9806af6..c3bf368c 100644 --- a/lib/bumblebee/utils/nx.ex +++ b/lib/bumblebee/utils/nx.ex @@ -75,22 +75,22 @@ defmodule Bumblebee.Utils.Nx do iex> [first, second] = Bumblebee.Utils.Nx.batch_to_list(outputs) iex> first.x #Nx.Tensor< - s64[2] + s32[2] [0, 0] > iex> second.x #Nx.Tensor< - s64[2] + s32[2] [1, 1] > iex> first.y #Nx.Tensor< - s64 + s32 0 > iex> second.y #Nx.Tensor< - s64 + s32 1 > @@ -122,7 +122,7 @@ defmodule Bumblebee.Utils.Nx do iex> result = Bumblebee.Utils.Nx.composite_concatenate(left, right) iex> result.x #Nx.Tensor< - s64[4][2] + s32[4][2] [ [0, 0], [1, 1], @@ -132,7 +132,7 @@ defmodule Bumblebee.Utils.Nx do > iex> result.y #Nx.Tensor< - s64[4] + s32[4] [0, 1, 2, 3] > @@ -164,7 +164,7 @@ defmodule Bumblebee.Utils.Nx do iex> result = Bumblebee.Utils.Nx.composite_unflatten_batch(output, 2) iex> result.x #Nx.Tensor< - s64[2][1][2] + s32[2][1][2] [ [ [0, 0] @@ -176,7 +176,7 @@ defmodule Bumblebee.Utils.Nx do > iex> result.y #Nx.Tensor< - s64[2][1] + s32[2][1] [ [0], [1] @@ -205,12 +205,12 @@ defmodule Bumblebee.Utils.Nx do iex> result = Bumblebee.Utils.Nx.composite_flatten_batch(output) iex> result.x #Nx.Tensor< - s64[4] + s32[4] [0, 0, 1, 1] > iex> result.y #Nx.Tensor< - s64[2] + s32[2] [0, 1] > @@ -249,7 +249,7 @@ defmodule Bumblebee.Utils.Nx do iex> idx = Nx.tensor([[1, 0], [1, 1]]) iex> Bumblebee.Utils.Nx.batched_take(t, idx) #Nx.Tensor< - s64[2][2][2] + s32[2][2][2] [ [ [2, 2], @@ -348,7 +348,7 @@ defmodule Bumblebee.Utils.Nx do iex> x = Nx.tensor([[1, 2], [3, 4]]) iex> Bumblebee.Utils.Nx.repeat_interleave(x, 2) #Nx.Tensor< - s64[4][2] + s32[4][2] [ [1, 2], [1, 2], @@ -387,7 +387,7 @@ defmodule Bumblebee.Utils.Nx do iex> x = Nx.tensor([[1, 1], [2, 2], [3, 3], [4, 4]]) iex> Bumblebee.Utils.Nx.chunked_take(x, 2, Nx.tensor([1, 0])) #Nx.Tensor< - s64[2][2] + s32[2][2] [ [2, 2], [3, 3] @@ -427,7 +427,7 @@ defmodule Bumblebee.Utils.Nx do iex> x = Nx.iota({3, 3}) iex> Bumblebee.Utils.Nx.roll(x, shifts: [1], axes: [0]) #Nx.Tensor< - s64[3][3] + s32[3][3] [ [6, 7, 8], [0, 1, 2], @@ -438,7 +438,7 @@ defmodule Bumblebee.Utils.Nx do iex> x = Nx.iota({3, 3}) iex> Bumblebee.Utils.Nx.roll(x, shifts: [-1], axes: [0]) #Nx.Tensor< - s64[3][3] + s32[3][3] [ [3, 4, 5], [6, 7, 8], @@ -449,7 +449,7 @@ defmodule Bumblebee.Utils.Nx do iex> x = Nx.iota({3, 3}) iex> Bumblebee.Utils.Nx.roll(x, shifts: [1, 2], axes: [0, 1]) #Nx.Tensor< - s64[3][3] + s32[3][3] [ [7, 8, 6], [1, 2, 0],