-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.sh
64 lines (50 loc) · 1.36 KB
/
train.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#!/usr/bin/env bash
usage () {
echo "usage: train.sh [local | remote | tune ]
Use 'local' to train locally with a local data file, and 'train' to
run on ML Engine. For ML Engine jobs the train and valid directories must reside on GCS.
Examples:
# train locally
./train.sh local
# train on ML Engine with hparms.py
./train.sh remote
"
}
date
TIME=`date +"%Y%m%d_%H%M%S"`
BUCKET_NAME=YOUR_BUCKET
BUCKET=gs://${BUCKET_NAME}
DATAPATH=gs://${BUCKET_NAME}/data
WEIGHTS=gs://${BUCKET_NAME}/jobs/mst_training_remote_20190524_103506/weights/decoder.h5
LOCAL_WEIGHTS=./trainer/data/weights/weights.h5
if [[ $# < 1 ]]; then
usage
exit 1
fi
# set job vars
JOB_TYPE="$1"
EVAL="$2"
JOB_NAME=mst_training_${JOB_TYPE}_${TIME}
export JOB_NAME=${JOB_NAME}
REGION=europe-west1
if [[ ${JOB_TYPE} == "local" ]]; then
gcloud ml-engine local train \
--module-name trainer.train \
--package-path ./trainer \
-- \
--datapath trainer/data \
--job-dir trainer/jobs/${JOB_NAME}/ \
--weights ${LOCAL_WEIGHTS} \
elif [[ ${JOB_TYPE} == "remote" ]]; then
gcloud ml-engine jobs submit training ${JOB_NAME} \
--region ${REGION} \
--job-dir ${BUCKET}/jobs/${JOB_NAME}/ \
--module-name trainer.train \
--package-path ./trainer \
--config trainer/config/config_train.json \
-- \
--datapath ${DATAPATH} \
# --weights ${WEIGHTS} \
else
usage
fi