diff --git a/src/train.py b/src/train.py index 8ec2a65..cc12b62 100644 --- a/src/train.py +++ b/src/train.py @@ -18,7 +18,6 @@ from accelerate import Accelerator, ProfileKwargs from accelerate.utils import ProjectConfiguration -import subprocess import logging import ujson import os @@ -233,7 +232,6 @@ def trace_handler(prof): proj_drop_rate=config['dropout'], fixed_dropout_depth=config['fixed_dropout_depth'], ) - # elif network == 'prototype': # model = OpticalTransformer( # name='Prototype', @@ -480,7 +478,7 @@ def trace_handler(prof): accelerator.save_state(config['checkpointdir'] / 'last_state') accelerator.end_training() - subprocess.call("ray stop --force", shell=True) + return def train_model( @@ -651,6 +649,7 @@ def train_model( ) result = trainer.fit() + return result def eval_model(