-
Given X of shape [<n_samples>, <n_features>, <n_lookback>] and a trained TSForecaster learner based on PatchTST, is there a way to retrieve the embeddings for each of the samples? I found a related discussion that was closed Dec. 2022 for TST (see #560 ), however I have troubles finding the right shape when triggering the forward() of the PatchTST model. As PatchTST operates on patches, the forward implementation operates on the following shape: Does anyone have some suggestions on how to tackle this problem and retrieve the embeddings? PS: (I also modified the suggested method in the referenced discussion to fit PatchTST. I am not quite sure if this is correct but it seems to hang into the right place for embedding retrieval): |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Hi @MMayr96,
from tsai.basics import *
from fastai.callback.hook import Hook
fcst_history = 60
fcst_horizon = 3
df = get_forecasting_time_series("Sunspots")
X, y = prepare_forecasting_data(df, fcst_history=fcst_history, fcst_horizon=fcst_horizon, dtype=np.float32)
splits = get_forecasting_splits(df, fcst_history=fcst_history, fcst_horizon=fcst_horizon, test_size=235, show_plot=False)
batch_tfms = TSStandardize()
fcst = TSForecaster(X, y, splits=splits, path='models', batch_tfms=batch_tfms, bs=64, arch="PatchTST", metrics=mae, cbs=ShowGraph())
hook = Hook(fcst.model.model.backbone, lambda m,i,o: o, is_forward=True, detach=True, cpu=True)
xb,_ = fcst.dls.train.one_batch()
_ = fcst.model(xb)
embeddings = hook.stored
hook = Hook(fcst.model.model.backbone, lambda m,i,o: o, is_forward=True, detach=True, cpu=True)
embeddings = []
for xb,_ in fcst.dls.train:
_ = fcst.model(xb)
embeddings.append(hook.stored)
embeddings = torch.cat(embeddings) Is that what you were looking for? |
Beta Was this translation helpful? Give feedback.
Hi @MMayr96,
If by embeddings you mean the output of a given layer, I'd recommend you use fastai's Hook (see this to learn more about hooks).
You just need to decide which output you want to pull. If you want to pull the output of the backbone before it's passed to the model´s head, you could use something like this: