{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Barycenter of MNIST Digits" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "\n", "from mmot import MMOTSolver\n", "\n", "import numpy as np \n", "import matplotlib.pyplot as plt \n", "import itertools" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Download and open the MNIST dataset" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "vscode": { "languageId": "markdown" } }, "outputs": [], "source": [ "import hashlib \n", "import os \n", "import requests \n", "import gzip\n", "\n", "#fetch data (adapted from https://github.com/geohot/ai-notebooks/blob/master/mnist_from_scratch.ipynb)\n", "path='./'\n", "def fetch(url):\n", " fp = os.path.join(path, hashlib.md5(url.encode('utf-8')).hexdigest())\n", " if os.path.isfile(fp):\n", " with open(fp, \"rb\") as f:\n", " data = f.read()\n", " else:\n", " with open(fp, \"wb\") as f:\n", " data = requests.get(url).content\n", " f.write(data)\n", " return np.frombuffer(gzip.decompress(data), dtype=np.uint8).copy()\n", "\n", "digits = fetch(\"http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\")[0x10:].reshape((-1, 28, 28))\n", "labels = fetch(\"http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\")[8:]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Plot a few samples of the digit we're interested in" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "desired_digit = 3\n", "inds = np.where(labels==desired_digit)[0]\n", "\n", "num_plot = 5\n", "fig,axs = plt.subplots(ncols=num_plot, sharey=True, figsize=(num_plot*5,5))\n", "for i in range(num_plot):\n", " axs[i].imshow(digits[inds[i],:,:],cmap='Greys')\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Extract digits with similar total measure" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# Grid of size n1 x n2\n", "n1 = digits.shape[1] # x axis\n", "n2 = digits.shape[2] # y axis\n", "\n", "x, y = np.meshgrid(np.linspace(0.5/n1,1-0.5/n1,n1), np.linspace(0.5/n2,1-0.5/n1,n2))" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "unroll_node = 0\n", "\n", "num_digits = 10 \n", "sums = np.array([np.sum(digits[i,:,:]) for i in inds])\n", "\n", "val = np.sum(digits[inds[0],:,:])\n", "\n", "sorted_inds = inds[np.argsort(np.abs(sums-val))]\n", "measures = [digits[sorted_inds[i],:,:]* (n1*n2/np.sum(digits[sorted_inds[i],:,:])) for i in range(num_digits)]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define the edge list for the barycenter problem" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "\n", "edge_list = []\n", "for i in range(num_digits):\n", " for j in range(i+1,num_digits):\n", " edge_list += [[i,j]]\n", "\n", "weights = np.ones(num_digits)/num_digits\n", "prob = MMOTSolver(measures, edge_list, x, y, unroll_node, weights)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Solve the problem" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Iteration, StepSize, Cost, Error, Line Its\n", "next_root 0 = 0\n", "next_root 1 = 0\n", "next_root 2 = 0\n", "next_root 3 = 0\n", "next_root 4 = 0\n", "next_root 5 = 0\n", "next_root 6 = 0\n", "next_root 7 = 0\n", "next_root 8 = 0\n", " 0, 0.0008, 1.2998e-04, 1.0349e+00, 8\n", "next_root 0 = 1\n", "next_root 0 = 2\n", "next_root 0 = 3\n", "next_root 0 = 4\n", "next_root 0 = 5\n", "next_root 0 = 6\n", "next_root 0 = 7\n", "next_root 0 = 8\n", "next_root 0 = 9\n", "next_root 1 = 9\n", "next_root 0 = 10\n", "next_root 1 = 10\n", " 10, 0.0012, 1.2499e-03, 1.5150e-01, 1\n", "next_root 0 = 11\n", "next_root 0 = 12\n", "next_root 0 = 13\n", "next_root 0 = 14\n", "next_root 0 = 15\n", "next_root 0 = 16\n", "next_root 0 = 17\n", "next_root 1 = 17\n", "next_root 2 = 17\n", "next_root 0 = 18\n", "next_root 0 = 19\n", "next_root 0 = 20\n", "next_root 1 = 20\n", " 20, 0.0009, 1.4826e-03, 4.6731e-02, 1\n", "next_root 0 = 21\n", "next_root 0 = 22\n", "next_root 0 = 23\n", "next_root 1 = 23\n", "next_root 0 = 24\n", "next_root 0 = 25\n", "next_root 0 = 26\n", "next_root 1 = 26\n", "next_root 0 = 27\n", "next_root 0 = 28\n", "next_root 0 = 29\n", "next_root 1 = 29\n", "next_root 0 = 30\n", " 30, 0.0007, 1.5566e-03, 4.0157e-02, 0\n", "next_root 0 = 31\n", "next_root 0 = 32\n", "next_root 1 = 32\n", "next_root 0 = 33\n", "next_root 0 = 34\n", "next_root 0 = 35\n", "next_root 0 = 36\n", "next_root 1 = 36\n", "next_root 0 = 37\n", "next_root 0 = 38\n", "next_root 0 = 39\n", "next_root 1 = 39\n", "next_root 0 = 40\n", " 40, 0.0006, 1.6275e-03, 4.0973e-02, 0\n", "next_root 0 = 41\n", "next_root 0 = 42\n", "next_root 1 = 42\n", "next_root 0 = 43\n", "next_root 0 = 44\n", "next_root 0 = 45\n", "next_root 0 = 0\n", "next_root 0 = 1\n", "next_root 0 = 2\n", "next_root 0 = 3\n", "next_root 1 = 3\n", "next_root 0 = 4\n", " 50, 0.0009, 1.7358e-03, 3.2187e-02, 0\n", "next_root 0 = 5\n", "next_root 1 = 5\n", "next_root 0 = 6\n", "next_root 0 = 7\n", "next_root 1 = 7\n", "next_root 0 = 8\n", "next_root 0 = 9\n", "next_root 0 = 10\n", "next_root 0 = 11\n", "next_root 1 = 11\n", "next_root 0 = 12\n", "next_root 0 = 13\n", "next_root 1 = 13\n", "next_root 0 = 14\n", " 60, 0.0003, 1.7963e-03, 2.8782e-02, 0\n", "next_root 0 = 15\n", "next_root 0 = 16\n", "next_root 0 = 17\n", "next_root 0 = 18\n", "next_root 0 = 19\n", "next_root 1 = 19\n", "next_root 0 = 20\n", "next_root 0 = 21\n", "next_root 0 = 22\n", "next_root 1 = 22\n", "next_root 0 = 23\n", "next_root 0 = 24\n", " 70, 0.0005, 1.8271e-03, 2.7783e-02, 0\n", "next_root 0 = 25\n", "next_root 0 = 26\n", "next_root 1 = 26\n", "next_root 0 = 27\n", "next_root 0 = 28\n", "next_root 0 = 29\n", "next_root 1 = 29\n", "next_root 0 = 30\n", "next_root 0 = 31\n", "next_root 0 = 32\n", "next_root 1 = 32\n", "next_root 0 = 33\n", "next_root 1 = 33\n", "next_root 0 = 34\n", " 80, 0.0002, 1.8805e-03, 2.8257e-02, 0\n", "next_root 0 = 35\n", "next_root 0 = 36\n", "next_root 0 = 37\n", "next_root 0 = 38\n", "next_root 1 = 38\n", "next_root 0 = 39\n", "next_root 0 = 40\n", "next_root 0 = 41\n", "next_root 0 = 42\n", "next_root 1 = 42\n", "next_root 0 = 43\n", "next_root 0 = 44\n", " 90, 0.0003, 1.9226e-03, 2.9543e-02, 0\n", "next_root 0 = 45\n", "next_root 0 = 0\n", "next_root 1 = 0\n", "next_root 0 = 1\n", "next_root 1 = 1\n", "next_root 0 = 2\n", "next_root 0 = 3\n", "next_root 0 = 4\n", "next_root 0 = 5\n", "next_root 1 = 5\n", "next_root 0 = 6\n", "next_root 0 = 7\n", "next_root 1 = 7\n", "next_root 0 = 8\n", " 100, 0.0001, 1.9420e-03, 2.8192e-02, 0\n", "next_root 0 = 9\n", "next_root 0 = 10\n", "next_root 0 = 11\n", "next_root 0 = 12\n", "next_root 0 = 13\n", "next_root 1 = 13\n", "next_root 0 = 14\n", "next_root 0 = 15\n", "next_root 0 = 16\n", "next_root 0 = 17\n", "next_root 1 = 17\n", "next_root 0 = 18\n", " 110, 0.0002, 1.9607e-03, 2.6876e-02, 0\n", "next_root 0 = 19\n", "next_root 0 = 20\n", "next_root 1 = 20\n", "next_root 0 = 21\n", "next_root 0 = 22\n", "next_root 0 = 23\n", "next_root 0 = 24\n", "next_root 1 = 24\n", "next_root 0 = 25\n", "next_root 0 = 26\n", "next_root 0 = 27\n", "next_root 0 = 28\n", "next_root 1 = 28\n", " 120, 0.0001, 1.9792e-03, 2.6153e-02, 1\n", "next_root 0 = 29\n", "next_root 0 = 30\n", "next_root 0 = 31\n", "next_root 0 = 32\n", "next_root 1 = 32\n", "next_root 0 = 33\n", "next_root 1 = 33\n", "next_root 0 = 34\n", "next_root 0 = 35\n", "next_root 0 = 36\n", "next_root 0 = 37\n", "next_root 0 = 38\n", "next_root 1 = 38\n", " 130, 0.0001, 1.9903e-03, 2.7401e-02, 1\n", "next_root 0 = 39\n", "next_root 0 = 40\n", "next_root 0 = 41\n", "next_root 0 = 42\n", "next_root 1 = 42\n", "next_root 0 = 43\n", "next_root 0 = 44\n", "next_root 0 = 45\n", "next_root 0 = 0\n", "next_root 1 = 0\n", "next_root 0 = 1\n", "next_root 1 = 1\n", "next_root 0 = 2\n", "next_root 1 = 2\n", " 140, 0.0000, 2.0004e-03, 2.7720e-02, 1\n", "next_root 0 = 3\n", "next_root 0 = 4\n", "next_root 0 = 5\n", "next_root 0 = 6\n", "next_root 0 = 7\n", "next_root 1 = 7\n", "next_root 0 = 8\n", "next_root 0 = 9\n", "next_root 0 = 10\n", "next_root 0 = 11\n", "next_root 1 = 11\n", "next_root 0 = 12\n", " 150, 0.0001, 2.0087e-03, 2.7586e-02, 0\n", "next_root 0 = 13\n", "next_root 0 = 14\n", "next_root 0 = 15\n", "next_root 1 = 15\n", "next_root 0 = 16\n", "next_root 0 = 17\n", "next_root 0 = 18\n", "next_root 0 = 19\n", "next_root 0 = 20\n", "next_root 1 = 20\n", "next_root 0 = 21\n", "next_root 0 = 22\n", " 160, 0.0001, 2.0155e-03, 2.7085e-02, 0\n", "next_root 0 = 23\n", "next_root 0 = 24\n", "next_root 1 = 24\n", "next_root 0 = 25\n", "next_root 0 = 26\n", "next_root 0 = 27\n", "next_root 0 = 28\n", "next_root 1 = 28\n", "next_root 0 = 29\n", "next_root 0 = 30\n", "next_root 0 = 31\n", "next_root 0 = 32\n", "next_root 1 = 32\n", " 170, 0.0001, 2.0205e-03, 2.7040e-02, 1\n", "next_root 0 = 33\n", "next_root 1 = 33\n", "next_root 0 = 34\n", "next_root 0 = 35\n", "next_root 0 = 36\n", "next_root 0 = 37\n", "next_root 1 = 37\n", "next_root 0 = 38\n", "next_root 0 = 39\n", "next_root 0 = 40\n", "next_root 0 = 41\n", "next_root 0 = 42\n", "next_root 1 = 42\n", " 180, 0.0001, 2.0259e-03, 2.7703e-02, 1\n", "next_root 0 = 43\n", "next_root 0 = 44\n", "next_root 0 = 45\n", "next_root 0 = 0\n", "next_root 1 = 0\n", "next_root 0 = 1\n", "next_root 1 = 1\n", "next_root 0 = 2\n", "next_root 1 = 2\n", "next_root 0 = 3\n", "next_root 0 = 4\n", "next_root 0 = 5\n", "next_root 1 = 5\n", "next_root 0 = 6\n", " 190, 0.0000, 2.0300e-03, 2.7548e-02, 0\n", "next_root 0 = 7\n", "next_root 0 = 8\n", "next_root 0 = 9\n", "next_root 0 = 10\n", "next_root 0 = 11\n", "next_root 1 = 11\n", "next_root 0 = 12\n", "next_root 0 = 13\n", "next_root 1 = 13\n", "next_root 0 = 14\n", "next_root 0 = 15\n", "next_root 0 = 16\n", " 200, 0.0000, 2.0350e-03, 2.7808e-02, 0\n", "next_root 0 = 17\n", "next_root 0 = 18\n", "next_root 0 = 19\n", "next_root 1 = 19\n", "next_root 0 = 20\n", "next_root 0 = 21\n", "next_root 0 = 22\n", "next_root 0 = 23\n", "next_root 0 = 24\n", "next_root 1 = 24\n", "next_root 0 = 25\n", "next_root 0 = 26\n", "next_root 1 = 26\n", " 210, 0.0000, 2.0386e-03, 2.7053e-02, 1\n", "next_root 0 = 27\n", "next_root 0 = 28\n", "next_root 0 = 29\n", "next_root 1 = 29\n", "next_root 0 = 30\n", "next_root 0 = 31\n", "next_root 0 = 32\n", "next_root 1 = 32\n", "next_root 0 = 33\n", "next_root 0 = 34\n", "next_root 0 = 35\n", "next_root 0 = 36\n", " 220, 0.0000, 2.0425e-03, 2.7091e-02, 0\n", "next_root 0 = 37\n", "next_root 1 = 37\n", "next_root 0 = 38\n", "next_root 0 = 39\n", "next_root 0 = 40\n", "next_root 1 = 40\n", "next_root 0 = 41\n", "next_root 0 = 42\n", "next_root 0 = 43\n", "next_root 0 = 44\n", "next_root 0 = 45\n", "next_root 0 = 0\n", "next_root 1 = 0\n", " 230, 0.0000, 2.0447e-03, 2.7168e-02, 1\n", "next_root 0 = 1\n", "next_root 1 = 1\n", "next_root 2 = 1\n", "next_root 0 = 2\n", "next_root 1 = 2\n", "next_root 0 = 3\n", "next_root 1 = 3\n", "next_root 0 = 4\n", "next_root 0 = 5\n", "next_root 1 = 5\n", "next_root 0 = 6\n", "next_root 0 = 7\n", "next_root 1 = 7\n", "next_root 0 = 8\n", "next_root 0 = 9\n", "next_root 0 = 10\n", " 240, 0.0000, 2.0456e-03, 2.7954e-02, 0\n", "next_root 0 = 11\n", "next_root 1 = 11\n", "next_root 0 = 12\n", "next_root 0 = 13\n", "next_root 1 = 13\n", "next_root 0 = 14\n", "next_root 0 = 15\n", "next_root 1 = 15\n", "next_root 0 = 16\n", "next_root 0 = 17\n", "next_root 0 = 18\n", "next_root 0 = 19\n", "next_root 1 = 19\n", "next_root 0 = 20\n", " 250, 0.0000, 2.0472e-03, 2.7679e-02, 0\n", "next_root 0 = 21\n", "next_root 0 = 22\n", "next_root 0 = 23\n", "next_root 1 = 23\n", "next_root 0 = 24\n", "next_root 1 = 24\n", "next_root 0 = 25\n", "next_root 0 = 26\n", "next_root 0 = 27\n", "next_root 0 = 28\n", "next_root 1 = 28\n", "next_root 0 = 29\n", "next_root 0 = 30\n", " 260, 0.0000, 2.0484e-03, 2.7547e-02, 0\n", "next_root 0 = 31\n", "next_root 0 = 32\n", "next_root 1 = 32\n", "next_root 0 = 33\n", "next_root 1 = 33\n", "next_root 0 = 34\n", "next_root 0 = 35\n", "next_root 1 = 35\n", "next_root 0 = 36\n", "next_root 0 = 37\n", "next_root 0 = 38\n", "next_root 1 = 38\n", "next_root 0 = 39\n", "next_root 0 = 40\n", " 270, 0.0000, 2.0494e-03, 2.7256e-02, 0\n", "next_root 0 = 41\n", "next_root 0 = 42\n", "next_root 1 = 42\n", "next_root 0 = 43\n", "next_root 0 = 44\n", "next_root 0 = 45\n", "next_root 0 = 0\n", "next_root 1 = 0\n", "next_root 0 = 1\n", "next_root 1 = 1\n", "next_root 2 = 1\n", "next_root 0 = 2\n", "next_root 1 = 2\n", "next_root 2 = 2\n", "next_root 0 = 3\n", "next_root 1 = 3\n", "next_root 0 = 4\n", "next_root 1 = 4\n", " 280, 0.0000, 2.0503e-03, 2.7424e-02, 1\n", "next_root 0 = 5\n", "next_root 1 = 5\n", "next_root 2 = 5\n", "next_root 0 = 6\n", "next_root 1 = 6\n", "next_root 0 = 7\n", "next_root 1 = 7\n", "next_root 2 = 7\n", "next_root 0 = 8\n", "next_root 1 = 8\n", "next_root 0 = 9\n", "next_root 1 = 9\n", "next_root 0 = 10\n", "next_root 1 = 10\n", "next_root 0 = 11\n", "next_root 1 = 11\n", "next_root 0 = 12\n", "next_root 1 = 12\n", "next_root 0 = 13\n", "next_root 1 = 13\n", "next_root 0 = 14\n", "next_root 1 = 14\n", " 290, 0.0000, 2.0511e-03, 2.7508e-02, 1\n", "next_root 0 = 15\n", "next_root 1 = 15\n", "next_root 0 = 16\n", "next_root 1 = 16\n", "next_root 0 = 17\n", "next_root 1 = 17\n", "next_root 0 = 18\n", "next_root 1 = 18\n", "next_root 0 = 19\n", "next_root 1 = 19\n", "next_root 0 = 20\n", "next_root 1 = 20\n", "next_root 0 = 21\n", "next_root 1 = 21\n", "next_root 0 = 22\n", "next_root 1 = 22\n", "next_root 0 = 23\n", "next_root 0 = 24\n", "next_root 1 = 24\n", " 300, 0.0000, 2.0520e-03, 2.7258e-02, 1\n", "next_root 0 = 25\n", "next_root 0 = 26\n", "next_root 1 = 26\n", "next_root 0 = 27\n", "next_root 1 = 27\n", "next_root 0 = 28\n", "next_root 1 = 28\n", "next_root 0 = 29\n", "next_root 1 = 29\n", "next_root 0 = 30\n", "next_root 1 = 30\n", "next_root 0 = 31\n", "next_root 0 = 32\n", "next_root 1 = 32\n", "next_root 2 = 32\n", "next_root 0 = 33\n", "next_root 1 = 33\n", "next_root 2 = 33\n", "next_root 0 = 34\n", "next_root 1 = 34\n", " 310, 0.0000, 2.0527e-03, 2.7769e-02, 1\n", "next_root 0 = 35\n", "next_root 1 = 35\n", "next_root 0 = 36\n", "next_root 1 = 36\n", "next_root 0 = 37\n", "next_root 1 = 37\n", "next_root 0 = 38\n", "next_root 1 = 38\n", "next_root 0 = 39\n", "next_root 1 = 39\n", "next_root 0 = 40\n", "next_root 1 = 40\n", "next_root 0 = 41\n", "next_root 1 = 41\n", "next_root 0 = 42\n", "next_root 1 = 42\n", "next_root 0 = 43\n", "next_root 1 = 43\n", "next_root 0 = 44\n", "next_root 1 = 44\n", " 320, 0.0000, 2.0535e-03, 2.7419e-02, 1\n", "next_root 0 = 45\n", "next_root 1 = 45\n", "next_root 0 = 0\n", "next_root 1 = 0\n", "next_root 2 = 0\n", "next_root 0 = 1\n", "next_root 1 = 1\n", "next_root 2 = 1\n", "next_root 0 = 2\n", "next_root 1 = 2\n", "next_root 2 = 2\n", "next_root 3 = 2\n", "next_root 0 = 3\n", "next_root 1 = 3\n", "next_root 2 = 3\n", "next_root 0 = 4\n", "next_root 1 = 4\n", "next_root 2 = 4\n", "next_root 0 = 5\n", "next_root 1 = 5\n", "next_root 2 = 5\n", "next_root 3 = 5\n", "next_root 0 = 6\n", "next_root 1 = 6\n", "next_root 2 = 6\n", "next_root 0 = 7\n", "next_root 1 = 7\n", "next_root 2 = 7\n", "next_root 3 = 7\n", "next_root 0 = 8\n", "next_root 1 = 8\n", "next_root 2 = 8\n", " 330, 0.0000, 2.0539e-03, 2.7790e-02, 2\n", "next_root 0 = 9\n", "next_root 1 = 9\n", "next_root 2 = 9\n", "next_root 0 = 10\n", "next_root 1 = 10\n", "next_root 2 = 10\n", "next_root 0 = 11\n", "next_root 1 = 11\n", "next_root 2 = 11\n", "next_root 0 = 12\n", "next_root 1 = 12\n", "next_root 2 = 12\n", "next_root 0 = 13\n", "next_root 1 = 13\n", "next_root 2 = 13\n", "next_root 0 = 14\n", "next_root 1 = 14\n", "next_root 2 = 14\n", "next_root 0 = 15\n", "next_root 1 = 15\n", "next_root 2 = 15\n", "next_root 0 = 16\n", "next_root 1 = 16\n", "next_root 2 = 16\n", "next_root 0 = 17\n", "next_root 1 = 17\n", "next_root 2 = 17\n", "next_root 0 = 18\n", "next_root 1 = 18\n", "next_root 2 = 18\n", " 340, 0.0000, 2.0544e-03, 2.7167e-02, 2\n", "next_root 0 = 19\n", "next_root 1 = 19\n", "next_root 2 = 19\n", "next_root 0 = 20\n", "next_root 1 = 20\n", "next_root 2 = 20\n", "next_root 0 = 21\n", "next_root 1 = 21\n", "next_root 2 = 21\n", "next_root 0 = 22\n", "next_root 1 = 22\n", "next_root 2 = 22\n", "next_root 0 = 23\n", "next_root 1 = 23\n", "next_root 2 = 23\n", "next_root 0 = 24\n", "next_root 1 = 24\n", "next_root 2 = 24\n", "next_root 0 = 25\n", "next_root 1 = 25\n", " 347, 0.0000, 2.0548e-03, 2.7386e-02, 1\n", "Terminating due to small change in objective.\n" ] } ], "source": [ "res = prob.Solve(max_its=10000, step_size=0.2, ftol_abs=1e-9, gtol_abs=1e-3)\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "\n", " \n", "fig, axs = plt.subplots(ncols=4,sharex=True,figsize=(16,4))\n", "axs[0].plot(res.costs)\n", "axs[0].set_title('Dual Objective')\n", "axs[0].set_xlabel('Iteration')\n", "\n", "axs[1].semilogy(res.grad_sq_norms)\n", "axs[1].set_title('Squared Norm of Gradient')\n", "axs[1].set_xlabel('Iteration')\n", "\n", "axs[2].semilogy(res.step_sizes)\n", "axs[2].set_title('Step Size (after line search)')\n", "axs[2].set_xlabel('Iteration')\n", "\n", "axs[3].plot(res.line_its)\n", "axs[3].set_title('Line Search Iterations')\n", "axs[3].set_xlabel('Iteration')\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Compute the Barycenter" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "\n", "bary = prob.Barycenter(res.dual_vars)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Plot the images used and the barycenter" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Text(0.5, 1.0, 'Barycenter')" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "vmax = np.max([np.max(m) for m in measures])\n", "fig, axs = plt.subplots(1, num_digits+1, figsize=((num_digits+1)*4,4))\n", "for i in range(num_digits):\n", " axs[i].imshow(measures[i], origin='lower', extent=(0,1,0,1), vmin=0, vmax=vmax, cmap='Greys')\n", " axs[i].set_title(\"$\\\\mu_{{ {:0d} }}$\".format(i))\n", " \n", "axs[-1].imshow(bary, origin='lower', extent=(0,1,0,1), vmin=0, vmax=vmax, cmap='Greys')\n", "axs[-1].set_title('Barycenter')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "interpreter": { "hash": "545e036c4b32438aced1f6b3c8d38ca151d9c36189e05839cb0aa568fda70ddd" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.4" } }, "nbformat": 4, "nbformat_minor": 4 }