-
Notifications
You must be signed in to change notification settings - Fork 517
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
3811e4e
commit 2121079
Showing
4 changed files
with
250 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
68 changes: 68 additions & 0 deletions
68
test/TensorFlowNET.UnitTest/Training/GradientDescentOptimizerTests.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
using Microsoft.VisualStudio.TestTools.UnitTesting; | ||
using System; | ||
using System.Linq; | ||
using System.Runtime.Intrinsics.X86; | ||
using System.Security.AccessControl; | ||
using Tensorflow.NumPy; | ||
using TensorFlowNET.UnitTest; | ||
using static Tensorflow.Binding; | ||
|
||
namespace Tensorflow.Keras.UnitTest.Optimizers | ||
{ | ||
[TestClass] | ||
public class GradientDescentOptimizerTest : PythonTest | ||
{ | ||
private void TestBasicGeneric<T>() where T : struct | ||
{ | ||
var dtype = Type.GetTypeCode(typeof(T)) switch | ||
{ | ||
TypeCode.Single => np.float32, | ||
TypeCode.Double => np.float64, | ||
_ => throw new NotImplementedException(), | ||
}; | ||
|
||
// train.GradientDescentOptimizer is V1 only API. | ||
tf.Graph().as_default(); | ||
using (self.cached_session()) | ||
{ | ||
var var0 = tf.Variable(new[] { 1.0, 2.0 }, dtype: dtype); | ||
var var1 = tf.Variable(new[] { 3.0, 4.0 }, dtype: dtype); | ||
var grads0 = tf.constant(new[] { 0.1, 0.1 }, dtype: dtype); | ||
var grads1 = tf.constant(new[] { 0.01, 0.01 }, dtype: dtype); | ||
var optimizer = tf.train.GradientDescentOptimizer(3.0f); | ||
var grads_and_vars = new[] { | ||
Tuple.Create(grads0, var0 as IVariableV1), | ||
Tuple.Create(grads1, var1 as IVariableV1) | ||
}; | ||
var sgd_op = optimizer.apply_gradients(grads_and_vars); | ||
|
||
var global_variables = variables.global_variables_initializer(); | ||
self.evaluate<T>(global_variables); | ||
// Fetch params to validate initial values | ||
// TODO: use self.evaluate<T[]> instead of self.evaluate<double[]> | ||
self.assertAllCloseAccordingToType(new double[] { 1.0, 2.0 }, self.evaluate<double[]>(var0)); | ||
self.assertAllCloseAccordingToType(new double[] { 3.0, 4.0 }, self.evaluate<double[]>(var1)); | ||
// Run 1 step of sgd | ||
sgd_op.run(); | ||
// Validate updated params | ||
self.assertAllCloseAccordingToType( | ||
new double[] { 1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1 }, | ||
self.evaluate<double[]>(var0)); | ||
self.assertAllCloseAccordingToType( | ||
new double[] { 3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01 }, | ||
self.evaluate<double[]>(var1)); | ||
// TODO: self.assertEqual(0, len(optimizer.variables())); | ||
} | ||
} | ||
|
||
[TestMethod] | ||
public void TestBasic() | ||
{ | ||
//TODO: add np.half | ||
TestBasicGeneric<float>(); | ||
TestBasicGeneric<double>(); | ||
} | ||
|
||
|
||
} | ||
} |