Skip to content

Commit

Permalink
Update train/predict methods.
Browse files Browse the repository at this point in the history
  • Loading branch information
itsluketwist committed Aug 11, 2024
1 parent bad5440 commit 3c6ec35
Show file tree
Hide file tree
Showing 9 changed files with 740 additions and 836 deletions.
Binary file added output/pretrained/te.pth
Binary file not shown.
18 changes: 10 additions & 8 deletions predict.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
"- A csv file of previous game data is needed before predictions can begin.\n",
"- To predict the result of the upcoming game between HomeTeam and AwayTeam, the CSV needs to \n",
"contain the results and statistics of each teams previous games.\n",
"- The first record in the csv should contain the oldest game data, and the last record should game data closest to the game to predict.\n",
"- The first record in the csv should contain the oldest game data, and the last record should contain game data closest to the game to predict.\n",
"- The csv needs the columns listed below, where the `home` prefix refers to information about previous games of HomeTeam, and similarly `away` prefix for AwayTeam.\n",
"- By default, the code expects the csv to have data from 8 previous games.\n",
"- Example csv files are available in `data/predict_csv/`."
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -155,14 +156,14 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([1, 10, 116])\n"
"torch.Size([1, 8, 116])\n"
]
}
],
Expand All @@ -179,7 +180,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"metadata": {},
"outputs": [
{
Expand All @@ -195,23 +196,24 @@
"\n",
"import torch\n",
"\n",
"model_path = \"output/trained/lstm.pth\"\n",
"model_type = \"lstm\" # choose between lstm, te\n",
"model_path = f\"output/pretrained/{model_type}.pth\"\n",
"model = torch.load(f=model_path)\n",
"\n",
"print(model)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The HOME team will win.\n",
"Home win prediction: 0.7150542736053467\n"
"Home win prediction: 0.9862344861030579\n"
]
}
],
Expand Down
2 changes: 2 additions & 0 deletions requirements.lock
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,8 @@ torch==2.3.1
# -r requirements.txt
# pytorch-tcn
# torchvision
torchinfo==1.8.0
# via -r requirements.txt
torchvision==0.18.1
# via -r requirements.txt
tornado==6.4.1
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@ pre-commit
pytorch-tcn
scikit-learn
torch
torchinfo
torchvision
tqdm
15 changes: 12 additions & 3 deletions src/loader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import pandas as pd
import torch.nn.functional as F
from torch import Generator, Tensor
from torch.nn.functional import normalize
from torch.utils.data import DataLoader, Dataset, Subset, random_split


