-
Notifications
You must be signed in to change notification settings - Fork 1
/
solubility_model_building.py
49 lines (38 loc) · 1.54 KB
/
solubility_model_building.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import pandas as pd
from sklearn import linear_model
from sklearn.metrics import mean_squared_error, r2_score
import matplotlib.pyplot as plt
import numpy as np
import pickle
# Download and Read the Dataset
delaney_with_descriptors_url = 'https://raw.githubusercontent.com/dataprofessor/data/master/delaney_solubility_with_descriptors.csv'
dataset = pd.read_csv(delaney_with_descriptors_url)
# Split Traing Data and Labels
X = dataset.drop(['logS'], axis=1)
Y = dataset.iloc[:,-1]
# Define the Linear Regression Model and Train it using Training Data X
model = linear_model.LinearRegression()
model.fit(X, Y)
# Predictions
Y_pred = model.predict(X)
# Calculate the Performance of the model
print('Coefficients:', model.coef_)
print('Intercept:', model.intercept_)
print('Mean squared error (MSE): %.2f'
% mean_squared_error(Y, Y_pred))
print('Coefficient of determination (R^2): %.2f'
% r2_score(Y, Y_pred))
# Print Model's Equation, using the Coefficients that were found
print('LogS = %.2f %.2f LogP %.4f MW + %.4f RB %.2f AP' % (model.intercept_, model.coef_[0], model.coef_[1], model.coef_[2], model.coef_[3] ) )
# Data Visualization
plt.figure(figsize=(5,5))
plt.scatter(x=Y, y=Y_pred, c="#7CAE00", alpha=0.3)
# Add trendline
# https://stackoverflow.com/questions/26447191/how-to-add-trendline-in-python-matplotlib-dot-scatter-graphs
z = np.polyfit(Y, Y_pred, 1)
p = np.poly1d(z)
plt.plot(Y,p(Y),"#F8766D")
plt.ylabel('Predicted LogS')
plt.xlabel('Experimental LogS')
# Save Model as Pickle Object
pickle.dump(model, open('solubility_model.pkl', 'wb'))