Barycenter of MNIST Digits#

[1]:

from mmot import MMOTSolver import numpy as np import matplotlib.pyplot as plt import itertools

Download and open the MNIST dataset#

[2]:
import hashlib
import os
import requests
import gzip

#fetch data (adapted from https://github.com/geohot/ai-notebooks/blob/master/mnist_from_scratch.ipynb)
path='./'
def fetch(url):
    fp = os.path.join(path, hashlib.md5(url.encode('utf-8')).hexdigest())
    if os.path.isfile(fp):
        with open(fp, "rb") as f:
            data = f.read()
    else:
        with open(fp, "wb") as f:
            data = requests.get(url).content
            f.write(data)
    return np.frombuffer(gzip.decompress(data), dtype=np.uint8).copy()

digits = fetch("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz")[0x10:].reshape((-1, 28, 28))
labels = fetch("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz")[8:]

Plot a few samples of the digit we’re interested in#

[3]:
desired_digit = 3
inds = np.where(labels==desired_digit)[0]

num_plot = 5
fig,axs = plt.subplots(ncols=num_plot, sharey=True, figsize=(num_plot*5,5))
for i in range(num_plot):
    axs[i].imshow(digits[inds[i],:,:],cmap='Greys')

../_images/examples_MNISTBarycenter_5_0.png

Extract digits with similar total measure#

[4]:
# Grid of size n1 x n2
n1 = digits.shape[1]   # x axis
n2 = digits.shape[2]   # y axis

x, y = np.meshgrid(np.linspace(0.5/n1,1-0.5/n1,n1), np.linspace(0.5/n2,1-0.5/n1,n2))
[5]:
unroll_node = 0

num_digits = 10
sums = np.array([np.sum(digits[i,:,:]) for i in inds])

val = np.sum(digits[inds[0],:,:])

sorted_inds = inds[np.argsort(np.abs(sums-val))]
measures = [digits[sorted_inds[i],:,:]* (n1*n2/np.sum(digits[sorted_inds[i],:,:])) for i in range(num_digits)]

Define the edge list for the barycenter problem#

[6]:

edge_list = [] for i in range(num_digits): for j in range(i+1,num_digits): edge_list += [[i,j]] weights = np.ones(num_digits)/num_digits prob = MMOTSolver(measures, edge_list, x, y, unroll_node, weights)

Solve the problem#

[7]:
res = prob.Solve(max_its=10000, step_size=0.2, ftol_abs=1e-9, gtol_abs=1e-3)

Iteration, StepSize,        Cost,        Error,  Line Its
next_root 0 =  0
next_root 1 =  0
next_root 2 =  0
next_root 3 =  0
next_root 4 =  0
next_root 5 =  0
next_root 6 =  0
next_root 7 =  0
next_root 8 =  0
        0,   0.0008,  1.2998e-04,   1.0349e+00,         8
next_root 0 =  1
next_root 0 =  2
next_root 0 =  3
next_root 0 =  4
next_root 0 =  5
next_root 0 =  6
next_root 0 =  7
next_root 0 =  8
next_root 0 =  9
next_root 1 =  9
next_root 0 =  10
next_root 1 =  10
       10,   0.0012,  1.2499e-03,   1.5150e-01,         1
next_root 0 =  11
next_root 0 =  12
next_root 0 =  13
next_root 0 =  14
next_root 0 =  15
next_root 0 =  16
next_root 0 =  17
next_root 1 =  17
next_root 2 =  17
next_root 0 =  18
next_root 0 =  19
next_root 0 =  20
next_root 1 =  20
       20,   0.0009,  1.4826e-03,   4.6731e-02,         1
next_root 0 =  21
next_root 0 =  22
next_root 0 =  23
next_root 1 =  23
next_root 0 =  24
next_root 0 =  25
next_root 0 =  26
next_root 1 =  26
next_root 0 =  27
next_root 0 =  28
next_root 0 =  29
next_root 1 =  29
next_root 0 =  30
       30,   0.0007,  1.5566e-03,   4.0157e-02,         0
