Skip to content

Commit

Permalink
Fix inferred_value of KerasTensor. #1142
Browse files Browse the repository at this point in the history
  • Loading branch information
Oceania2018 committed Jul 17, 2023
1 parent 12e3f54 commit 6ec39ba
Show file tree
Hide file tree
Showing 10 changed files with 88 additions and 14 deletions.
2 changes: 1 addition & 1 deletion src/TensorFlowNET.Core/APIs/tf.reshape.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,6 @@ public Tensor reshape(Tensor tensor,
public Tensor reshape(Tensor tensor,
object[] shape,
string name = null)
=> gen_array_ops.reshape(tensor, ops.convert_to_tensor(shape), name);
=> array_ops.reshape(tensor, shape, name);
}
}
2 changes: 1 addition & 1 deletion src/TensorFlowNET.Core/APIs/tf.tile.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public Tensor tile(Tensor input, Tensor multiples, string name = null)
=> gen_array_ops.tile(input, multiples, name);

public Tensor tile(Tensor input, object[] multiples, string name = null)
=> gen_array_ops.tile(input, ops.convert_to_tensor(multiples), name);
=> array_ops.tile(input, multiples, name);

public Tensor tile(Tensor input, Shape multiples, string name = null)
{
Expand Down
3 changes: 2 additions & 1 deletion src/TensorFlowNET.Core/GlobalUsing.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
global using System.Data;
global using System.Linq;
global using Tensorflow.Keras.Engine;
global using Tensorflow.Framework.Models;
global using Tensorflow.Framework.Models;
global using static Tensorflow.Binding;
19 changes: 15 additions & 4 deletions src/TensorFlowNET.Core/Keras/Engine/KerasTensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,32 @@ public KerasTensor(TensorSpec type_spec, Shape inferred_value = null, string nam
public static KerasTensor from_tensor(Tensor tensor)
{
var type_spec = tensor.ToTensorSpec();
var kt = new KerasTensor(type_spec, name: tensor.name);
Shape? inferred_value = default;
if (tensor.dtype == TF_DataType.TF_INT32 && tensor.rank < 2)
{
inferred_value = tf.ones(tensor).shape;
}
var kt = new KerasTensor(type_spec, inferred_value: inferred_value, name: tensor.name);
kt.original_tensors = tensor;
return kt;
}

public KerasTensor this[int idx]
=> _original_tensors.First()[idx];

public KerasTensor this[params Slice[] slices]
=> _original_tensors.First()[slices];

public override string ToString()
=> _original_tensors.Length switch
{
> 1 => "[" + string.Join(", ", _original_tensors.Select(x => $"KerasTensor: shape={x.shape} dtype={x.dtype}")) + "]",
1 => $"KerasTensor: shape={_original_tensors.shape} {GetInferredValueString()} dtype={_original_tensors.dtype}",
> 1 => "[" + string.Join(", ", _original_tensors.Select(x => $"KerasTensor: shape={x.shape} dtype={x.dtype.as_numpy_name()}{GetInferredValueString()}")) + "]",
1 => $"KerasTensor: shape={_original_tensors.shape} dtype={_original_tensors.dtype.as_numpy_name()}{GetInferredValueString()}",
_ => _original_tensors.ToString(),
};

private string GetInferredValueString()
=> _inferred_value == null ? "" : "";
=> _inferred_value == null ? "" : $" inferred_value={_inferred_value}";

public static implicit operator Tensors(KerasTensor kt)
=> kt._original_tensors;
Expand Down
29 changes: 27 additions & 2 deletions src/TensorFlowNET.Core/Operations/array_ops.cs
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ public static Tensor zeros(Tensors shape, TF_DataType dtype = TF_DataType.TF_FLO
if(shape.Length > 1)
{
shapeTensor = ops.convert_to_tensor(shape, dtypes.int32);
if(shapeTensor.ndim > 1)
if (shapeTensor.ndim > 1)
{
shapeTensor = array_ops.reshape(shapeTensor, new Shape(-1));
}
Expand Down Expand Up @@ -304,6 +304,10 @@ public static Tensor _autopacking_helper(IEnumerable<object> list_or_tuple, TF_D
{
elems_as_tensors.Add(tensor);
}
else if (elem is KerasTensor kt)
{
elems_as_tensors.Add(kt);
}
else
{
var elem_tensor = constant_op.constant(elem, dtype: dtype, name: i.ToString());
Expand Down Expand Up @@ -404,7 +408,10 @@ public static Tensor reshape(Tensor tensor, Shape shape, string name = null)
=> gen_array_ops.reshape(tensor, shape, name: name);

public static Tensor reshape(Tensor tensor, object[] shape, string name = null)
=> gen_array_ops.reshape(tensor, ops.convert_to_tensor(shape), name: name);
{
var dims = shape_utils.from_object_array(shape);
return gen_array_ops.reshape(tensor, dims, name: name);
}

private static Tensor ones_like_impl<T>(T tensor, TF_DataType dtype, string name, bool optimize = true)
{
Expand All @@ -425,6 +432,10 @@ public static Tensor ones(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT
return tf_with(ops.name_scope(name, "ones", new { shape }), scope =>
{
name = scope;
if (shape._shape_tuple().Length == 0)
{
shape = reshape(shape, new Shape(-1));
}
var output = gen_array_ops.fill(shape, constant_op.constant(1.0f, dtype: dtype), name: name);
return output;
});
Expand Down Expand Up @@ -647,6 +658,20 @@ public static Tensor tile(Tensor input, Tensor multiples, string name = null)
}
});

public static Tensor tile(Tensor input, object[] multiples, string name = null)
{
Shape dims = shape_utils.from_object_array(multiples);

return tf.Context.ExecuteOp("Tile", name, new ExecuteOpArgs(input, dims)
{
GetGradientAttrs = (op) => new
{
T = op.get_attr<TF_DataType>("T"),
Tmultiples = op.get_attr<TF_DataType>("Tmultiples")
}
});
}

public static Tensor zeros_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true)
{
return tf_with(ops.name_scope(name, "zeros_like", new Tensor[] { tensor }), scope =>
Expand Down
27 changes: 27 additions & 0 deletions src/TensorFlowNET.Core/Tensors/shape_utils.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System.Linq;
using Tensorflow.Eager;
using static Tensorflow.Binding;

namespace Tensorflow
Expand All @@ -13,5 +14,31 @@ public static Tensor static_or_dynamic_map_fn(Func<Tensor, Tensor> fn, Tensor el

throw new NotImplementedException("");
}

public static Shape from_object_array(object[] shape)
{
var dims = shape.Select(x =>
{
if (x is KerasTensor kt && kt.inferred_value != null)
{
return kt.inferred_value.as_int_list()[0];
}
else if (x is EagerTensor et && et.dtype == TF_DataType.TF_INT32)
{
return et.ToArray<int>()[0];
}
else if (x is int i)
{
return i;
}
else if (x is long l)
{
return l;
}
throw new NotImplementedException();
}).ToArray();

return new Shape(dims);
}
}
}
3 changes: 3 additions & 0 deletions src/TensorFlowNET.Core/Tensors/tf.constant.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ public Tensor zeros(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT, stri
public Tensor ones(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null)
=> array_ops.ones(shape, dtype, name);

public Tensor ones(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null)
=> array_ops.ones(shape, dtype, name);

public Tensor size(Tensor input,
string name = null,
TF_DataType out_type = TF_DataType.TF_INT32) => array_ops.size(input,
Expand Down
11 changes: 9 additions & 2 deletions src/TensorFlowNET.Core/ops.cs
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,18 @@ public static Tensor convert_to_tensor(object value,
}
if (!graph.building_function)
{
throw new RuntimeError("Attempting to capture an EagerTensor without building a function.");
// return eager_tensor.AsPlaceholder(name: name);
// throw new RuntimeError("Attempting to capture an EagerTensor without building a function.");
return eager_tensor.AsPlaceholder(name: name);
}
}
}
else if (value is KerasTensor kt)
{
if (kt.inferred_value != null)
{
return convert_to_tensor(kt.inferred_value, dtype: kt.dtype, name: name);
}
}

// graph mode
Tensor ret = value switch
Expand Down
2 changes: 1 addition & 1 deletion src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ Keras is an API designed for human beings, not machines. Keras follows best prac

<ItemGroup>
<PackageReference Include="HDF5-CSharp" Version="1.17.0" />
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" />
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.149" />
<PackageReference Include="SharpZipLib" Version="1.4.2" />
</ItemGroup>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@

<ItemGroup>
<PackageReference Include="FluentAssertions" Version="5.10.3" />
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.3.2" />
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.149" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.6.3" />
<PackageReference Include="MSTest.TestAdapter" Version="2.2.10" />
<PackageReference Include="MSTest.TestFramework" Version="2.2.10" />
</ItemGroup>
Expand Down

0 comments on commit 6ec39ba

Please sign in to comment.