-
Notifications
You must be signed in to change notification settings - Fork 0
/
run.py
22 lines (15 loc) · 834 Bytes
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import argparse
from train import SeverityClassificationPipeline
file_path = "data/processed-data.csv"
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--test_size', dest='test_size', help='Test Size (default 0.2)', type=float, default=0.2)
parser.add_argument('--random_state', dest='random_state', help='Random State', type=int, default=42)
parser.add_argument('--lr', dest='lr', help='LEARNING RATE', type=float, default=None)
parser.add_argument('--max_depth', dest='max_depth', help='Max Depth', type=int, default=None)
return parser.parse_args()
def main(args):
train_model = SeverityClassificationPipeline(file_path, args.test_size, args.random_state, args.lr, args.max_depth)
train_model.run_augmented_data()
if __name__ == '__main__':
main(parse_args())