-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
55 lines (42 loc) · 1.5 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from request import Request
from model import LSTMModel
import matplotlib.pyplot as plt
import tkinter as tk
class Application(tk.Frame):
def __init__(self, master=None):
super().__init__(master)
self.master = master
self.master.geometry("300x200")
self.create_widgets()
def create_widgets(self):
self.L1 = tk.Label(self.master, text="Input Ticker Symbol")
self.L1.pack(side=tk.LEFT)
self.E1 = tk.Entry(self.master, bd=5)
self.E1.pack(side=tk.RIGHT)
self.B1 = tk.Button(self.master, text="Submit", command=lambda: self.runModel())
self.B1.pack(side=tk.BOTTOM)
def say_hi(self):
print(str(self.E1.get()).upper())
def runModel(self):
# get data
Req = Request()
data = Req.StockPrices(str(self.E1.get()).upper())
# print(data)
# put data into model
model = LSTMModel()
result = model.trainData(data, 5)
print(result)
# result.to_csv(index=True, archive_name='output.csv')
predictions = result['Predictions'].to_list()
closes = result['Close'].to_list()
dates = result.index.to_list()
plt.plot(dates, predictions, label="predictions")
plt.plot(dates, closes, label="price")
plt.legend()
plt.show()
def downloadCsv(self):
return
if __name__ == "__main__":
root = tk.Tk()
app = Application(master=root)
app.mainloop()