-
Notifications
You must be signed in to change notification settings - Fork 517
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1250 from SchoenTannenbaum/master
fix: regularizer serialization problem
- Loading branch information
Showing
10 changed files
with
289 additions
and
69 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,25 @@ | ||
namespace Tensorflow.Keras | ||
using Newtonsoft.Json; | ||
using System.Collections.Generic; | ||
using Tensorflow.Keras.Saving.Common; | ||
|
||
namespace Tensorflow.Keras | ||
{ | ||
public interface IRegularizer | ||
{ | ||
Tensor Apply(RegularizerArgs args); | ||
} | ||
[JsonConverter(typeof(CustomizedRegularizerJsonConverter))] | ||
public interface IRegularizer | ||
{ | ||
[JsonProperty("class_name")] | ||
string ClassName { get; } | ||
[JsonProperty("config")] | ||
IDictionary<string, object> Config { get; } | ||
Tensor Apply(RegularizerArgs args); | ||
} | ||
|
||
public interface IRegularizerApi | ||
{ | ||
IRegularizer GetRegularizerFromName(string name); | ||
IRegularizer L1 { get; } | ||
IRegularizer L2 { get; } | ||
IRegularizer L1L2 { get; } | ||
} | ||
|
||
} |
57 changes: 57 additions & 0 deletions
57
src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedRegularizerJsonConverter.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
using Newtonsoft.Json.Linq; | ||
using Newtonsoft.Json; | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Text; | ||
using Tensorflow.Operations.Regularizers; | ||
|
||
namespace Tensorflow.Keras.Saving.Common | ||
{ | ||
class RegularizerInfo | ||
{ | ||
public string class_name { get; set; } | ||
public JObject config { get; set; } | ||
} | ||
|
||
public class CustomizedRegularizerJsonConverter : JsonConverter | ||
{ | ||
public override bool CanConvert(Type objectType) | ||
{ | ||
return objectType == typeof(IRegularizer); | ||
} | ||
|
||
public override bool CanRead => true; | ||
|
||
public override bool CanWrite => true; | ||
|
||
public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) | ||
{ | ||
var regularizer = value as IRegularizer; | ||
if (regularizer is null) | ||
{ | ||
JToken.FromObject(null).WriteTo(writer); | ||
return; | ||
} | ||
JToken.FromObject(new RegularizerInfo() | ||
{ | ||
class_name = regularizer.ClassName, | ||
config = JObject.FromObject(regularizer.Config) | ||
}, serializer).WriteTo(writer); | ||
} | ||
|
||
public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | ||
{ | ||
var info = serializer.Deserialize<RegularizerInfo>(reader); | ||
if (info is null) | ||
{ | ||
return null; | ||
} | ||
return info.class_name switch | ||
{ | ||
"L1L2" => new L1L2 (info.config["l1"].ToObject<float>(), info.config["l2"].ToObject<float>()), | ||
"L1" => new L1(info.config["l1"].ToObject<float>()), | ||
"L2" => new L2(info.config["l2"].ToObject<float>()), | ||
}; | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
using System; | ||
|
||
using Tensorflow.Keras; | ||
|
||
namespace Tensorflow.Operations.Regularizers | ||
{ | ||
public class L1 : IRegularizer | ||
{ | ||
float _l1; | ||
private readonly Dictionary<string, object> _config; | ||
|
||
public string ClassName => "L1"; | ||
public virtual IDictionary<string, object> Config => _config; | ||
|
||
public L1(float l1 = 0.01f) | ||
{ | ||
// l1 = 0.01 if l1 is None else l1 | ||
// validate_float_arg(l1, name = "l1") | ||
// self.l1 = ops.convert_to_tensor(l1) | ||
this._l1 = l1; | ||
|
||
_config = new(); | ||
_config["l1"] = _l1; | ||
} | ||
|
||
|
||
public Tensor Apply(RegularizerArgs args) | ||
{ | ||
//return self.l1 * ops.sum(ops.absolute(x)) | ||
return _l1 * math_ops.reduce_sum(math_ops.abs(args.X)); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
using System; | ||
|
||
using Tensorflow.Keras; | ||
|
||
namespace Tensorflow.Operations.Regularizers | ||
{ | ||
public class L1L2 : IRegularizer | ||
{ | ||
float _l1; | ||
float _l2; | ||
private readonly Dictionary<string, object> _config; | ||
|
||
public string ClassName => "L1L2"; | ||
public virtual IDictionary<string, object> Config => _config; | ||
|
||
public L1L2(float l1 = 0.0f, float l2 = 0.0f) | ||
{ | ||
//l1 = 0.0 if l1 is None else l1 | ||
//l2 = 0.0 if l2 is None else l2 | ||
// validate_float_arg(l1, name = "l1") | ||
// validate_float_arg(l2, name = "l2") | ||
|
||
// self.l1 = l1 | ||
// self.l2 = l2 | ||
this._l1 = l1; | ||
this._l2 = l2; | ||
|
||
_config = new(); | ||
_config["l1"] = l1; | ||
_config["l2"] = l2; | ||
} | ||
|
||
public Tensor Apply(RegularizerArgs args) | ||
{ | ||
//regularization = ops.convert_to_tensor(0.0, dtype = x.dtype) | ||
//if self.l1: | ||
// regularization += self.l1 * ops.sum(ops.absolute(x)) | ||
//if self.l2: | ||
// regularization += self.l2 * ops.sum(ops.square(x)) | ||
//return regularization | ||
|
||
Tensor regularization = tf.constant(0.0, args.X.dtype); | ||
regularization += _l1 * math_ops.reduce_sum(math_ops.abs(args.X)); | ||
regularization += _l2 * math_ops.reduce_sum(math_ops.square(args.X)); | ||
return regularization; | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
using System; | ||
|
||
using Tensorflow.Keras; | ||
|
||
namespace Tensorflow.Operations.Regularizers | ||
{ | ||
public class L2 : IRegularizer | ||
{ | ||
float _l2; | ||
private readonly Dictionary<string, object> _config; | ||
|
||
public string ClassName => "L2"; | ||
public virtual IDictionary<string, object> Config => _config; | ||
|
||
public L2(float l2 = 0.01f) | ||
{ | ||
// l2 = 0.01 if l2 is None else l2 | ||
// validate_float_arg(l2, name = "l2") | ||
// self.l2 = l2 | ||
this._l2 = l2; | ||
|
||
_config = new(); | ||
_config["l2"] = _l2; | ||
} | ||
|
||
|
||
public Tensor Apply(RegularizerArgs args) | ||
{ | ||
//return self.l2 * ops.sum(ops.square(x)) | ||
return _l2 * math_ops.reduce_sum(math_ops.square(args.X)); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,51 @@ | ||
namespace Tensorflow.Keras | ||
using Tensorflow.Operations.Regularizers; | ||
|
||
namespace Tensorflow.Keras | ||
{ | ||
public class Regularizers | ||
public class Regularizers: IRegularizerApi | ||
{ | ||
private static Dictionary<string, IRegularizer> _nameActivationMap; | ||
|
||
public IRegularizer l1(float l1 = 0.01f) | ||
=> new L1(l1); | ||
public IRegularizer l2(float l2 = 0.01f) | ||
=> new L2(l2); | ||
|
||
//From TF source | ||
//# The default value for l1 and l2 are different from the value in l1_l2 | ||
//# for backward compatibility reason. Eg, L1L2(l2=0.1) will only have l2 | ||
//# and no l1 penalty. | ||
public IRegularizer l1l2(float l1 = 0.00f, float l2 = 0.00f) | ||
=> new L1L2(l1, l2); | ||
|
||
static Regularizers() | ||
{ | ||
public IRegularizer l2(float l2 = 0.01f) | ||
=> new L2(l2); | ||
_nameActivationMap = new Dictionary<string, IRegularizer>(); | ||
_nameActivationMap["L1"] = new L1(); | ||
_nameActivationMap["L1"] = new L2(); | ||
_nameActivationMap["L1"] = new L1L2(); | ||
} | ||
|
||
public IRegularizer L1 => l1(); | ||
|
||
public IRegularizer L2 => l2(); | ||
|
||
public IRegularizer L1L2 => l1l2(); | ||
|
||
public IRegularizer GetRegularizerFromName(string name) | ||
{ | ||
if (name == null) | ||
{ | ||
throw new Exception($"Regularizer name cannot be null"); | ||
} | ||
if (!_nameActivationMap.TryGetValue(name, out var res)) | ||
{ | ||
throw new Exception($"Regularizer {name} not found"); | ||
} | ||
else | ||
{ | ||
return res; | ||
} | ||
} | ||
} | ||
} |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters