-
Notifications
You must be signed in to change notification settings - Fork 36
/
calculate_nlls.py
executable file
·60 lines (44 loc) · 2.16 KB
/
calculate_nlls.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#!/usr/bin/env python
# coding=utf-8
"""
Calculates the NLLs of a set of scaffolds and decorations.
"""
import argparse
import models.model as mm
import models.actions as ma
import utils.log as ul
import utils.chem as uc
def parse_args():
"""Parses input arguments."""
parser = argparse.ArgumentParser(
description="Calculates NLLs of a list of scaffolds and decorations given a model.")
parser.add_argument("--input-csv-path", "-i",
help="Path to the input CSV file. The first and second fields are the scaffold \
and the decoration. The rest are going to be kept as-is.",
type=str, required=True)
parser.add_argument("--output-csv-path", "-o",
help="Path to the output CSV file which will have the NLL added as a new field in the end.",
type=str, required=True)
parser.add_argument("--model-path", "-m", help="Path to the model that will be used.", type=str, required=True)
parser.add_argument("--batch-size", "-b",
help="Batch size used to calculate NLLs (DEFAULT: 128).", type=int, default=128)
parser.add_argument("--use-gzip", help="Compress the output file (if set).", action="store_true", default=False)
return parser.parse_args()
def main():
"""Main function."""
args = parse_args()
model = mm.DecoratorModel.load_from_file(args.model_path, mode="sampling")
input_csv = uc.open_file(args.input_csv_path, mode="rt")
if args.use_gzip:
args.output_csv_path += ".gz"
output_csv = uc.open_file(args.output_csv_path, mode="wt+")
calc_nlls_action = ma.CalculateNLLsFromModel(model, batch_size=args.batch_size, logger=LOG)
scaffold_decoration_list = [fields[0:2] for fields in uc.read_csv_file(args.input_csv_path)]
for nll in ul.progress_bar(calc_nlls_action.run(scaffold_decoration_list), total=len(scaffold_decoration_list)):
input_line = input_csv.readline().strip()
output_csv.write("{}\t{:.8f}\n".format(input_line, nll))
input_csv.close()
output_csv.close()
LOG = ul.get_logger("calculate_nlls")
if __name__ == "__main__":
main()