Skip to content

Commit

Permalink
Merge pull request #1208 from Wanglongzhi2001/fix_concat_v2_bug
Browse files Browse the repository at this point in the history
fix: fix the bug caused by concat_v2
  • Loading branch information
Oceania2018 authored Nov 5, 2023
2 parents 7f0bde3 + 8c06bbb commit 6aa2ab2
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
Original file line number Diff line number Diff line change
Expand Up @@ -428,9 +428,9 @@ public static Tensor _transpose_batch_time(Tensor x)
return x;

var x_rank = array_ops.rank(x);
var con1 = new Tensor[]
var con1 = new object[]
{
new Tensor(new int[]{0, 2}),
new []{1, 0 },
math_ops.range(2, x_rank)
};
var x_t = array_ops.transpose(x, array_ops.concat(con1, 0));
Expand Down
6 changes: 3 additions & 3 deletions src/TensorFlowNET.Core/Operations/array_ops.cs
Original file line number Diff line number Diff line change
Expand Up @@ -945,12 +945,12 @@ public static Tensor broadcast_static_shape(Tensor shape_x, Tensor shape_y)
/// <returns></returns>
public static Tensor concat(Tensor[] values, Tensor axis, string name = "concat")
{
return gen_array_ops.concat_v2(values, axis, name: name);
return tf.Context.ExecuteOp("ConcatV2", name, new ExecuteOpArgs(values, axis));
}

public static Tensor concat(Tensor[] values, Axis axis, string name = "concat")
public static Tensor concat(object[] values, int axis, string name = "concat")
{
return gen_array_ops.concat_v2(values, axis, name: name);
return tf.Context.ExecuteOp("ConcatV2", name, new ExecuteOpArgs(values, axis));
}

/// <summary>
Expand Down
2 changes: 1 addition & 1 deletion src/TensorFlowNET.Core/Operations/nn_ops.cs
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ private static Tensor _flatten_outer_dims(Tensor logits)
new[] { math_ops.subtract(rank, 1) },
new[] { constant_op.constant(1) });

var ops = array_ops.concat(new Tensor[] { new Tensor(new int[] {1}), last_dim_size }, 0);
var ops = array_ops.concat(new[] { new[] { -1 }, (object)last_dim_size }, 0);
var output = array_ops.reshape(logits, ops);

// Set output shape if known.
Expand Down

0 comments on commit 6aa2ab2

Please sign in to comment.