Expand All @@ -27,7 +27,7 @@ def __init__(
self,
data_file: str,
verbose: bool = False,
sequence_len: int = 10,
sequence_len: int = 8,
normalize: bool = False,
as_sequence: bool = True,
**kwargs,
Expand All @@ -47,7 +47,7 @@ def __getitem__(self, idx: int) -> tuple[Tensor, str | int]:
sequence = Tensor(item["data"])[-self.sequence_len :]

if self.normalize:
sequence = normalize(sequence)
sequence = F.normalize(sequence)

if not self.as_sequence:
sequence = sequence.reshape([self.sequence_len * self.vector_len])
Expand All @@ -72,6 +72,7 @@ def get_train_dataloader(
dataset_class: Dataset = GameSequenceDataset,
sequence_len: int = 10,
as_sequence: bool = True,
normalize: bool = True,
) -> tuple[DataLoader, DataLoader]:
"""
Create dataloaders for training a model with NBA game data.
Expand All @@ -84,6 +85,7 @@ def get_train_dataloader(
dataset_class: Dataset = GameSequenceDataset
sequence_len: int = 10
as_sequence: bool = True
normalize: bool = True
Returns
-------
Expand All @@ -93,6 +95,7 @@ def get_train_dataloader(
data_file=parquet_file,
sequence_len=sequence_len,
as_sequence=as_sequence,
normalize=normalize,
)

num_train = int(len(raw_data) * train_split)
Expand Down Expand Up @@ -122,6 +125,7 @@ def get_eval_dataloader(
dataset_class: Dataset = GameSequenceDataset,
sequence_len: int = 10,
as_sequence: bool = True,
normalize: bool = True,
) -> DataLoader:
"""
Create a dataloader for evaluating a model with NBA game data.
Expand All @@ -132,6 +136,7 @@ def get_eval_dataloader(
dataset_class: Dataset = GameSequenceDataset
sequence_len: int = 10
as_sequence: bool = True
normalize: bool = True
Returns
-------
Expand All @@ -141,6 +146,7 @@ def get_eval_dataloader(
data_file=parquet_file,
sequence_len=sequence_len,
as_sequence=as_sequence,
normalize=normalize,
)
return DataLoader(
dataset=raw_data,
Expand All @@ -154,6 +160,7 @@ def get_sample_dataloader(
dataset_class: Dataset = GameSequenceDataset,
sequence_len: int = 10,
as_sequence: bool = True,
normalize: bool = True,
) -> DataLoader:
"""
Create a dataloader for providing sample NBA game data.
Expand All @@ -165,6 +172,7 @@ def get_sample_dataloader(
dataset_class: Dataset = GameSequenceDataset
sequence_len: int = 10
as_sequence: bool = True
normalize: bool = True
Returns
-------
Expand All @@ -175,6 +183,7 @@ def get_sample_dataloader(
verbose=True,
sequence_len=sequence_len,
as_sequence=as_sequence,
normalize=normalize,
)
idxs = np.random.choice(range(0, len(raw_data)), size=(count,))
_subset = Subset(raw_data, idxs)
Expand Down
11 changes: 10 additions & 1 deletion src/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ def plot_history(
history: History,
model_name: str,
save_location: str = "output",
title: str | None = None,
):
"""
Plot the training history curves.
Expand All @@ -20,6 +21,8 @@ def plot_history(
Name of the model being plotted, used in plot title and saved filename.
save_location: str = "output"
Where to save the plot.
title: str | None = None
Whether to override the default title.
"""
# configure matplotlib
matplotlib.use("Agg")
Expand All @@ -28,7 +31,13 @@ def plot_history(
x_ticks = [x + 1 for x in range(len(history.train_accuracy))]

fig = plt.figure(figsize=(12, 6))
plt.suptitle(f"{model_name.capitalize()} Training ({filename_datetime()})", size=16)

if title is not None:
plt.suptitle(title)
else:
plt.suptitle(
f"{model_name.capitalize()} Training ({filename_datetime()})", size=16
)

# plot loss on training and validation sets
_ = fig.add_subplot(1, 2, 1)
Expand Down
21 changes: 20 additions & 1 deletion src/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import logging

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Module

from src import models
from src.utils import device


Expand Down Expand Up @@ -34,6 +36,9 @@ def make_prediction(
model = model.to(device)
data = data.to(device)

if isinstance(model, models.TE):
data = data.reshape([1, len(data[0]) * len(data[0][0])])

# switch off autograd for
with torch.no_grad():
batch_pred = model(data)
Expand All @@ -51,6 +56,8 @@ def make_prediction(
def load_record_from_csv(
file_path: str,
has_header_row: bool = True,
sequence_len: int = 8,
normalize: bool = True,
) -> Tensor:
"""
Utility function to load game result data from a csv into a Tensor, ready for predictions.
Expand All @@ -61,6 +68,10 @@ def load_record_from_csv(
Path to csv file to load.
has_header_row: bool = True
Whether the csv file has a header row.
sequence_len: int = 8
What sequence length to load and use.
normalize: bool = True
Whether to normalize the data after loading.
Returns
-------
Expand All @@ -77,4 +88,12 @@ def load_record_from_csv(
for row in reader:
data.append([float(x) for x in row])

return Tensor([data])
if len(data) == sequence_len:
break

_tensor = Tensor([data])

if normalize:
return F.normalize(_tensor)
else:
return _tensor
Loading

0 comments on commit 3c6ec35

Please sign in to comment.