## Differentiable Optimizers with Perturbations in PyTorch

This contains a PyTorch implementation of Differentiable Optimizers with Perturbations in Tensorflow. All credit belongs to the original authors which can be found below. The source code, tests, and examples given below are a one-to-one copy of the original work, but with pure PyTorch implementations.

## Overview

We propose in this work a universal method to transform any optimizer in a

differentiable approximation. We provide a PyTorch implementation,

illustrated here on some examples.

## Perturbed argmax

We start from an original optimizer, an `argmax`

function, computed on an

example input `theta`

.

```
import torch
import torch.nn.functional as F
import perturbations
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def argmax(x, axis=-1):
return F.one_hot(torch.argmax(x, dim=axis), list(x.shape)[axis]).float()
```

This function returns a one-hot corresponding to the largest input entry.

```
>>> argmax(torch.tensor([-0.6, 1.9, -0.2, 1.1, -1.0]))
tensor([0., 1., 0., 0., 0.])
```

It is possible to modify the function by creating a perturbed optimizer, using

Gumbel noise.

```
pert_argmax = perturbations.perturbed(argmax,
num_samples=1000000,
sigma=0.5,
noise='gumbel',
batched=False,
device=device)
```

```
>>> theta = torch.tensor([-0.6, 1.9, -0.2, 1.1, -1.0], device=device)
>>> pert_argmax(theta)
tensor([0.0055, 0.8150, 0.0122, 0.1648, 0.0025], device='cuda:0')
```

In this particular case, it is equal to the usual softmax with exponential

weights.

```
>>> sigma = 0.5
>>> F.softmax(theta/sigma, dim=-1)
tensor([0.0055, 0.8152, 0.0122, 0.1646, 0.0025], device='cuda:0')
```

### Batched version

The original function can accept a batch dimension, and is applied to every

element of the batch.

```
theta_batch = torch.tensor([[-0.6, 1.9, -0.2, 1.1, -1.0],
[-0.6, 1.0, -0.2, 1.8, -1.0]], device=device, requires_grad=True)
```

```
>>> argmax(theta_batch)
tensor([[0., 1., 0., 0., 0.],
[0., 0., 0., 1., 0.]], device='cuda:0')
```

Likewise, if the argument `batched`

is set to `True`

(its default value), the

perturbed optimizer can handle a batch of inputs.

```
pert_argmax = perturbations.perturbed(argmax,
num_samples=1000000,
sigma=0.5,
noise='gumbel',
batched=True,
device=device)
```

```
>>> pert_argmax(theta_batch)
tensor([[0.0055, 0.8158, 0.0122, 0.1640, 0.0025],
[0.0066, 0.1637, 0.0147, 0.8121, 0.0030]], device='cuda:0')
```

It can be compared to its deterministic version, the softmax.

```
>>> F.softmax(theta_batch/sigma, dim=-1)
tensor([[0.0055, 0.8152, 0.0122, 0.1646, 0.0025],
[0.0067, 0.1639, 0.0149, 0.8116, 0.0030]], device='cuda:0')
```

### Decorator version

It is also possible to use the perturbed function as a decorator.

```
@perturbations.perturbed(num_samples=1000000, sigma=0.5, noise='gumbel', batched=True, device=device)
def argmax(x, axis=-1):
return F.one_hot(torch.argmax(x, dim=axis), list(x.shape)[axis]).float()
```

```
>>> argmax(theta_batch)
tensor([[0.0054, 0.8148, 0.0121, 0.1652, 0.0024],
[0.0067, 0.1639, 0.0148, 0.8116, 0.0029]], device='cuda:0')
```

### Gradient computation

The Perturbed optimizers are differentiable, and the gradients can be computed

with stochastic estimation automatically. In this case, it can be compared

directly to the gradient of softmax.

```
output = pert_argmax(theta_batch)
square_norm = torch.linalg.norm(output)
square_norm.backward(torch.ones_like(square_norm))
grad_pert = theta_batch.grad
```

```
>>> grad_pert
tensor([[-0.0072, 0.1708, -0.0132, -0.1476, -0.0033],
[-0.0068, -0.1464, -0.0173, 0.1748, -0.0046]], device='cuda:0')
```

Compared to the same computations with a softmax.

```
output = F.softmax(theta_batch/sigma, dim=-1)
square_norm = torch.linalg.norm(output)
square_norm.backward(torch.ones_like(square_norm))
grad_soft = theta_batch.grad
```

