From 7ed0c5f9fbec07c5a74c2f2f2396984938c1abd8 Mon Sep 17 00:00:00 2001 From: Bo Tang Date: Mon, 17 Jul 2023 02:40:59 -0400 Subject: [PATCH] Bug fix: model sense --- pkg/pyepo/func/blackbox.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/pkg/pyepo/func/blackbox.py b/pkg/pyepo/func/blackbox.py index b4862af1..49ac5f07 100644 --- a/pkg/pyepo/func/blackbox.py +++ b/pkg/pyepo/func/blackbox.py @@ -225,6 +225,8 @@ def forward(ctx, pred_cost, optmodel, processes, pool, solve_ratio, module): sol, _ = _cache_in_pass(cp, optmodel, module.solpool) # convert to tensor pred_sol = torch.FloatTensor(np.array(sol)).to(device) + # add other objects to ctx + ctx.optmodel = optmodel return pred_sol @staticmethod @@ -232,8 +234,13 @@ def backward(ctx, grad_output): """ Backward pass for NID """ + optmodel = ctx.optmodel # get device device = grad_output.device # identity matrix I = torch.eye(grad_output.shape[1]).to(device) - return grad_output @ (-I), None, None, None, None, None + if optmodel.modelSense == EPO.MINIMIZE: + grad = - I + if optmodel.modelSense == EPO.MAXIMIZE: + grad = I + return grad_output @ grad, None, None, None, None, None