From 5edf6ecb55acf3cb916a8fb17861fd5c4bf1d58a Mon Sep 17 00:00:00 2001 From: lewismervin1 <58475084+lewismervin1@users.noreply.github.com> Date: Wed, 3 Jul 2024 08:55:47 +0100 Subject: [PATCH] Update optbuild.py remove AZ specific inference function --- optunaz/optbuild.py | 50 --------------------------------------------- 1 file changed, 50 deletions(-) diff --git a/optunaz/optbuild.py b/optunaz/optbuild.py index 3ef0ad7..eec9634 100644 --- a/optunaz/optbuild.py +++ b/optunaz/optbuild.py @@ -22,44 +22,6 @@ logger = logging.getLogger(__name__) -def predict_pls(model_path, inference_path): - if inference_path == "None": - logger.info(f"Inference path is not set so AL predictions not performed") - return - else: - logger.info(f"Inference path is {inference_path}") - predict_args = [ - "prog", - "--model-file", - str(model_path), - "--input-smiles-csv-file", - str(inference_path), - "--input-smiles-csv-column", - "Structure", - "--output-prediction-csv-file", - str(os.path.dirname(model_path)) + "/al.csv", - "--predict-uncertainty", - "--uncertainty_quantile", - "0.99", - ] - try: - with patch.object(sys, "argv", predict_args): - logging.info("Performing active learning predictions") - predict.main() - except FileNotFoundError: - logger.info( - f"PLS file not found at {model_path}, AL predictions not performed" - ) - except predict.UncertaintyError: - logging.info( - "PLS prediction not performed: algorithm does not support uncertainty prediction" - ) - except predict.AuxCovariateMissing: - logging.info( - "PLS prediction not performed: algorithm requires corvariate auxiliary data for inference" - ) - - def main(): logging.basicConfig(level=logging.INFO) parser = argparse.ArgumentParser( @@ -97,12 +59,6 @@ def main(): help="Turn off descriptor generation caching ", action="store_true", ) - parser.add_argument( - "--inference_uncert", - help="Path for uncertainty inference and thresholding.", - type=pathlib.Path, - default="/projects/db-mirror/MLDatasets/PLS/pls.csv", - ) args = parser.parse_args() AnyConfig = Union[OptimizationConfig, BuildConfig] @@ -111,7 +67,6 @@ def main(): if isinstance(config, OptimizationConfig): study_name = str(pathlib.Path(args.config).absolute()) - pred_pls = False if not args.no_cache: config.set_cache() cache = config._cache @@ -123,7 +78,6 @@ def main(): if args.best_model_outpath or args.merged_model_outpath: buildconfig = buildconfig_best(study) elif isinstance(config, BuildConfig): - pred_pls = True buildconfig = config cache = None cache_dir = None @@ -140,15 +94,11 @@ def main(): args.best_model_outpath, cache=cache, ) - if not args.merged_model_outpath and pred_pls: - predict_pls(args.best_model_outpath, args.inference_uncert) if args.merged_model_outpath: build_merged( buildconfig, args.merged_model_outpath, cache=cache, ) - if pred_pls: - predict_pls(args.merged_model_outpath, args.inference_uncert) if cache_dir is not None: cache_dir.cleanup()