next_root 0 =  31
next_root 0 =  32
next_root 1 =  32
next_root 0 =  33
next_root 0 =  34
next_root 0 =  35
next_root 0 =  36
next_root 1 =  36
next_root 0 =  37
next_root 0 =  38
next_root 0 =  39
next_root 1 =  39
next_root 0 =  40
       40,   0.0006,  1.6275e-03,   4.0973e-02,         0
next_root 0 =  41
next_root 0 =  42
next_root 1 =  42
next_root 0 =  43
next_root 0 =  44
next_root 0 =  45
next_root 0 =  0
next_root 0 =  1
next_root 0 =  2
next_root 0 =  3
next_root 1 =  3
next_root 0 =  4
       50,   0.0009,  1.7358e-03,   3.2187e-02,         0
next_root 0 =  5
next_root 1 =  5
next_root 0 =  6
next_root 0 =  7
next_root 1 =  7
next_root 0 =  8
next_root 0 =  9
next_root 0 =  10
next_root 0 =  11
next_root 1 =  11
next_root 0 =  12
next_root 0 =  13
next_root 1 =  13
next_root 0 =  14
       60,   0.0003,  1.7963e-03,   2.8782e-02,         0
next_root 0 =  15
next_root 0 =  16
next_root 0 =  17
next_root 0 =  18
next_root 0 =  19
next_root 1 =  19
next_root 0 =  20
next_root 0 =  21
next_root 0 =  22
next_root 1 =  22
next_root 0 =  23
next_root 0 =  24
       70,   0.0005,  1.8271e-03,   2.7783e-02,         0
next_root 0 =  25
next_root 0 =  26
next_root 1 =  26
next_root 0 =  27
next_root 0 =  28
next_root 0 =  29
next_root 1 =  29
next_root 0 =  30
next_root 0 =  31
next_root 0 =  32
next_root 1 =  32
next_root 0 =  33
next_root 1 =  33
next_root 0 =  34
       80,   0.0002,  1.8805e-03,   2.8257e-02,         0
next_root 0 =  35
next_root 0 =  36
next_root 0 =  37
next_root 0 =  38
next_root 1 =  38
next_root 0 =  39
next_root 0 =  40
next_root 0 =  41
next_root 0 =  42
next_root 1 =  42
next_root 0 =  43
next_root 0 =  44
       90,   0.0003,  1.9226e-03,   2.9543e-02,         0
next_root 0 =  45
next_root 0 =  0
next_root 1 =  0
next_root 0 =  1
next_root 1 =  1
next_root 0 =  2
next_root 0 =  3
next_root 0 =  4
next_root 0 =  5
next_root 1 =  5
next_root 0 =  6
next_root 0 =  7
next_root 1 =  7
next_root 0 =  8
      100,   0.0001,  1.9420e-03,   2.8192e-02,         0
next_root 0 =  9
next_root 0 =  10
next_root 0 =  11
next_root 0 =  12
next_root 0 =  13
next_root 1 =  13
next_root 0 =  14
next_root 0 =  15
next_root 0 =  16
next_root 0 =  17
next_root 1 =  17
next_root 0 =  18
      110,   0.0002,  1.9607e-03,   2.6876e-02,         0
next_root 0 =  19
next_root 0 =  20
next_root 1 =  20
next_root 0 =  21
next_root 0 =  22
next_root 0 =  23
next_root 0 =  24
next_root 1 =  24
next_root 0 =  25
next_root 0 =  26
next_root 0 =  27
next_root 0 =  28
next_root 1 =  28
      120,   0.0001,  1.9792e-03,   2.6153e-02,         1
next_root 0 =  29
next_root 0 =  30
next_root 0 =  31
next_root 0 =  32
next_root 1 =  32
next_root 0 =  33
next_root 1 =  33
next_root 0 =  34
next_root 0 =  35
next_root 0 =  36
next_root 0 =  37
next_root 0 =  38
next_root 1 =  38
      130,   0.0001,  1.9903e-03,   2.7401e-02,         1
next_root 0 =  39
next_root 0 =  40
next_root 0 =  41
next_root 0 =  42
next_root 1 =  42
next_root 0 =  43
next_root 0 =  44
next_root 0 =  45
next_root 0 =  0
next_root 1 =  0
next_root 0 =  1
next_root 1 =  1
next_root 0 =  2
next_root 1 =  2
      140,   0.0000,  2.0004e-03,   2.7720e-02,         1
next_root 0 =  3
next_root 0 =  4
next_root 0 =  5
next_root 0 =  6
next_root 0 =  7
next_root 1 =  7
next_root 0 =  8
next_root 0 =  9
next_root 0 =  10
next_root 0 =  11
next_root 1 =  11
next_root 0 =  12
      150,   0.0001,  2.0087e-03,   2.7586e-02,         0
next_root 0 =  13
next_root 0 =  14
next_root 0 =  15
next_root 1 =  15
next_root 0 =  16
next_root 0 =  17
next_root 0 =  18
next_root 0 =  19
next_root 0 =  20
next_root 1 =  20
next_root 0 =  21
next_root 0 =  22
      160,   0.0001,  2.0155e-03,   2.7085e-02,         0
next_root 0 =  23
next_root 0 =  24
next_root 1 =  24
next_root 0 =  25
next_root 0 =  26
next_root 0 =  27
next_root 0 =  28
next_root 1 =  28
next_root 0 =  29
next_root 0 =  30
next_root 0 =  31
next_root 0 =  32
next_root 1 =  32
      170,   0.0001,  2.0205e-03,   2.7040e-02,         1
next_root 0 =  33
next_root 1 =  33
next_root 0 =  34
next_root 0 =  35
next_root 0 =  36
next_root 0 =  37
next_root 1 =  37
next_root 0 =  38
next_root 0 =  39
next_root 0 =  40
next_root 0 =  41
next_root 0 =  42
next_root 1 =  42
      180,   0.0001,  2.0259e-03,   2.7703e-02,         1
next_root 0 =  43
next_root 0 =  44
next_root 0 =  45
next_root 0 =  0
next_root 1 =  0
next_root 0 =  1
next_root 1 =  1
next_root 0 =  2
next_root 1 =  2
next_root 0 =  3
next_root 0 =  4
next_root 0 =  5
next_root 1 =  5
next_root 0 =  6
      190,   0.0000,  2.0300e-03,   2.7548e-02,         0
next_root 0 =  7
next_root 0 =  8
next_root 0 =  9
next_root 0 =  10
next_root 0 =  11
next_root 1 =  11
next_root 0 =  12
next_root 0 =  13
next_root 1 =  13
next_root 0 =  14
next_root 0 =  15
next_root 0 =  16
      200,   0.0000,  2.0350e-03,   2.7808e-02,         0
next_root 0 =  17
next_root 0 =  18
next_root 0 =  19
next_root 1 =  19
next_root 0 =  20
next_root 0 =  21
next_root 0 =  22
next_root 0 =  23
next_root 0 =  24
next_root 1 =  24
next_root 0 =  25
next_root 0 =  26
next_root 1 =  26
      210,   0.0000,  2.0386e-03,   2.7053e-02,         1
next_root 0 =  27
next_root 0 =  28
next_root 0 =  29
next_root 1 =  29
next_root 0 =  30
next_root 0 =  31
next_root 0 =  32
next_root 1 =  32
next_root 0 =  33
next_root 0 =  34
next_root 0 =  35
next_root 0 =  36
      220,   0.0000,  2.0425e-03,   2.7091e-02,         0
next_root 0 =  37
next_root 1 =  37
next_root 0 =  38
next_root 0 =  39
next_root 0 =  40
next_root 1 =  40
next_root 0 =  41
next_root 0 =  42
next_root 0 =  43
next_root 0 =  44
next_root 0 =  45
next_root 0 =  0
next_root 1 =  0
      230,   0.0000,  2.0447e-03,   2.7168e-02,         1
next_root 0 =  1
next_root 1 =  1
next_root 2 =  1
next_root 0 =  2
next_root 1 =  2
next_root 0 =  3
next_root 1 =  3
next_root 0 =  4
next_root 0 =  5
next_root 1 =  5
next_root 0 =  6
next_root 0 =  7
next_root 1 =  7
next_root 0 =  8
next_root 0 =  9
next_root 0 =  10
      240,   0.0000,  2.0456e-03,   2.7954e-02,         0
next_root 0 =  11
next_root 1 =  11
next_root 0 =  12
next_root 0 =  13
next_root 1 =  13
next_root 0 =  14
next_root 0 =  15
next_root 1 =  15
next_root 0 =  16
next_root 0 =  17
next_root 0 =  18
next_root 0 =  19
next_root 1 =  19
next_root 0 =  20
      250,   0.0000,  2.0472e-03,   2.7679e-02,         0
next_root 0 =  21
next_root 0 =  22
next_root 0 =  23
next_root 1 =  23
next_root 0 =  24
next_root 1 =  24
next_root 0 =  25
next_root 0 =  26
next_root 0 =  27
next_root 0 =  28
next_root 1 =  28
next_root 0 =  29
next_root 0 =  30
      260,   0.0000,  2.0484e-03,   2.7547e-02,         0
next_root 0 =  31
next_root 0 =  32
next_root 1 =  32
next_root 0 =  33
next_root 1 =  33
next_root 0 =  34
next_root 0 =  35
next_root 1 =  35
next_root 0 =  36
next_root 0 =  37
next_root 0 =  38
next_root 1 =  38
next_root 0 =  39
next_root 0 =  40
      270,   0.0000,  2.0494e-03,   2.7256e-02,         0
next_root 0 =  41
next_root 0 =  42
next_root 1 =  42
next_root 0 =  43
next_root 0 =  44
next_root 0 =  45
next_root 0 =  0
next_root 1 =  0
next_root 0 =  1
next_root 1 =  1
next_root 2 =  1
next_root 0 =  2
next_root 1 =  2
next_root 2 =  2
next_root 0 =  3
next_root 1 =  3
next_root 0 =  4
next_root 1 =  4
      280,   0.0000,  2.0503e-03,   2.7424e-02,         1
next_root 0 =  5
next_root 1 =  5
next_root 2 =  5
next_root 0 =  6
next_root 1 =  6
next_root 0 =  7
next_root 1 =  7
next_root 2 =  7
next_root 0 =  8
next_root 1 =  8
next_root 0 =  9
next_root 1 =  9
next_root 0 =  10
next_root 1 =  10
next_root 0 =  11
next_root 1 =  11
next_root 0 =  12
next_root 1 =  12
next_root 0 =  13
next_root 1 =  13
next_root 0 =  14
next_root 1 =  14
      290,   0.0000,  2.0511e-03,   2.7508e-02,         1
next_root 0 =  15
next_root 1 =  15
next_root 0 =  16
next_root 1 =  16
next_root 0 =  17
next_root 1 =  17
next_root 0 =  18
next_root 1 =  18
next_root 0 =  19
next_root 1 =  19
next_root 0 =  20
next_root 1 =  20
next_root 0 =  21
next_root 1 =  21
next_root 0 =  22
next_root 1 =  22
next_root 0 =  23
next_root 0 =  24
next_root 1 =  24
      300,   0.0000,  2.0520e-03,   2.7258e-02,         1
next_root 0 =  25
next_root 0 =  26
next_root 1 =  26
next_root 0 =  27
next_root 1 =  27
next_root 0 =  28
next_root 1 =  28
next_root 0 =  29
next_root 1 =  29
next_root 0 =  30
next_root 1 =  30
next_root 0 =  31
next_root 0 =  32
next_root 1 =  32
next_root 2 =  32
next_root 0 =  33
next_root 1 =  33
next_root 2 =  33
next_root 0 =  34
next_root 1 =  34
      310,   0.0000,  2.0527e-03,   2.7769e-02,         1
next_root 0 =  35
next_root 1 =  35
next_root 0 =  36
next_root 1 =  36
next_root 0 =  37
next_root 1 =  37
next_root 0 =  38
next_root 1 =  38
next_root 0 =  39
next_root 1 =  39
next_root 0 =  40
next_root 1 =  40
next_root 0 =  41
next_root 1 =  41
next_root 0 =  42
next_root 1 =  42
next_root 0 =  43
next_root 1 =  43
next_root 0 =  44
next_root 1 =  44
      320,   0.0000,  2.0535e-03,   2.7419e-02,         1
