How to export user-defined metadata of experiment during training #7888
-
In the context of using Determined AI for training a PyTorch model, I am dynamically creating data splits (train and validation) during the initialization of my PyTorchTrial. These splits are created randomly each time I run the experiment. I am looking to retrieve these splits post-training for analysis in my developer workspace, without resorting to uploading the meta-data to cloud storage. Here is a snippet of how the data splits are created:
My question is, how can I retrieve the data splits after the experiment has run? Is there a way to export this metadata so it can be accessed easily outside of the Determined environment? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
You can use a import torch
import pathlib
from determined.pytorch import PyTorchCallback, PyTorchTrial
class DeterminedMain(PyTorchTrial):
def __init__(self, context):
self.train_indices = ...
self.val_indices = ...
def build_callbacks(self) -> Dict[str, PyTorchCallback]:
return {"save_indices_callback": SaveIndicesCallback(self.train_indices, self.val_indices)}
class SaveIndicesCallback(PyTorchCallback):
def __init__(self, train_indices, val_indices) -> None:
self.train_indices = train_indices
self.val_indices = val_indices
super().__init__()
def on_checkpoint_write_end(self, checkpoint_dir: str) -> None:
print(f"checkpoint dir: {checkpoint_dir}")
torch.save(self.train_indices, pathlib.Path(checkpoint_dir) / 'train_indices.pt')
torch.save(self.val_indices, pathlib.Path(checkpoint_dir) / 'val_indices.pt') Then later, download a specific checkpoint, and load the indices: from determined.experimental import client
import torch
trial_num = 1234
trial = client.get_trial(trial_num)
checkpoint = trial.top_checkpoint()
checkpoint_dir = checkpoint.download()
train_indices = torch.load(checkpoint_dir+'/train_indices.pt')
val_indices = torch.load(checkpoint_dir+'/val_indices.pt') |
Beta Was this translation helpful? Give feedback.
You can use a
PyTorchCallback
to save arbitrary data to the checkpoint directory: