Simple Barycenter Computation#

[1]:
from mmot import MMOTSolver

import numpy as np
import matplotlib.pyplot as plt
import itertools
[2]:
# Grid of size n1 x n2
n1 = 128   # x axis
n2 = 128  # y axis

x, y = np.meshgrid(np.linspace(0.5/n1,1-0.5/n1,n1), np.linspace(0.5/n2,1-0.5/n1,n2))

r = 0.125

positions = np.array([[0.25,0.25],
                      [ 0.75,0.75],
                      [0.25,0.75]])

# measures = [None]*len(positions)
measures = [None]*(len(positions))

for i in range(len(positions)):
    xc,yc = positions[i]
    measures[i] = np.zeros((n2, n1))
    measures[i] = np.exp(-0.5*((x-xc)**2 + (y-yc)**2)*100.0)
    measures[i][measures[i]<1e-4] = 0.0

    measures[i] *= n1*n2 / np.sum(measures[i])

fig, ax = plt.subplots(1, len(measures), figsize=(12,4))
for i in range(len(measures)):
    ax[i].imshow(measures[i], origin='lower', extent=(0,1,0,1))
    ax[i].set_title("$\\mu_{{ {:0d} }}$".format(i))
../_images/examples_Barycenter_2_0.png
[3]:
unroll_node = 0

# The set A that defines the pairwise costs
edge_list = [[0,1], [1,2], [0,2]]

bary_weights = np.array([1.0,1.0,1.0])
bary_weights /= np.sum(bary_weights)

prob = MMOTSolver(measures, edge_list, x, y, unroll_node, bary_weights)

prob.Visualize('original', filename='CostGraph.svg')
../_images/examples_Barycenter_3_0.png
[4]:
dual_vars = [np.zeros(x.shape) for i in range(prob.NumDual())]

num_its = 200
step_size = 0.25
root_node = 0
costs = np.nan*np.zeros(num_its)

ftol = 0.0
gtol = 0.0

root_nodes = np.arange(prob.NumDual())
root_cycler = itertools.cycle(root_nodes)

print('Iteration, StepSize,    Cost,   Error')
for i in range(num_its):
    error = prob.Step(next(root_cycler), dual_vars, step_size)

    costs[i] = prob.ComputeCost(dual_vars)
    if(i>0):
        step_size = prob.StepSizeUpdate(step_size, costs[i], costs[i-1], error)

    if((i%10)==0):
        print('{:9d},   {:0.4f},  {:0.4f},   {:0.5f}'.format(i,step_size, costs[i], error))

    # Check for convergence in cost
    if(np.abs(costs[i]-costs[i-1])<ftol):
        break

    # Check for convergence via gradient
    if(error<gtol):
        break

plt.figure()
plt.plot(costs)
plt.show()
Iteration, StepSize,    Cost,   Error
        0,   0.2500,  0.0057,   1.99919
       10,   0.0593,  0.0368,   0.81437
       20,   0.0188,  0.0531,   0.26774
       30,   0.0106,  0.0547,   0.00155
       40,   0.0079,  0.0547,   0.00048
       50,   0.0106,  0.0547,   0.00030
       60,   0.0106,  0.0548,   0.00014
       70,   0.0079,  0.0548,   0.00008
       80,   0.0106,  0.0548,   0.00010
       90,   0.0106,  0.0548,   0.00004
      100,   0.0059,  0.0548,   0.00014
      110,   0.0059,  0.0548,   0.00003
      120,   0.0059,  0.0548,   0.00003
      130,   0.0059,  0.0548,   0.00002
      140,   0.0059,  0.0548,   0.00003
      150,   0.0059,  0.0548,   0.00002
      160,   0.0059,  0.0548,   0.00002
      170,   0.0059,  0.0548,   0.00001
      180,   0.0059,  0.0548,   0.00002
      190,   0.0059,  0.0548,   0.00001
../_images/examples_Barycenter_4_1.png
[5]:
weights = np.ones(len(measures))/len(measures)
bary = prob.Barycenter(dual_vars)

[6]:
vmax = np.max(measures[0])
fig, axs = plt.subplots(1, len(measures)+2, figsize=(12,4))
for i in range(len(measures)):
    axs[i].imshow(measures[i], origin='lower', extent=(0,1,0,1))#, vmin=0, vmax=vmax)
    axs[i].set_title("$\\mu_{{ {:0d} }}$".format(i))

axs[-2].imshow(bary, origin='lower', extent=(0,1,0,1))#, vmin=0, vmax=vmax)
axs[-2].set_title('Estimated Barycenter')

true_loc = np.mean(positions,axis=0)
true_bary = np.zeros((n2, n1))
true_bary = np.exp(-0.5*((x-true_loc[0])**2 + (y-true_loc[1])**2)*100.0)
true_bary *= n1*n2 / np.sum(true_bary)

#true_bary[(x-true_loc[0])**2 + (y-true_loc[1])**2 < r**2] = 1
#true_bary *= n1*n2 / np.sum(true_bary)

axs[-1].imshow(true_bary, origin='lower', extent=(0,1,0,1))#, vmin=0, vmax=vmax)
axs[-1].set_title('True Barycenter')

[6]:
Text(0.5, 1.0, 'True Barycenter')
../_images/examples_Barycenter_6_1.png
[7]:
plt.plot(true_bary[int(0.6*n1),:],label='True Barycenter')
plt.plot(bary[int(0.6*n1),:],label='Estimated Barycenter')
plt.title('Barycenter Slice at y=0.6')
plt.legend()
[7]:
<matplotlib.legend.Legend at 0x16ccd3160>
../_images/examples_Barycenter_7_1.png
[ ]:

[ ]: