-
Notifications
You must be signed in to change notification settings - Fork 0
/
MLP.py
32 lines (26 loc) · 888 Bytes
/
MLP.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch.nn as nn
class MLP(nn.Module) :
# input_Size, hidden_size, output_size initialize
def __init__(self, input_size=784, hidden_size=256, output_size=10) :
super(MLP, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
# input Layer
nn.Linear(input_size, hidden_size),
# Activation Function
nn.ReLU(),
# nn.Sigmoid(),
# nn.Tanh(),
# hidden Layer
nn.Linear(hidden_size, hidden_size),
# Activation Function
nn.ReLU(),
# nn.Sigmoid(),
# nn.Tanh(),
# output Layer
nn.Linear(hidden_size, output_size),
)
def forward(self, x) :
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits