-
Notifications
You must be signed in to change notification settings - Fork 1
/
run.py
86 lines (72 loc) · 2.5 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# imports
import mlflow
import argparse
import dask_mpi
import xgboost as xgb
import dask.dataframe as dd
from distributed import Client
from adlfs import AzureBlobFileSystem
# define functions
def main(args):
# distributed setup
print("initializing...")
dask_mpi.initialize(nthreads=args.cpus_per_node)
client = Client()
print(client)
# get data
print("connecting to data...")
print(client)
container_name = "malware"
storage_options = {"account_name": "azuremlexamples"}
fs = AzureBlobFileSystem(**storage_options)
files = fs.ls(f"{container_name}/processed")
# read into dataframes
print("creating dataframes...")
print(client)
for f in files:
if "train" in f:
df_train = dd.read_parquet(f"az://{f}", storage_options=storage_options)
elif "test" in f:
df_test = dd.read_parquet(f"az://{f}", storage_options=storage_options)
# data processing
print("processing data...")
print(client)
cols = [col for col in df_train.columns if df_train.dtypes[col] != "object"]
X = df_train[cols].drop("HasDetections", axis=1).values.persist()
y = df_train["HasDetections"].persist()
# train xgboost
print("training xgboost...")
print(client)
params = {
"objective": "binary:logistic",
"learning_rate": args.learning_rate,
"gamma": args.gamma,
"max_depth": args.max_depth,
}
mlflow.log_params(params) # log to the run
dtrain = xgb.dask.DaskDMatrix(client, X, y)
model = xgb.dask.train(client, params, dtrain, num_boost_round=args.num_boost_round)
print(model)
# predict on test data
print("making predictions...")
print(client)
X_test = df_test[
[col for col in cols if "HasDetections" not in col]
].values.persist()
y_pred = xgb.dask.predict(client, model, X_test)
y_pred.to_dask_dataframe().to_csv("./outputs/predictions.csv")
# save model
print("saving model...")
print(client)
mlflow.xgboost.log_model(model["booster"], "./outputs/model")
if __name__ == "__main__":
# argparse setup
parser = argparse.ArgumentParser()
parser.add_argument("--num_boost_round", type=int, default=10)
parser.add_argument("--learning_rate", type=float, default=0.1)
parser.add_argument("--gamma", type=float, default=0)
parser.add_argument("--max_depth", type=int, default=8)
parser.add_argument("--cpus_per_node", type=int, default=4)
args = parser.parse_args()
# run functions
main(args)