next_root 0 =  45
next_root 1 =  45
next_root 0 =  0
next_root 1 =  0
next_root 2 =  0
next_root 0 =  1
next_root 1 =  1
next_root 2 =  1
next_root 0 =  2
next_root 1 =  2
next_root 2 =  2
next_root 3 =  2
next_root 0 =  3
next_root 1 =  3
next_root 2 =  3
next_root 0 =  4
next_root 1 =  4
next_root 2 =  4
next_root 0 =  5
next_root 1 =  5
next_root 2 =  5
next_root 3 =  5
next_root 0 =  6
next_root 1 =  6
next_root 2 =  6
next_root 0 =  7
next_root 1 =  7
next_root 2 =  7
next_root 3 =  7
next_root 0 =  8
next_root 1 =  8
next_root 2 =  8
      330,   0.0000,  2.0539e-03,   2.7790e-02,         2
next_root 0 =  9
next_root 1 =  9
next_root 2 =  9
next_root 0 =  10
next_root 1 =  10
next_root 2 =  10
next_root 0 =  11
next_root 1 =  11
next_root 2 =  11
next_root 0 =  12
next_root 1 =  12
next_root 2 =  12
next_root 0 =  13
next_root 1 =  13
next_root 2 =  13
next_root 0 =  14
next_root 1 =  14
next_root 2 =  14
next_root 0 =  15
next_root 1 =  15
next_root 2 =  15
next_root 0 =  16
next_root 1 =  16
next_root 2 =  16
next_root 0 =  17
next_root 1 =  17
next_root 2 =  17
next_root 0 =  18
next_root 1 =  18
next_root 2 =  18
      340,   0.0000,  2.0544e-03,   2.7167e-02,         2
next_root 0 =  19
next_root 1 =  19
next_root 2 =  19
next_root 0 =  20
next_root 1 =  20
next_root 2 =  20
next_root 0 =  21
next_root 1 =  21
next_root 2 =  21
next_root 0 =  22
next_root 1 =  22
next_root 2 =  22
next_root 0 =  23
next_root 1 =  23
next_root 2 =  23
next_root 0 =  24
next_root 1 =  24
next_root 2 =  24
next_root 0 =  25
next_root 1 =  25
      347,   0.0000,  2.0548e-03,   2.7386e-02,         1
Terminating due to small change in objective.
[8]:

fig, axs = plt.subplots(ncols=4,sharex=True,figsize=(16,4)) axs[0].plot(res.costs) axs[0].set_title('Dual Objective') axs[0].set_xlabel('Iteration') axs[1].semilogy(res.grad_sq_norms) axs[1].set_title('Squared Norm of Gradient') axs[1].set_xlabel('Iteration') axs[2].semilogy(res.step_sizes) axs[2].set_title('Step Size (after line search)') axs[2].set_xlabel('Iteration') axs[3].plot(res.line_its) axs[3].set_title('Line Search Iterations') axs[3].set_xlabel('Iteration') plt.show()
../_images/examples_MNISTBarycenter_13_0.png

Compute the Barycenter#

[9]:

bary = prob.Barycenter(res.dual_vars)

Plot the images used and the barycenter#

[10]:
vmax = np.max([np.max(m) for m in measures])
fig, axs = plt.subplots(1, num_digits+1, figsize=((num_digits+1)*4,4))
for i in range(num_digits):
    axs[i].imshow(measures[i], origin='lower', extent=(0,1,0,1), vmin=0, vmax=vmax, cmap='Greys')
    axs[i].set_title("$\\mu_{{ {:0d} }}$".format(i))

axs[-1].imshow(bary, origin='lower', extent=(0,1,0,1), vmin=0, vmax=vmax, cmap='Greys')
axs[-1].set_title('Barycenter')
[10]:
Text(0.5, 1.0, 'Barycenter')
../_images/examples_MNISTBarycenter_17_1.png
[ ]:

[ ]: