Skip to content

Commit

Permalink
New method: I-MLE
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasBoTang committed Oct 21, 2023
1 parent 5b1c54f commit 2ccb73b
Show file tree
Hide file tree
Showing 5 changed files with 186 additions and 3 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ To reproduce the experiments in original paper, please use the code and follow t

## Features

- Implement **SPO+** [[1]](https://doi.org/10.1287/mnsc.2020.3922), **DBB** [[3]](https://arxiv.org/abs/1912.02175), **NID** [[7]](https://arxiv.org/abs/2205.15213), **DPO** [[4]](https://papers.nips.cc/paper/2020/hash/6bb56208f672af0dd65451f869fedfd9-Abstract.html), **PFYL** [[4]](https://papers.nips.cc/paper/2020/hash/6bb56208f672af0dd65451f869fedfd9-Abstract.html), **NCE** [[5]](https://www.ijcai.org/proceedings/2021/390) and **LTR** [[6]](https://proceedings.mlr.press/v162/mandi22a.htm).
- Implement **SPO+** [[1]](https://doi.org/10.1287/mnsc.2020.3922), **DBB** [[3]](https://arxiv.org/abs/1912.02175), **I-MLE** [[8]](https://proceedings.neurips.cc/paper_files/paper/2021/hash/7a430339c10c642c4b2251756fd1b484-Abstract.html), **NID** [[7]](https://arxiv.org/abs/2205.15213), **DPO** [[4]](https://papers.nips.cc/paper/2020/hash/6bb56208f672af0dd65451f869fedfd9-Abstract.html), **PFYL** [[4]](https://papers.nips.cc/paper/2020/hash/6bb56208f672af0dd65451f869fedfd9-Abstract.html), **NCE** [[5]](https://www.ijcai.org/proceedings/2021/390) and **LTR** [[6]](https://proceedings.mlr.press/v162/mandi22a.htm).
- Support [Gurobi](https://www.gurobi.com/), [COPT](https://shanshu.ai/copt), and [Pyomo](http://www.pyomo.org/) API
- Support Parallel computing for optimization solver
- Support solution caching [[5]](https://www.ijcai.org/proceedings/2021/390) to speed up training
Expand Down Expand Up @@ -183,3 +183,4 @@ if __name__ == "__main__":
* [5] [Mulamba, M., Mandi, J., Diligenti, M., Lombardi, M., Bucarey, V., & Guns, T. (2021). Contrastive losses and solution caching for predict-and-optimize. Proceedings of the Thirtieth International Joint Conference on Artificial Intelligence.](https://www.ijcai.org/proceedings/2021/390)
* [6] [Mandi, J., Bucarey, V., Mulamba, M., & Guns, T. (2022). Decision-focused learning: through the lens of learning to rank. Proceedings of the 39th International Conference on Machine Learning.](https://proceedings.mlr.press/v162/mandi22a.html)
* [7] [Sahoo, S. S., Paulus, A., Vlastelica, M., Musil, V., Kuleshov, V., & Martius, G. (2022). Backpropagation through combinatorial algorithms: Identity with projection works. arXiv preprint arXiv:2205.15213.](https://arxiv.org/abs/2205.15213)
* [8] [Niepert, M., Minervini, P., & Franceschi, L. (2021). Implicit MLE: backpropagating through discrete exponential family distributions. Advances in Neural Information Processing Systems, 34, 14567-14579.](https://proceedings.neurips.cc/paper_files/paper/2021/hash/7a430339c10c642c4b2251756fd1b484-Abstract.html)
2 changes: 1 addition & 1 deletion pkg/pyepo/func/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@

from pyepo.func.spoplus import SPOPlus
from pyepo.func.blackbox import blackboxOpt, negativeIdentity
from pyepo.func.perturbed import perturbedOpt, perturbedFenchelYoung
from pyepo.func.perturbed import perturbedOpt, perturbedFenchelYoung, implicitMLE
from pyepo.func.contrastive import NCE, contrastiveMAP
from pyepo.func.rank import listwiseLTR, pairwiseLTR, pointwiseLTR
162 changes: 162 additions & 0 deletions pkg/pyepo/func/perturbed.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pyepo import EPO
from pyepo.func.abcmodule import optModule
from pyepo.utlis import getArgs
from pyepo.func.utlis import sumGammaDistribution


class perturbedOpt(optModule):
Expand Down Expand Up @@ -263,6 +264,167 @@ def backward(ctx, grad_output):
return grad * grad_output, None, None, None, None, None, None, None, None, None


class implicitMLE(optModule):
"""
An autograd module for Implicit Maximum Likelihood Estimator, which yield
an optimal solution in a constrained exponential family distribution via
Perturb-and-MAP.
For I-LME, it works as black-box combinatorial solvers, in which constraints
are known and fixed, but the cost vector need to be predicted from
contextual data.
The I-LME approximate gradient of optimizer smoothly. Thus, allows us to
design an algorithm based on stochastic gradient descent.
Reference: <https://proceedings.neurips.cc/paper_files/paper/2021/hash/7a430339c10c642c4b2251756fd1b484-Abstract.html>
"""

def __init__(self, optmodel, n_samples=10, sigma=1.0, lambd=10, processes=1,
distribution=sumGammaDistribution(kappa=5), solve_ratio=1,
dataset=None):
"""
Args:
optmodel (optModel): an PyEPO optimization model
n_samples (int): number of Monte-Carlo samples
sigma (float): noise temperature for the input distribution
lambd (float): a hyperparameter for differentiable block-box to control interpolation degree
processes (int): number of processors, 1 for single-core, 0 for all of cores
distribution (distribution): noise distribution
solve_ratio (float): the ratio of new solutions computed during training
dataset (None/optDataset): the training data
"""
super().__init__(optmodel, processes, solve_ratio, dataset)
# number of samples
self.n_samples = n_samples
# noise temperature
self.sigma = sigma
# smoothing parameter
if lambd <= 0:
raise ValueError("lambda is not positive.")
self.lambd = lambd
# noise distribution
self.distribution = distribution
# build I-LME
self.ilme = implicitMLEFunc()

def forward(self, pred_cost):
"""
Forward pass
"""
sols = self.ilme.apply(pred_cost, self.optmodel, self.n_samples,
self.sigma, self.lambd, self.processes,
self.pool, self.distribution, self.solve_ratio,
self)
return sols


class implicitMLEFunc(Function):
"""
A autograd function for Implicit Maximum Likelihood Estimator
"""

@staticmethod
def forward(ctx, pred_cost, optmodel, n_samples, sigma, lambd,
processes, pool, distribution, solve_ratio, module):
"""
Forward pass for IMLE
Args:
pred_cost (torch.tensor): a batch of predicted values of the cost
optmodel (optModel): an PyEPO optimization model
n_samples (int): number of Monte-Carlo samples
sigma (float): noise temperature for the input distribution
lambd (float): a hyperparameter for differentiable block-box to control interpolation degree
processes (int): number of processors, 1 for single-core, 0 for all of cores
pool (ProcessPool): process pool object
distribution (distribution): noise distribution
solve_ratio (float): the ratio of new solutions computed during training
module (optModule): implicitMLE module
Returns:
torch.tensor: predicted solutions
"""
# get device
device = pred_cost.device
# convert tenstor
cp = pred_cost.detach().to("cpu").numpy()
# sample perturbations
noises = distribution.sample(size=(n_samples, *cp.shape))
ptb_c = cp + sigma * noises
# solve with perturbation
rand_sigma = np.random.uniform()
if rand_sigma <= solve_ratio:
ptb_sols = _solve_in_pass(ptb_c, optmodel, processes, pool)
if solve_ratio < 1:
sols = ptb_sols.reshape(-1, cp.shape[1])
# add into solpool
module.solpool = np.concatenate((module.solpool, sols))
# remove duplicate
module.solpool = np.unique(module.solpool, axis=0)
else:
ptb_sols = _cache_in_pass(ptb_c, optmodel, module.solpool)
# solution average
e_sol = ptb_sols.mean(axis=1)
# convert to tensor
e_sol = torch.FloatTensor(e_sol).to(device)
# save
ctx.save_for_backward(pred_cost)
# add other objects to ctx
ctx.noises = noises
ctx.ptb_sols = ptb_sols
ctx.lambd = lambd
ctx.optmodel = optmodel
ctx.processes = processes
ctx.pool = pool
ctx.solve_ratio = solve_ratio
if solve_ratio < 1:
ctx.module = module
ctx.rand_sigma = rand_sigma
return e_sol

@staticmethod
def backward(ctx, grad_output):
"""
Backward pass for IMLE
"""
pred_cost, = ctx.saved_tensors
noises = ctx.noises
ptb_sols = ctx.ptb_sols
lambd = ctx.lambd
optmodel = ctx.optmodel
processes = ctx.processes
pool = ctx.pool
solve_ratio = ctx.solve_ratio
rand_sigma = ctx.rand_sigma
if solve_ratio < 1:
module = ctx.module
# get device
device = pred_cost.device
# convert tenstor
cp = pred_cost.detach().to("cpu").numpy()
dl = grad_output.detach().to("cpu").numpy()
# perturbed costs
ptb_cq = cp + lambd * dl + noises
# solve with perturbation
rand_sigma = np.random.uniform()
if rand_sigma <= solve_ratio:
ptb_solsq = _solve_in_pass(ptb_cq, optmodel, processes, pool)
if solve_ratio < 1:
sols = ptb_solsq.reshape(-1, cp.shape[1])
# add into solpool
module.solpool = np.concatenate((module.solpool, sols))
# remove duplicate
module.solpool = np.unique(module.solpool, axis=0)
else:
ptb_solsq = _cache_in_pass(ptb_cq, optmodel, module.solpool)
# get gradient
grad = (np.array(ptb_solsq) - ptb_sols).mean(axis=1) / lambd
# convert to tensor
grad = torch.FloatTensor(grad).to(device)
return grad, None, None, None, None, None, None, None, None, None


def _solve_in_pass(ptb_c, optmodel, processes, pool):
"""
A function to solve optimization in the forward pass
Expand Down
20 changes: 20 additions & 0 deletions pkg/pyepo/func/utlis.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,23 @@ def _check_sol(c, w, z):
raise AssertionError(
"Solution {} does not macth the objective value {}.".
format(c[i] @ w[i], z[i][0]))


class sumGammaDistribution:
"""
creates a generator of samples for the Sum-of-Gamma distribution
"""
def __init__(self, kappa, n_iterations=10, seed=135):
self.κ = kappa
self.n_iterations = n_iterations
self.rnd = np.random.RandomState(seed)

def sample(self, size):
# init samples
samples = 0
# calculate samples
for i in range(1, self.n_iterations+1):
samples += self.rnd.gamma(1/self.κ, self.κ/i, size)
samples -= np.log(self.n_iterations)
samples /= self.κ
return samples
2 changes: 1 addition & 1 deletion pkg/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# description
description = "PyTorch-based End-to-End Predict-then-Optimize Tool",
# version
version = "0.3.4",
version = "0.3.5",
# Github repo
url = "https://github.com/khalil-research/PyEPO",
# author name
Expand Down

0 comments on commit 2ccb73b

Please sign in to comment.