From 005476cbcd71f4bcdfeda8f41461ea20dbdc09df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CWanglongzhi2001=E2=80=9D?= <“583087864@qq.com”> Date: Wed, 26 Jul 2023 15:31:06 +0800 Subject: [PATCH] fix: add the gradient of the tf.gradient opr --- src/TensorFlowNET.Core/Gradients/array_grad.cs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/TensorFlowNET.Core/Gradients/array_grad.cs b/src/TensorFlowNET.Core/Gradients/array_grad.cs index 1b6bc95ee..4b7027992 100644 --- a/src/TensorFlowNET.Core/Gradients/array_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/array_grad.cs @@ -373,5 +373,13 @@ public static Tensor[] _TransposeGrad(Operation op, Tensor[] grads) var p = op.inputs[1]; return new Tensor[] { array_ops.transpose(grads[0], array_ops.invert_permutation(p)), null }; } + + [RegisterGradient("ReverseV2")] + public static Tensor[] _ReverseV2Grad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var axis = op.inputs[1]; + return new Tensor[] { array_ops.reverse(grad, axis), null }; + } } }