```
>>> grad_soft
tensor([[-0.0064, 0.1714, -0.0142, -0.1479, -0.0029],
[-0.0077, -0.1457, -0.0170, 0.1739, -0.0035]], device='cuda:0')
```

## Perturbed OR

The OR function over the signs of inputs, that is an example of optimizer,

offers a well-interpretable visualization.

```
def hard_or(x):
s = ((torch.sign(x) + 1) / 2.0).type(torch.bool)
result = torch.any(s, dim=-1)
return result.type(torch.float) * 2.0 - 1
```

In the following batch of two inputs, both instances are evaluated as `True`

(value `1`

).

```
theta = torch.tensor([[-5., 0.2],
[-5., 0.1]], device=device)
```

```
>>> hard_or(theta)
tensor([1., 1.])
```

Computing a perturbed OR operator over 1000 samples shows the difference in

value for these two inputs.

```
pert_or = perturbations.perturbed(hard_or,
num_samples=1000,
sigma=0.1,
noise='gumbel',
batched=True,
device=device)
```

```
>>> pert_or(theta)
tensor([1.0000, 0.8540], device='cuda:0')
```

This can be vizualized more broadly, for values between -1 and 1, as well as the

evaluated values of the gradient.

## Perturbed shortest path

This framework can also be easily applied to more complex optimizers, such as a

blackbox shortest paths solver (here the function `shortest_path`

). We consider

a small example on 9 nodes, illustrated here with the shortest path between 0

and 8 in bold, and edge costs labels.

We also consider a function of the perturbed solution: the weight of this

solution on the edgebetween nodes **6** and **8**.

A gradient of this function with respect to a vector of four edge costs

(top-rightmost, between nodes 4, 5, 6, and 8) is automatically computed. This

can be used to increase the weight on this edge of the solution by changing

these four costs. This is challenging to do with first-order methods using only

an original optimizer, as its gradient would be zero almost everywhere.

```
final_edges_costs = torch.tensor([0.4, 0.1, 0.1, 0.1], device=device, requires_grad=True)
weights = edge_costs_to_weights(final_edges_costs)
@perturbations.perturbed(num_samples=100000, sigma=0.05, batched=False, device=device)
def perturbed_shortest_path(weights):
return shortest_path(weights, symmetric=False)
```

We obtain a perturbed solution to the shortest path problem on this graph, an

average of solutions under perturbations on the weights.

```
>>> perturbed_shortest_path(weights)
tensor([[0. 0. 0.001 0.025 0. 0. 0. 0. 0. ]
[0. 0. 0. 0. 0.023 0. 0. 0. 0. ]
[0.679 0. 0. 0.119 0. 0. 0. 0. 0. ]
[0.304 0. 0. 0. 0. 0. 0. 0. 0. ]
[0. 0.023 0. 0. 0. 0.898 0. 0. 0. ]
[0. 0. 0.001 0. 0. 0. 0.896 0. 0. ]
[0. 0. 0. 0. 0. 0.001 0. 0.974 0. ]
[0. 0. 0.797 0.178 0. 0. 0. 0. 0. ]
[0. 0. 0. 0. 0.921 0. 0.079 0. 0. ]])
```

For illustration, this solution can be represented with edge width proportional

to the weight of the solution.

We consider an example of scalar function on this solution, here the weight of

the perturbed solution on the edge from node 6 to 8 (of current value `0.079`

).

```
def i_to_j_weight_fn(i, j, paths):
return paths[..., i, j]
weights = edge_costs_to_weights(final_edges_costs)
pert_paths = perturbed_shortest_path(weights)
i_to_j_weight = pert_paths[..., 8, 6]
i_to_j_weight.backward(torch.ones_like(i_to_j_weight))
grad = final_edges_costs.grad
```

This provides a direction in which to modify the vector of four edge costs, to

increase the weight on this solution, obtained thanks to our perturbed version

of the optimizer.

```
>>> grad
tensor([-2.0993764, 2.076386 , 2.042395 , 2.0411625], device='cuda:0')
```

Running gradient *ascent* for 30 steps on this vector of four edge costs to

*increase* the weight of the edge from 6 to 8 modifies the problem. Its new

perturbed solution has a corresponding edge weight of `0.989`

. The new problem

and its perturbed solution can be vizualized as follows.

## References

Berthet Q., Blondel M., Teboul O., Cuturi M., Vert J.-P., Bach F.,

Learning with Differentiable Perturbed Optimizers,

NeurIPS 2020

## License

Please see the original repository for proper details.