Skip to content

Commit

Permalink
Merge pull request #28 from marcpinet/feat-more-layers
Browse files Browse the repository at this point in the history
Feat more layers
  • Loading branch information
marcpinet authored Apr 24, 2024
2 parents d64b130 + 513e98e commit 5b81fea
Show file tree
Hide file tree
Showing 12 changed files with 723 additions and 226 deletions.
3 changes: 0 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
# Dist generator
dist_gen.bat

# Datasets formats
*.csv
*.npz
Expand Down
52 changes: 26 additions & 26 deletions examples/classification-regression/mnist_loading_saved_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-21T12:52:21.706906Z",
"start_time": "2024-04-21T12:52:18.726598200Z"
"end_time": "2024-04-23T23:32:44.879695500Z",
"start_time": "2024-04-23T23:32:41.806868Z"
}
},
"outputs": [],
Expand All @@ -47,8 +47,8 @@
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-21T12:52:21.915810200Z",
"start_time": "2024-04-21T12:52:21.706906Z"
"end_time": "2024-04-23T23:32:45.056739600Z",
"start_time": "2024-04-23T23:32:44.879695500Z"
}
},
"outputs": [],
Expand All @@ -68,8 +68,8 @@
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-21T12:52:22.072282500Z",
"start_time": "2024-04-21T12:52:21.916810900Z"
"end_time": "2024-04-23T23:32:45.166846Z",
"start_time": "2024-04-23T23:32:45.059739600Z"
}
},
"outputs": [],
Expand All @@ -92,8 +92,8 @@
"execution_count": 4,
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-21T12:52:22.233389700Z",
"start_time": "2024-04-21T12:52:22.073284800Z"
"end_time": "2024-04-23T23:32:45.285935300Z",
"start_time": "2024-04-23T23:32:45.167845600Z"
}
},
"outputs": [],
Expand All @@ -113,8 +113,8 @@
"execution_count": 5,
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-21T12:52:22.258467800Z",
"start_time": "2024-04-21T12:52:22.234388100Z"
"end_time": "2024-04-23T23:32:45.329886Z",
"start_time": "2024-04-23T23:32:45.288843800Z"
}
},
"outputs": [],
Expand All @@ -134,16 +134,16 @@
"execution_count": 6,
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-21T12:52:22.323518700Z",
"start_time": "2024-04-21T12:52:22.257467100Z"
"end_time": "2024-04-23T23:32:45.374527900Z",
"start_time": "2024-04-23T23:32:45.314964200Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Validation Accuracy: 0.899\n"
"Validation Accuracy: 0.9738333333333333\n"
]
}
],
Expand All @@ -165,27 +165,27 @@
"execution_count": 7,
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-21T12:52:22.393768500Z",
"start_time": "2024-04-21T12:52:22.318518600Z"
"end_time": "2024-04-23T23:32:45.444303500Z",
"start_time": "2024-04-23T23:32:45.375529400Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test Accuracy: 0.8863\n",
"Test Accuracy: 0.9549\n",
"Confusion Matrix:\n",
"[[ 937 0 0 1 11 7 2 18 1 3]\n",
" [ 0 1097 3 4 0 3 2 4 19 3]\n",
" [ 13 9 858 36 26 1 23 38 16 12]\n",
" [ 8 6 18 899 2 33 2 16 12 14]\n",
" [ 1 0 1 0 944 0 7 2 1 26]\n",
" [ 19 0 0 82 30 701 12 5 23 20]\n",
" [ 18 2 0 0 70 15 849 1 2 1]\n",
" [ 0 9 10 5 15 0 0 945 4 40]\n",
" [ 6 22 3 3 37 26 9 2 803 63]\n",
" [ 3 2 1 11 137 2 0 15 8 830]]\n"
"[[ 958 0 3 0 0 3 7 2 4 3]\n",
" [ 0 1117 1 6 0 1 1 2 6 1]\n",
" [ 5 1 983 11 3 0 4 16 9 0]\n",
" [ 2 0 10 959 0 13 1 7 8 10]\n",
" [ 2 1 6 0 909 0 6 0 0 58]\n",
" [ 9 1 0 20 0 838 8 2 3 11]\n",
" [ 10 4 4 1 5 6 917 0 10 1]\n",
" [ 1 8 10 6 0 0 0 982 0 21]\n",
" [ 5 3 9 7 4 6 5 7 917 11]\n",
" [ 3 5 3 5 10 4 2 7 1 969]]\n"
]
}
],
Expand Down
109 changes: 57 additions & 52 deletions examples/classification-regression/simple_cancer_binary.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-21T13:22:53.026361300Z",
"start_time": "2024-04-21T13:22:52.339942200Z"
"end_time": "2024-04-24T00:47:40.722532900Z",
"start_time": "2024-04-24T00:47:40.066546300Z"
}
},
"outputs": [],
Expand All @@ -31,7 +31,7 @@
"\n",
"from neuralnetlib.preprocessing import StandardScaler\n",
"from neuralnetlib.activations import Sigmoid, ReLU\n",
"from neuralnetlib.layers import Input, Activation, Dense\n",
"from neuralnetlib.layers import Input, Activation, Dense, BatchNormalization\n",
"from neuralnetlib.losses import BinaryCrossentropy\n",
"from neuralnetlib.model import Model\n",
"from neuralnetlib.optimizers import Adam\n",
Expand All @@ -51,8 +51,8 @@
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-21T13:22:53.040903100Z",
"start_time": "2024-04-21T13:22:53.026361300Z"
"end_time": "2024-04-24T00:47:40.766058Z",
"start_time": "2024-04-24T00:47:40.722532900Z"
}
},
"outputs": [],
Expand All @@ -73,8 +73,8 @@
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-21T13:22:53.054442700Z",
"start_time": "2024-04-21T13:22:53.042408400Z"
"end_time": "2024-04-24T00:47:40.767058100Z",
"start_time": "2024-04-24T00:47:40.754053500Z"
}
},
"outputs": [],
Expand All @@ -99,8 +99,8 @@
"execution_count": 4,
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-21T13:22:53.059957800Z",
"start_time": "2024-04-21T13:22:53.048922300Z"
"end_time": "2024-04-24T00:47:40.776573400Z",
"start_time": "2024-04-24T00:47:40.761057900Z"
}
},
"outputs": [],
Expand All @@ -117,6 +117,7 @@
"\n",
"for _ in range(num_hidden_layers - 1):\n",
" model.add(Dense(hidden_neurons, weights_init='he', random_state=42))\n",
" model.add(BatchNormalization())\n",
" model.add(Activation(ReLU()))\n",
"\n",
"model.add(Dense(output_neurons, random_state=42))\n",
Expand All @@ -135,8 +136,8 @@
"execution_count": 5,
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-21T13:22:53.085516700Z",
"start_time": "2024-04-21T13:22:53.058950900Z"
"end_time": "2024-04-24T00:47:40.787571200Z",
"start_time": "2024-04-24T00:47:40.771564800Z"
}
},
"outputs": [
Expand All @@ -150,15 +151,19 @@
"Layer 2: Dense(units=100)\n",
"Layer 3: Activation(ReLU)\n",
"Layer 4: Dense(units=100)\n",
"Layer 5: Activation(ReLU)\n",
"Layer 6: Dense(units=100)\n",
"Layer 7: Activation(ReLU)\n",
"Layer 8: Dense(units=100)\n",
"Layer 5: BatchNormalization(momentum=0.99, epsilon=1e-08)\n",
"Layer 6: Activation(ReLU)\n",
"Layer 7: Dense(units=100)\n",
"Layer 8: BatchNormalization(momentum=0.99, epsilon=1e-08)\n",
"Layer 9: Activation(ReLU)\n",
"Layer 10: Dense(units=100)\n",
"Layer 11: Activation(ReLU)\n",
"Layer 12: Dense(units=1)\n",
"Layer 13: Activation(Sigmoid)\n",
"Layer 11: BatchNormalization(momentum=0.99, epsilon=1e-08)\n",
"Layer 12: Activation(ReLU)\n",
"Layer 13: Dense(units=100)\n",
"Layer 14: BatchNormalization(momentum=0.99, epsilon=1e-08)\n",
"Layer 15: Activation(ReLU)\n",
"Layer 16: Dense(units=1)\n",
"Layer 17: Activation(Sigmoid)\n",
"-------------------------------------------------\n",
"Loss function: BinaryCrossentropy\n",
"Optimizer: Adam(learning_rate=0.0001, beta_1=0.9, beta_2=0.999, epsilon=1e-08)\n",
Expand All @@ -184,40 +189,40 @@
"execution_count": 6,
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-21T13:22:53.842873Z",
"start_time": "2024-04-21T13:22:53.081003300Z"
"end_time": "2024-04-24T00:47:41.813204800Z",
"start_time": "2024-04-24T00:47:40.788571700Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[==============================] 100% Epoch 1/20 - loss: 0.6860 - accuracy_score: 0.6308 - 0.04s\n",
"[==============================] 100% Epoch 2/20 - loss: 0.6677 - accuracy_score: 0.7055 - 0.03s\n",
"[==============================] 100% Epoch 3/20 - loss: 0.6323 - accuracy_score: 0.8066 - 0.04s\n",
"[==============================] 100% Epoch 4/20 - loss: 0.5702 - accuracy_score: 0.8901 - 0.05s\n",
"[==============================] 100% Epoch 5/20 - loss: 0.4731 - accuracy_score: 0.9143 - 0.05s\n",
"[==============================] 100% Epoch 6/20 - loss: 0.3540 - accuracy_score: 0.9297 - 0.04s\n",
"[==============================] 100% Epoch 7/20 - loss: 0.2499 - accuracy_score: 0.9429 - 0.04s\n",
"[==============================] 100% Epoch 8/20 - loss: 0.1816 - accuracy_score: 0.9473 - 0.04s\n",
"[==============================] 100% Epoch 9/20 - loss: 0.1418 - accuracy_score: 0.9648 - 0.05s\n",
"[==============================] 100% Epoch 10/20 - loss: 0.1182 - accuracy_score: 0.9714 - 0.04s\n",
"[==============================] 100% Epoch 11/20 - loss: 0.1034 - accuracy_score: 0.9758 - 0.03s\n",
"[==============================] 100% Epoch 12/20 - loss: 0.0927 - accuracy_score: 0.9758 - 0.03s\n",
"[==============================] 100% Epoch 13/20 - loss: 0.0844 - accuracy_score: 0.9802 - 0.03s\n",
"[==============================] 100% Epoch 14/20 - loss: 0.0777 - accuracy_score: 0.9802 - 0.03s\n",
"[==============================] 100% Epoch 15/20 - loss: 0.0722 - accuracy_score: 0.9824 - 0.03s\n",
"[==============================] 100% Epoch 16/20 - loss: 0.0675 - accuracy_score: 0.9846 - 0.03s\n",
"[==============================] 100% Epoch 17/20 - loss: 0.0635 - accuracy_score: 0.9890 - 0.03s\n",
"[==============================] 100% Epoch 18/20 - loss: 0.0600 - accuracy_score: 0.9890 - 0.03s\n",
"[==============================] 100% Epoch 19/20 - loss: 0.0569 - accuracy_score: 0.9890 - 0.04s\n",
"[==============================] 100% Epoch 20/20 - loss: 0.0542 - accuracy_score: 0.9912 - 0.03s\n"
"[==============================] 100% Epoch 1/20 - loss: 0.6905 - accuracy_score: 0.5363 - 0.06s\n",
"[==============================] 100% Epoch 2/20 - loss: 0.6785 - accuracy_score: 0.7209 - 0.05s\n",
"[==============================] 100% Epoch 3/20 - loss: 0.6621 - accuracy_score: 0.8462 - 0.05s\n",
"[==============================] 100% Epoch 4/20 - loss: 0.6433 - accuracy_score: 0.8857 - 0.05s\n",
"[==============================] 100% Epoch 5/20 - loss: 0.6219 - accuracy_score: 0.9011 - 0.05s\n",
"[==============================] 100% Epoch 6/20 - loss: 0.5981 - accuracy_score: 0.9099 - 0.05s\n",
"[==============================] 100% Epoch 7/20 - loss: 0.5717 - accuracy_score: 0.9143 - 0.05s\n",
"[==============================] 100% Epoch 8/20 - loss: 0.5433 - accuracy_score: 0.9275 - 0.04s\n",
"[==============================] 100% Epoch 9/20 - loss: 0.5139 - accuracy_score: 0.9275 - 0.04s\n",
"[==============================] 100% Epoch 10/20 - loss: 0.4846 - accuracy_score: 0.9253 - 0.05s\n",
"[==============================] 100% Epoch 11/20 - loss: 0.4565 - accuracy_score: 0.9209 - 0.04s\n",
"[==============================] 100% Epoch 12/20 - loss: 0.4308 - accuracy_score: 0.9231 - 0.05s\n",
"[==============================] 100% Epoch 13/20 - loss: 0.4077 - accuracy_score: 0.9275 - 0.05s\n",
"[==============================] 100% Epoch 14/20 - loss: 0.3877 - accuracy_score: 0.9363 - 0.05s\n",
"[==============================] 100% Epoch 15/20 - loss: 0.3708 - accuracy_score: 0.9385 - 0.05s\n",
"[==============================] 100% Epoch 16/20 - loss: 0.3571 - accuracy_score: 0.9385 - 0.05s\n",
"[==============================] 100% Epoch 17/20 - loss: 0.3464 - accuracy_score: 0.9407 - 0.06s\n",
"[==============================] 100% Epoch 18/20 - loss: 0.3382 - accuracy_score: 0.9385 - 0.06s\n",
"[==============================] 100% Epoch 19/20 - loss: 0.3317 - accuracy_score: 0.9363 - 0.05s\n",
"[==============================] 100% Epoch 20/20 - loss: 0.3268 - accuracy_score: 0.9385 - 0.04s\n"
]
}
],
"source": [
"model.train(x_train, y_train, epochs=20, batch_size=48, metrics=[accuracy_score], random_state=42)"
"model.fit(x_train, y_train, epochs=20, batch_size=48, metrics=[accuracy_score], random_state=42)"
]
},
{
Expand All @@ -232,16 +237,16 @@
"execution_count": 7,
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-21T13:22:53.857412200Z",
"start_time": "2024-04-21T13:22:53.843829400Z"
"end_time": "2024-04-24T00:47:41.814205700Z",
"start_time": "2024-04-24T00:47:41.794206700Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test loss: 0.06351246680217817\n"
"Test loss: 0.300304683196257\n"
]
}
],
Expand All @@ -262,8 +267,8 @@
"execution_count": 8,
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-21T13:22:53.863439400Z",
"start_time": "2024-04-21T13:22:53.852402800Z"
"end_time": "2024-04-24T00:47:41.815206100Z",
"start_time": "2024-04-24T00:47:41.805207700Z"
}
},
"outputs": [],
Expand All @@ -283,19 +288,19 @@
"execution_count": 9,
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-21T13:22:53.873465Z",
"start_time": "2024-04-21T13:22:53.859930800Z"
"end_time": "2024-04-24T00:47:41.826208500Z",
"start_time": "2024-04-24T00:47:41.813204800Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy: 0.9736842105263158\n",
"Precision: 0.9741062479117941\n",
"Recall: 0.9692460317460317\n",
"F1 Score: 0.9716700622635057\n"
"Accuracy: 0.9473684210526315\n",
"Precision: 0.9434523809523809\n",
"Recall: 0.9434523809523809\n",
"F1 Score: 0.9434523809523809\n"
]
}
],
Expand Down
Loading

0 comments on commit 5b81fea

Please sign in to comment.