Skip to content

Commit

Permalink
Fixed TFT transform
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif committed Jul 7, 2021
1 parent 1a94965 commit cc7dea9
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 19 deletions.
9 changes: 7 additions & 2 deletions pts/model/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
from torch.utils import data
from torch.utils.data import DataLoader

from gluonts.env import env
from gluonts.core.component import validated
from gluonts.dataset.common import Dataset
from gluonts.model.estimator import Estimator
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.transform import SelectFields, Transformation
from gluonts.support.util import maybe_len

from pts import Trainer
from pts.model import get_module_forward_input_names
Expand Down Expand Up @@ -101,7 +103,9 @@ def train_model(
trained_net = self.create_training_network(self.trainer.device)

input_names = get_module_forward_input_names(trained_net)
training_instance_splitter = self.create_instance_splitter("training")

with env._let(max_idle_transforms=maybe_len(training_data) or 0):
training_instance_splitter = self.create_instance_splitter("training")
training_iter_dataset = TransformedIterableDataset(
dataset=training_data,
transform=transformation
Expand All @@ -124,7 +128,8 @@ def train_model(

validation_data_loader = None
if validation_data is not None:
validation_instance_splitter = self.create_instance_splitter("validation")
with env._let(max_idle_transforms=maybe_len(validation_data) or 0):
validation_instance_splitter = self.create_instance_splitter("validation")
validation_iter_dataset = TransformedIterableDataset(
dataset=validation_data,
transform=transformation
Expand Down
2 changes: 2 additions & 0 deletions pts/model/tft/tft_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import torch

from gluonts.core.component import validated
from gluonts.dataset.field_names import FieldName
from gluonts.model.forecast_generator import QuantileForecastGenerator
Expand Down Expand Up @@ -30,6 +31,7 @@
from pts import Trainer
from pts.model import PyTorchEstimator
from pts.model.utils import get_module_forward_input_names

from .tft_network import (
TemporalFusionTransformerPredictionNetwork,
TemporalFusionTransformerTrainingNetwork,
Expand Down
36 changes: 19 additions & 17 deletions pts/model/tft/tft_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
shift_timestamp,
target_transformation_length,
)
from gluonts.transform.sampler import InstanceSampler


class BroadcastTo(MapTransformation):
Expand Down Expand Up @@ -54,7 +55,7 @@ class TFTInstanceSplitter(InstanceSplitter):
@validated()
def __init__(
self,
instance_sampler,
instance_sampler: InstanceSampler,
past_length: int,
future_length: int,
target_field: str = FieldName.TARGET,
Expand All @@ -64,29 +65,30 @@ def __init__(
observed_value_field: str = FieldName.OBSERVED_VALUES,
lead_time: int = 0,
output_NTC: bool = True,
time_series_fields: Optional[List[str]] = None,
past_time_series_fields: Optional[List[str]] = None,
time_series_fields: List[str] = [],
past_time_series_fields: List[str] = [],
dummy_value: float = 0.0,
) -> None:

super().__init__(
target_field=target_field,
is_pad_field=is_pad_field,
start_field=start_field,
forecast_start_field=forecast_start_field,
instance_sampler=instance_sampler,
past_length=past_length,
future_length=future_length,
lead_time=lead_time,
output_NTC=output_NTC,
time_series_fields=time_series_fields,
dummy_value=dummy_value,
)

assert past_length > 0, "The value of `past_length` should be > 0"
assert future_length > 0, "The value of `future_length` should be > 0"

self.instance_sampler = instance_sampler
self.past_length = past_length
self.future_length = future_length
self.lead_time = lead_time
self.output_NTC = output_NTC
self.dummy_value = dummy_value

self.target_field = target_field
self.is_pad_field = is_pad_field
self.start_field = start_field
self.forecast_start_field = forecast_start_field
self.observed_value_field = observed_value_field

self.ts_fields = time_series_fields or []
self.past_ts_fields = past_time_series_fields or []
self.past_ts_fields = past_time_series_fields

def flatmap_transform(self, data: DataEntry, is_train: bool) -> Iterator[DataEntry]:
pl = self.future_length
Expand Down

0 comments on commit cc7dea9

Please sign in to comment.