Skip to content

Commit

Permalink
Bug fix: model sense
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasBoTang committed Jul 17, 2023
1 parent 5fb8731 commit 7ed0c5f
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion pkg/pyepo/func/blackbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,15 +225,22 @@ def forward(ctx, pred_cost, optmodel, processes, pool, solve_ratio, module):
sol, _ = _cache_in_pass(cp, optmodel, module.solpool)
# convert to tensor
pred_sol = torch.FloatTensor(np.array(sol)).to(device)
# add other objects to ctx
ctx.optmodel = optmodel
return pred_sol

@staticmethod
def backward(ctx, grad_output):
"""
Backward pass for NID
"""
optmodel = ctx.optmodel
# get device
device = grad_output.device
# identity matrix
I = torch.eye(grad_output.shape[1]).to(device)
return grad_output @ (-I), None, None, None, None, None
if optmodel.modelSense == EPO.MINIMIZE:
grad = - I
if optmodel.modelSense == EPO.MAXIMIZE:
grad = I
return grad_output @ grad, None, None, None, None, None

0 comments on commit 7ed0c5f

Please sign in to comment.