Skip to content

Commit

Permalink
Add output name as an argument
Browse files Browse the repository at this point in the history
  • Loading branch information
koparasy committed Oct 28, 2024
1 parent 5587e0f commit c70b60b
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions tests/AMSlib/generate_tupple_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def main(args):
device = args[3]
uq_type = args[4]
precision = args[5]
output_name = args[6]
enable_cuda = True
if device == "cuda":
enable_cuda = True
Expand Down Expand Up @@ -82,11 +83,11 @@ def main(args):

with torch.jit.optimized_execution(True):
traced = torch.jit.trace(model, (torch.randn(inputDim, dtype=prec).to(device),))
traced.save(f"uq_{uq_type}_{precision}{suffix}.pt")
traced.save(f"{output_name}")

data = torch.zeros(2, inputDim, dtype=prec)
inputs = Variable(data.to(device))
model = jit.load(f"uq_{uq_type}_{precision}{suffix}.pt")
model = jit.load(f"{output_name}")
model.eval()
with torch.no_grad():
print("Ouput", model(inputs))
Expand Down

0 comments on commit c70b60b

Please sign in to comment.