From 4efa0a8881a5886a2eeb2e19d8a5157c3f68a32f Mon Sep 17 00:00:00 2001 From: lingbai-kong Date: Sat, 1 Jul 2023 13:44:54 +0800 Subject: [PATCH] add pad preprocessing for `imdb` dataset --- src/TensorFlowNET.Keras/Datasets/Imdb.cs | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/src/TensorFlowNET.Keras/Datasets/Imdb.cs b/src/TensorFlowNET.Keras/Datasets/Imdb.cs index 56b0d2a77..61ce39475 100644 --- a/src/TensorFlowNET.Keras/Datasets/Imdb.cs +++ b/src/TensorFlowNET.Keras/Datasets/Imdb.cs @@ -40,6 +40,8 @@ public DatasetPass load_data(string path = "imdb.npz", int oov_char= 2, int index_from = 3) { + if (maxlen == -1) throw new InvalidArgumentError("maxlen must be assigned."); + var dst = Download(); var lines = File.ReadAllLines(Path.Combine(dst, "imdb_train.txt")); @@ -51,7 +53,7 @@ public DatasetPass load_data(string path = "imdb.npz", x_train_string[i] = lines[i].Substring(2); } - var x_train = np.array(x_train_string); + var x_train = keras.preprocessing.sequence.pad_sequences(PraseData(x_train_string), maxlen: maxlen); File.ReadAllLines(Path.Combine(dst, "imdb_test.txt")); var x_test_string = new string[lines.Length]; @@ -62,7 +64,7 @@ public DatasetPass load_data(string path = "imdb.npz", x_test_string[i] = lines[i].Substring(2); } - var x_test = np.array(x_test_string); + var x_test = keras.preprocessing.sequence.pad_sequences(PraseData(x_test_string), maxlen: maxlen); return new DatasetPass { @@ -93,5 +95,23 @@ string Download() return dst; // return Path.Combine(dst, file_name); } + + protected IEnumerable PraseData(string[] x) + { + var data_list = new List(); + for (int i = 0; i < len(x); i++) + { + var list_string = x[i]; + var cleaned_list_string = list_string.Replace("[", "").Replace("]", "").Replace(" ", ""); + string[] number_strings = cleaned_list_string.Split(','); + int[] numbers = new int[number_strings.Length]; + for (int j = 0; j < number_strings.Length; j++) + { + numbers[j] = int.Parse(number_strings[j]); + } + data_list.Add(numbers); + } + return data_list; + } } }