diff --git a/tests/AMSlib/generate_tupple_model.py b/tests/AMSlib/generate_tupple_model.py index f4cb232..104a353 100644 --- a/tests/AMSlib/generate_tupple_model.py +++ b/tests/AMSlib/generate_tupple_model.py @@ -43,6 +43,7 @@ def main(args): device = args[3] uq_type = args[4] precision = args[5] + output_name = args[6] enable_cuda = True if device == "cuda": enable_cuda = True @@ -82,11 +83,11 @@ def main(args): with torch.jit.optimized_execution(True): traced = torch.jit.trace(model, (torch.randn(inputDim, dtype=prec).to(device),)) - traced.save(f"uq_{uq_type}_{precision}{suffix}.pt") + traced.save(f"{output_name}") data = torch.zeros(2, inputDim, dtype=prec) inputs = Variable(data.to(device)) - model = jit.load(f"uq_{uq_type}_{precision}{suffix}.pt") + model = jit.load(f"{output_name}") model.eval() with torch.no_grad(): print("Ouput", model(inputs))