Skip to content

Commit

Permalink
Merge pull request #104 from jajimer/develop
Browse files Browse the repository at this point in the history
 MLFLOW_TRACKING_URI bug solved
  • Loading branch information
AlejandroCN7 authored Dec 17, 2021
2 parents ad47648 + 88d5354 commit 9ae18e7
Showing 1 changed file with 51 additions and 11 deletions.
62 changes: 51 additions & 11 deletions DRL_battery.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,25 @@
'-sto',
action='store_true',
dest='remote_store',
help='Determine if sinergym output will be sent to a common resource')
help='Determine if sinergym output will be sent to a Google Cloud Storage Bucket.')
parser.add_argument(
'--mlflow_store',
'-mlflow',
action='store_true',
dest='mlflow_store',
help='Determine if sinergym output will be sent to a mlflow artifact storage')
parser.add_argument(
'--group_name',
'-group',
type=str,
dest='group_name',
help='This field indicate instance group name')
parser.add_argument(
'--auto_delete',
'-del',
action='store_true',
dest='auto_delete',
help='If is a GCE instance and this flag is active, that instance will be removed from GCP.')

parser.add_argument('--learning_rate', '-lr', type=float, default=.0007)
parser.add_argument('--gamma', '-g', type=float, default=.99)
Expand All @@ -146,6 +158,15 @@
if args.seed:
name += '-seed_' + str(args.seed)
name += '(' + experiment_date + ')'
# Check if MLFLOW_TRACKING_URI is defined
if os.environ.get('MLFLOW_TRACKING_URI') is not None:
# Check ping to server
mlflow_ip = os.environ.get(
'MLFLOW_TRACKING_URI').split('/')[-1].split(':')[0]
# If server is not valid, setting default local path to mlflow
response = os.system("ping -c 1 " + mlflow_ip)
if response != 0:
mlflow.set_tracking_uri('file://' + os.getcwd() + '/mlruns')
# MLflow track
with mlflow.start_run(run_name=name):
# Log experiment params
Expand All @@ -164,7 +185,7 @@
mlflow.log_param('evaluation-length', args.eval_length)
mlflow.log_param('log-interval', args.log_interval)
mlflow.log_param('seed', args.seed)
mlflow.log_param('remote-store', bool(args.seed))
mlflow.log_param('remote-store', bool(args.remote_store))

mlflow.log_param('learning_rate', args.learning_rate)
mlflow.log_param('n_steps', args.n_steps)
Expand Down Expand Up @@ -340,7 +361,24 @@
log_interval=args.log_interval)
model.save(env.simulator._env_working_dir_parent + '/' + name)

# Store all results if remote_store flag is True
# If mlflow artifacts store is active
if args.mlflow_store:
# Code for send output and tensorboard to mlflow artifacts.
mlflow.log_artifacts(
local_dir=env.simulator._env_working_dir_parent,
artifact_path=name + '/')
if args.evaluation:
mlflow.log_artifacts(
local_dir='best_model/' + name + '/',
artifact_path='best_model/' + name + '/')
# If tensorboard is active (in local) we should send to mlflow
if args.tensorboard and 'gs://experiments-storage' not in args.tensorboard:
mlflow.log_artifacts(
local_dir=args.tensorboard + '/' + name + '/',
artifact_path=os.path.abspath(args.tensorboard).split('/')[-1] + '/' + name + '/')

# Store all results if remote_store flag is True (Google Cloud Bucket for
# experiments)
if args.remote_store:
# Initiate Google Cloud client
client = gcloud.init_storage_client()
Expand All @@ -350,18 +388,19 @@
src_path=env.simulator._env_working_dir_parent,
dest_bucket_name='experiments-storage',
dest_path=name)
if args.tensorboard:
gcloud.upload_to_bucket(
client,
src_path=args.tensorboard + '/' + name + '/',
dest_bucket_name='experiments-storage',
dest_path=os.path.abspath(args.tensorboard).split('/')[-1] + '/' + name + '/')
if args.evaluation:
gcloud.upload_to_bucket(
client,
src_path='best_model/' + name + '/',
dest_bucket_name='experiments-storage',
dest_path='best_model/' + name + '/')
# If tensorboard is active (in local) we should send to bucket
if args.tensorboard and 'gs://experiments-storage' not in args.tensorboard:
gcloud.upload_to_bucket(
client,
src_path=args.tensorboard + '/' + name + '/',
dest_bucket_name='experiments-storage',
dest_path=os.path.abspath(args.tensorboard).split('/')[-1] + '/' + name + '/')
# gcloud.upload_to_bucket(
# client,
# src_path='mlruns/',
Expand All @@ -371,7 +410,8 @@
# End mlflow run
mlflow.end_run()

# If it is a Google Cloud VM, shutdown remote machine when ends
if args.group_name:
# If it is a Google Cloud VM and experiment flag auto_delete has been
# activated, shutdown remote machine when ends
if args.group_name and args.auto_delete:
token = gcloud.get_service_account_token()
gcloud.delete_instance_MIG_from_container(args.group_name, token)

0 comments on commit 9ae18e7

Please sign in to comment.