Skip to content

Commit

Permalink
changed static path logic in model files
Browse files Browse the repository at this point in the history
  • Loading branch information
ramankhurana committed Jan 11, 2024
1 parent 2bc1b9f commit c4e5381
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 9 deletions.
9 changes: 6 additions & 3 deletions models/Autoformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def __init__(self, configs):
self.label_len = configs.label_len
self.pred_len = configs.pred_len
self.output_attention = configs.output_attention
self.data = configs.data

self.use_static = True

self.static1 = configs.static=="static1"
Expand Down Expand Up @@ -66,9 +68,10 @@ def __init__(self, configs):
# static_raw = torch.tensor([1, 1, 2, 1, 2, 2, 1]) ## synthetic data for ETTh1

### this is for Divvy bikes
#self.static_raw = torch.tensor(np.load('auxutils/divvy_static.npy').tolist() ) ## static real data for Divvy Bikes

self.static_raw = torch.tensor(np.load('auxutils/M5_static.npy')[1].tolist() ) ## static real data for Divvy Bikes
if self.data == "Divvy":
self.static_raw = torch.tensor(np.load('auxutils/divvy_static.npy').tolist() ) ## static real data for Divvy Bikes
if self.data == "M5":
self.static_raw = torch.tensor(np.load('auxutils/M5_static.npy')[1].tolist() ) ## static real data for Divvy Bikes

#static_raw = static_raw.repeat((32,72,1)) ## for input it should 96, for output it should be 144
self.static_raw = self.static_raw.repeat((32,self.repeat_freq,1)) ## for input it should 96, for output it should be 144
Expand Down
9 changes: 6 additions & 3 deletions models/DLinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def __init__(self, configs, individual=False):
self.static4 = configs.static=="static4"
self.static6 = configs.static=="static6"
self.static7 = configs.static=="static7"

self.data = configs.data

print ("self.static1, self.static2, self.static4, self.static6, self.static7", self.static1, self.static2, self.static4, self.static6, self.static7)

## Raman code starts
Expand All @@ -43,8 +44,10 @@ def __init__(self, configs, individual=False):
if (self.use_static):

# static_raw = torch.tensor([1, 1, 2, 1, 2, 2, 1]) ## synthetic data for ETTh1
#self.static_raw = torch.tensor(np.load('auxutils/divvy_static.npy').tolist() ) ## static real data for Divvy Bikes
self.static_raw = torch.tensor(np.load('auxutils/M5_static.npy')[1].tolist() ) ## static real data for Divvy Bikes
if self.data == "Divvy":
self.static_raw = torch.tensor(np.load('auxutils/divvy_static.npy').tolist() ) ## static real data for Divvy Bikes
if self.data == "M5":
self.static_raw = torch.tensor(np.load('auxutils/M5_static.npy')[1].tolist() ) ## static real data for Divvy Bikes

#static_raw = static_raw.repeat((32,72,1)) ## for input it should 96, for output it should be 144
#static_raw = static_raw.repeat((32,144,1)) # for Auto and FED former ## for input it should 96, for output it should be 144
Expand Down
8 changes: 5 additions & 3 deletions models/FEDformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, configs, version='fourier', mode_select='random', modes=32):
self.static6 = configs.static=="static6"
self.static7 = configs.static=="static7"
self.repeat_freq = self.pred_len + self.label_len

self.data = configs.data
# Decomp
self.decomp = series_decomp(configs.moving_avg)

Expand All @@ -54,8 +54,10 @@ def __init__(self, configs, version='fourier', mode_select='random', modes=32):
print ("pred_len, label_len", self.pred_len, self.label_len, self.pred_len + self.label_len)

# static_raw = torch.tensor([1, 1, 2, 1, 2, 2, 1]) ## synthetic data for ETTh1
#self.static_raw = torch.tensor(np.load('auxutils/divvy_static.npy').tolist() ) ## static real data for Divvy Bikes
self.static_raw = torch.tensor(np.load('auxutils/M5_static.npy')[1].tolist() ) ## static real data for Divvy Bikes
if self.data == "Divvy":
self.static_raw = torch.tensor(np.load('auxutils/divvy_static.npy').tolist() ) ## static real data for Divvy Bikes
if self.data == "M5":
self.static_raw = torch.tensor(np.load('auxutils/M5_static.npy')[1].tolist() ) ## static real data for Divvy Bikes

#static_raw = static_raw.repeat((32,72,1)) ## for input it should 96, for output it should be 144

Expand Down

0 comments on commit c4e5381

Please sign in to comment.