-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathlinear_regresion_pytorch.py
82 lines (64 loc) · 1.76 KB
/
linear_regresion_pytorch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# %%
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import seaborn as sns
from skorch import NeuralNetRegressor
from sklearn.model_selection import GridSearchCV
# %%
cars_file = "https://gist.githubusercontent.com/noamross/e5d3e859aa0c794be10b/raw/b999fb4425b54c63cab088c0ce2c0d6ce961a563/cars.csv"
cars = pd.read_csv(cars_file)
cars.head()
# %%
sns.scatterplot(x="wt", y="mpg", data=cars)
sns.regplot(x="wt", y="mpg", data=cars)
# %%
X_list = cars.wt.values
X_np = np.array(X_list, dtype=np.float32).reshape(-1, 1)
y_list = cars.mpg.values
y_np = np.array(y_list, dtype=np.float32).reshape(-1, 1)
X = torch.from_numpy(X_np)
y_true = torch.from_numpy(y_np)
# %%
class LinearRegressionDataset(Dataset):
def __init__(self, X, y):
self.X = X
self.y = y
def __len__(self):
return len(self.X)
def __getitem__(self, idx):
return self.X[idx], self.y[idx]
# %%
class LinearRegressionTorch(nn.Module):
def __init__(self, input_size=1, output_size=1):
super(LinearRegressionTorch, self).__init__()
self.linear = nn.Linear(input_size, output_size)
def forward(self, x):
return self.linear(x)
# %%
input_dim = 1
output_dim = 1
model = LinearRegressionTorch(input_size=input_dim, output_size=output_dim)
model.train()
# %%
loss_fun = nn.MSELoss()
# %%
learning_rate = 0.02
# %%
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
# %%
net = NeuralNetRegressor(
LinearRegressionTorch,
max_epochs=10,
lr=0.1,
# Shuffle training data on each epoch
iterator_train__shuffle=True,
)
# %%n
net.set_params(train_split=False, verbose=0)
params = {
"lr": [0.02, 0.05, 0.08, 0.001],
"max_epochs": [200, 600, 80, 90],
}