<?xml version="1.0" encoding="UTF-8"?>
<rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom">
  <channel>
    <title>Kevin Zakka's Blog</title>
    <description></description>
    <link>http://kevinzakka.github.io/</link>
    <atom:link href="http://kevinzakka.github.io/feed.xml" rel="self" type="application/rss+xml" />
    <pubDate>Sun, 05 Jul 2020 23:39:23 +0000</pubDate>
    <lastBuildDate>Sun, 05 Jul 2020 23:39:23 +0000</lastBuildDate>
    <generator>Jekyll v3.8.7</generator>
    
      <item>
        <title>kNN classification using Neighbourhood Components Analysis</title>
        <description>&lt;p&gt;&lt;small&gt;&lt;strong&gt;Update (12/02/2020)&lt;/strong&gt;: The implementation is now available as a &lt;a href=&quot;https://pypi.org/project/torchnca/&quot;&gt;pip package&lt;/a&gt;. Simply run &lt;em&gt;pip install torchnca&lt;/em&gt;.&lt;small&gt;&lt;/small&gt;&lt;/small&gt;&lt;/p&gt;

&lt;p&gt;While reading related work&lt;sup id=&quot;fnref:1&quot;&gt;&lt;a href=&quot;#fn:1&quot; class=&quot;footnote&quot;&gt;1&lt;/a&gt;&lt;/sup&gt; for my current research project, I stumbled upon a reference to a classic paper from 2004 called &lt;em&gt;Neighbourhood Components Analysis&lt;/em&gt; (NCA). After giving it a read, I was instantly charmed by its simplicity and elegance. Long story short, NCA allows you to learn a linear transformation of your data that maximizes k-nearest neighbours performance. By forcing the transformation to be low-rank, NCA will perform dimensionality reduction, leading to vastly reduced storage sizes and search times for kNN! NCA is a very useful algorithm to have in your toolkit – just like PCA – but it’s very rarely mentioned in the wild. In fact, I couldn’t find any tutorial or reference outside of academic papers. This post is an attempt to rectify this.&lt;/p&gt;

&lt;div class=&quot;imgcap&quot;&gt;
    &lt;button id=&quot;animButton&quot; onclick=&quot;toggleAnim()&quot; class=&quot;playbutton&quot;&gt;Play&lt;/button&gt;
    &lt;img alt=&quot;&quot; src=&quot;/assets/nca/banner-start.png&quot; width=&quot;70%&quot; id=&quot;animImage&quot; style=&quot;border:none;&quot; /&gt;
    &lt;div class=&quot;thecap&quot; style=&quot;text-align:center;&quot;&gt;&lt;b&gt;Figure 1:&lt;/b&gt; Visualizing the embedding space of a synthetic dataset as NCA trains.&lt;/div&gt;
&lt;/div&gt;

&lt;script language=&quot;javascript&quot;&gt;
    function toggleAnim() {

        path = document.getElementById(&quot;animImage&quot;).src;
        if (path.split('/').pop() == &quot;banner-start.png&quot;)
        {
            document.getElementById(&quot;animImage&quot;).src = &quot;/assets/nca/banner-smaller.gif&quot;;
            document.getElementById(&quot;animButton&quot;).textContent = &quot;Reset&quot;;
        }
        else
        {
            document.getElementById(&quot;animImage&quot;).src = &quot;/assets/nca/banner-start.png&quot;;
            document.getElementById(&quot;animButton&quot;).textContent = &quot;Play&quot;;
        }
    }
&lt;/script&gt;

&lt;p&gt;I’ve implemented NCA in PyTorch with some added bells and whistles. It took almost 1 week to get it to work right, but I gained a lot of insight along the way. I think implementing algorithms from scratch is a great way of building intuition&lt;sup id=&quot;fnref:2&quot;&gt;&lt;a href=&quot;#fn:2&quot; class=&quot;footnote&quot;&gt;2&lt;/a&gt;&lt;/sup&gt; for why things work – and by extension when and why they don’t – so I encourage the reader to do the same. There’s also a video presentation of NCA by one of the co-authors on &lt;a href=&quot;https://youtu.be/07erva41ZoI&quot;&gt;YouTube&lt;/a&gt; which should serve as a good supplement to this post.&lt;/p&gt;

&lt;div style=&quot;text-align: center;&quot;&gt;
    &lt;a href=&quot;https://papers.nips.cc/paper/2566-neighbourhood-components-analysis.pdf&quot; id=&quot;linkbutton&quot; target=&quot;_blank&quot; style=&quot;margin-right: 10px;&quot;&gt;Paper&lt;/a&gt;
    &lt;a href=&quot;https://github.com/kevinzakka/torchnca&quot; id=&quot;linkbutton&quot; target=&quot;_blank&quot; style=&quot;margin-left: 10px;&quot;&gt;PyTorch Code&lt;/a&gt;
&lt;/div&gt;

&lt;h4 id=&quot;table-of-contents&quot;&gt;Table of Contents&lt;/h4&gt;

&lt;ul&gt;
  &lt;li&gt;&lt;a href=&quot;#knn-issues&quot;&gt;kNN: The Good, The Bad, The Ugly&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#nca-rescue&quot;&gt;NCA to the rescue&lt;/a&gt;
    &lt;ul&gt;
      &lt;li&gt;&lt;a href=&quot;#loss-func&quot;&gt;Formulating the loss function&lt;/a&gt;&lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#contrastive&quot;&gt;NCA as a special case of the contrastive loss&lt;/a&gt;&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#pytorch&quot;&gt;NCA in PyTorch&lt;/a&gt;
    &lt;ul&gt;
      &lt;li&gt;&lt;a href=&quot;#init&quot;&gt;Initialization&lt;/a&gt;&lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#comp-loss&quot;&gt;Loss function&lt;/a&gt;&lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#sgd&quot;&gt;Replacing Conjugate Gradients with SGD&lt;/a&gt;&lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#tricks&quot;&gt;Stability tricks&lt;/a&gt;&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#results&quot;&gt;Boring… Show me what it can do!&lt;/a&gt;
    &lt;ul&gt;
      &lt;li&gt;&lt;a href=&quot;#dim-reduct&quot;&gt;Dimensionality reduction&lt;/a&gt;&lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#sentiment&quot;&gt;kNN on MNIST&lt;/a&gt;&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#thankyou&quot;&gt;Acknowledgements&lt;/a&gt;&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;&lt;a name=&quot;knn-issues&quot;&gt;&lt;/a&gt;&lt;/p&gt;
&lt;h2 id=&quot;knn-the-good-the-bad-the-ugly&quot;&gt;kNN: The Good, The Bad, The Ugly&lt;/h2&gt;

&lt;div class=&quot;imgcap&quot;&gt;
&lt;img src=&quot;/assets/knn/teaser.png&quot; width=&quot;60%&quot; style=&quot;border:none;&quot; /&gt;
&lt;div class=&quot;thecap&quot; style=&quot;text-align:center&quot;&gt;&lt;b&gt;Figure 2:&lt;/b&gt; kNN's nonlinear decision boundary &lt;a href=&quot;http://scott.fortmann-roe.com/docs/BiasVariance.html&quot;&gt;(source)&lt;/a&gt;.&lt;/div&gt;
&lt;/div&gt;

&lt;p&gt;You’ve probably &lt;a href=&quot;https://kevinzakka.github.io/2016/07/13/k-nearest-neighbor/&quot;&gt;heard&lt;/a&gt; of k-nearest neighbours (kNN) &lt;em&gt;at least&lt;/em&gt; once in your life. It’s one of the first algorithms taught in many machine learning classes, and not without good reason. There’s lots to love about kNN! To name a few:&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;It has an extremely simple implementation. In fact, kNN has absolutely no computational training cost.&lt;/li&gt;
  &lt;li&gt;It’s decision boundary, controlled by &lt;script type=&quot;math/tex&quot;&gt;k&lt;/script&gt;, is highly nonlinear (the lines in Figure 2 are locally linear but their overall shape can’t be defined by a hyperplane). For low values of &lt;script type=&quot;math/tex&quot;&gt;k&lt;/script&gt;, kNN has very little &lt;a href=&quot;https://en.wikipedia.org/wiki/Inductive_bias&quot;&gt;inductive bias&lt;/a&gt;.&lt;/li&gt;
  &lt;li&gt;There’s just a single hyperparameter to tune: the number of neighbours &lt;script type=&quot;math/tex&quot;&gt;k&lt;/script&gt;. You can easily find its optimal value with cross-validation.&lt;/li&gt;
  &lt;li&gt;It is asymptotically optimal. One can show that as the amount of data approaches infinity, k-NN is guaranteed to yield an error rate no worse than twice the Bayes error rate – the lowest possible error rate for any classifier – on a binary classification task. Or in other words, you can expect the performance of kNN to automatically improve as the number of training examples increases.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;But kNN does have some annoying drawbacks that limit its efficiency in big-data regimes. Specifically,&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;It has to store and search through the entire training data to classify just one test point. Without any optimizations, test-time classification is roughly &lt;script type=&quot;math/tex&quot;&gt;\mathcal{O}(n)&lt;/script&gt; given &lt;script type=&quot;math/tex&quot;&gt;n \gg d&lt;/script&gt;. That’s extremely unappealing from a deployment perspective since we usualy aim for a high test-time efficiency and low memory footprint.&lt;/li&gt;
  &lt;li&gt;In high dimensions, it suffers from the &lt;a href=&quot;https://en.wikipedia.org/wiki/Curse_of_dimensionality&quot;&gt;curse of dimensionality&lt;/a&gt;.&lt;/li&gt;
  &lt;li&gt;The choice of the distance metric can have a significant effect on its performance. What then is the optimal distance metric? How should one go about choosing it?&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;&lt;a name=&quot;nca-rescue&quot;&gt;&lt;/a&gt;&lt;/p&gt;
&lt;h2 id=&quot;nca-to-the-rescue&quot;&gt;NCA to the Rescue&lt;/h2&gt;

&lt;p&gt;Rather than having the user specify some arbitrary distance metric, NCA &lt;em&gt;learns&lt;/em&gt; it by choosing a parameterized family of quadratic distance metrics, constructing a loss function of the parameters, and optimizing it with gradient descent. Furthermore, the learned distance metric can explicitly be made low-dimensional, solving test-time storage and search issues. How does NCA do this?&lt;/p&gt;

&lt;p&gt;It turns out that learning a quadratic distance metric &lt;script type=&quot;math/tex&quot;&gt;\mathcal{d}&lt;/script&gt; of the input space where the performance of kNN is maximized is equivalent to learning a linear transformation &lt;script type=&quot;math/tex&quot;&gt;\mathcal{A}&lt;/script&gt; of the input space, such that in the transformed space, kNN with a Euclidean distance metric is maximized. In fact, quadratic distance metrics&lt;sup id=&quot;fnref:3&quot;&gt;&lt;a href=&quot;#fn:3&quot; class=&quot;footnote&quot;&gt;3&lt;/a&gt;&lt;/sup&gt; can be represented by a positive semi-definite matrix &lt;script type=&quot;math/tex&quot;&gt;Q = \mathcal{A}^T \mathcal{A}&lt;/script&gt; such that:&lt;/p&gt;

&lt;script type=&quot;math/tex; mode=display&quot;&gt;% &lt;![CDATA[
\begin{equation} \label{eq1}
\begin{split}
d(x_1, x_2) &amp;= (x_1 - x_2)^T Q (x_1- x_2) \\
 &amp; = (\mathcal{A}x_1 - \mathcal{A}x_2)^T (\mathcal{A}x_1 - \mathcal{A}x_2) \\
 &amp;= \langle y_1 - y_2, y_1 - y_2 \rangle
\end{split}
\end{equation} %]]&gt;&lt;/script&gt;

&lt;p&gt;The goal of the learning algorithm then, is to optimize the performance of kNN on future test data. Since we don’t a priori know the test data, we can choose instead to optimize the closest thing in our toolbox: the &lt;a href=&quot;https://en.wikipedia.org/wiki/Cross-validation_(statistics)#Leave-one-out_cross-validation&quot;&gt;&lt;em&gt;leave-one-out&lt;/em&gt;&lt;/a&gt; (LOO) performance of the training data.&lt;/p&gt;

&lt;p&gt;At this point, I’d like the reader to appreciate the elegance of NCA. We’ve transformed the problem of maximizing the classification accuracy of kNN into an optimization problem involving a two-dimensional matrix &lt;script type=&quot;math/tex&quot;&gt;\mathcal{A}&lt;/script&gt;. What remains is specifying a loss function that’s parameterized by &lt;script type=&quot;math/tex&quot;&gt;\mathcal{A}&lt;/script&gt; and that can serve as as a proxy for the LOO classification accuracy.&lt;/p&gt;

&lt;div class=&quot;imgcap&quot;&gt;
&lt;img src=&quot;/assets/nca/loo-disc.png&quot; width=&quot;50%&quot; style=&quot;border:none;&quot; /&gt;
&lt;div class=&quot;thecap&quot; style=&quot;text-align:center&quot;&gt;&lt;b&gt;Figure 3:&lt;/b&gt; The discontinuous graph of the LOO cross validation error. The red rectangle in particular illustrates how an infinitesimal change in the x-axis may change the value of the y-axis by a finite amount.&lt;/div&gt;
&lt;/div&gt;

&lt;p&gt;&lt;a name=&quot;loss-func&quot;&gt;&lt;/a&gt;
&lt;strong&gt;Formulating The Loss Function.&lt;/strong&gt; There’s a slight bump in our road: LOO error is a highly discontinuous loss function. The reason is that it depends solely on the neighbourhood graph of each point. If the distance metric changes slightly at first, there might be no change in the neighbourhood graph and thus no change of the LOO error. But then suddenly, an infinitesimal change in the metric can alter the neighbourhood graph of many points, causing a significant jump in the LOO error. This is illustrated in the figure above.&lt;/p&gt;

&lt;p&gt;Clearly, a discontinuous loss function is terrible for optimzation so we need to construct an alternative that is smooth and differentiable. The key to doing this is to replace &lt;em&gt;fixed&lt;/em&gt; neighbourhood selection (i.e. what is done in LOO cross-validation) with &lt;em&gt;stochastic&lt;/em&gt; neighbourhood selection. That is, each point &lt;script type=&quot;math/tex&quot;&gt;i&lt;/script&gt; in the training set selects another point &lt;script type=&quot;math/tex&quot;&gt;j&lt;/script&gt; as its neighbor with some probability &lt;script type=&quot;math/tex&quot;&gt;p_{ij}&lt;/script&gt; that is inversely proportional to the Euclidean distance &lt;script type=&quot;math/tex&quot;&gt;d_{ij}&lt;/script&gt; in the transformed space. By summing over all values of &lt;script type=&quot;math/tex&quot;&gt;j&lt;/script&gt;, we can compute the probability &lt;script type=&quot;math/tex&quot;&gt;p_i&lt;/script&gt; that a point &lt;script type=&quot;math/tex&quot;&gt;i&lt;/script&gt; will be correctly classified and then sum over all values of &lt;script type=&quot;math/tex&quot;&gt;p_i&lt;/script&gt; to obtain the total number of points we can expect to correctly classifiy.&lt;/p&gt;

&lt;p&gt;Denoting the set of points in the same class as &lt;script type=&quot;math/tex&quot;&gt;i&lt;/script&gt; by &lt;script type=&quot;math/tex&quot;&gt;C_i&lt;/script&gt;, our loss function&lt;sup id=&quot;fnref:4&quot;&gt;&lt;a href=&quot;#fn:4&quot; class=&quot;footnote&quot;&gt;4&lt;/a&gt;&lt;/sup&gt; thus becomes:&lt;/p&gt;

&lt;script type=&quot;math/tex; mode=display&quot;&gt;\mathcal{L}(X; \mathcal{A}) = -\sum_i p_i = - \sum_i \sum_{j \in C_i} p_{ij}&lt;/script&gt;

&lt;p&gt;where&lt;/p&gt;

&lt;script type=&quot;math/tex; mode=display&quot;&gt;% &lt;![CDATA[
\begin{equation} \label{eq2}
\begin{split}
p_{ij} &amp;= \frac{e^{-d_{ij}}}{\sum_{k \neq i} e^{-d_{ik}}}
 &amp;= \frac{\exp{\big(-\lVert Ax_i - Ax_j \lVert ^2\big)}}{\sum_{k \neq i} \exp{\big(- \lVert A x_i - Ax_k \lVert}\big)}
\end{split}
\end{equation} %]]&gt;&lt;/script&gt;

&lt;p&gt;The really neat thing about this stochastic assignment is that we’ve completely avoided having to specify a value of &lt;script type=&quot;math/tex&quot;&gt;k&lt;/script&gt;. It gets learned implicitly through the scale of the matrix &lt;script type=&quot;math/tex&quot;&gt;\mathcal{A}&lt;/script&gt;:&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;With larger values of &lt;script type=&quot;math/tex&quot;&gt;\mathcal{A}&lt;/script&gt;, the distance between points increases and as a result their probabilities decrease (think exponential of smaller and smaller values). This means kNN will consult fewer neighbours for each point.&lt;/li&gt;
  &lt;li&gt;With smaller values of &lt;script type=&quot;math/tex&quot;&gt;\mathcal{A}&lt;/script&gt;, the distance between points decreases and as a result their probabilities increase (think exponential of larger and larger values). This means kNN will consult more neighbours for each point.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;&lt;a name=&quot;contrastive&quot;&gt;&lt;/a&gt;
&lt;strong&gt;NCA as a special case of the contrastive loss.&lt;/strong&gt; If we slightly alter our loss function to sum over log probabilities &lt;script type=&quot;math/tex&quot;&gt;-\sum_i \log{p_i}&lt;/script&gt;, you’ll notice it looks just like a categorical cross entropy loss. In fact, you can think of NCA as a single hidden layer feed-forward neural network that performs metric learning with a contrastive loss function. Recall that a contrastive loss takes on the form:&lt;/p&gt;

&lt;script type=&quot;math/tex; mode=display&quot;&gt;\mathcal{L}_ {contr} = \alpha \mathcal{L}_ {pos} + \beta \mathcal{L}_ {neg}&lt;/script&gt;

&lt;p&gt;In most papers, &lt;script type=&quot;math/tex&quot;&gt;\mathcal{L}_ {pos}&lt;/script&gt; is an L2 loss, &lt;script type=&quot;math/tex&quot;&gt;\mathcal{L}_ {neg}&lt;/script&gt; is a hinge loss and &lt;script type=&quot;math/tex&quot;&gt;\alpha = \beta = 1&lt;/script&gt;. The NCA loss function uses a categorical cross-entropy loss for &lt;script type=&quot;math/tex&quot;&gt;\mathcal{L}_ {pos}&lt;/script&gt; with &lt;script type=&quot;math/tex&quot;&gt;\alpha = 1&lt;/script&gt; and &lt;script type=&quot;math/tex&quot;&gt;\beta = 0&lt;/script&gt;. This insight is going to be very valuable in our implementation of NCA when we talk about tricks to stabilize the training.&lt;/p&gt;

&lt;p&gt;&lt;a name=&quot;pytorch&quot;&gt;&lt;/a&gt;&lt;/p&gt;
&lt;h2 id=&quot;nca-in-pytorch&quot;&gt;NCA In PyTorch&lt;/h2&gt;

&lt;p&gt;There’s currently no GPU-accelerated version of NCA. The two most common ones at the time of this post are sklearn’s python &lt;a href=&quot;https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.NeighborhoodComponentsAnalysis.html&quot;&gt;implementation&lt;/a&gt; and a C++ &lt;a href=&quot;https://github.com/jhseu/nca&quot;&gt;implementation&lt;/a&gt;. This meant I had the perfect excuse to implement a version in PyTorch that could leverage (a) &lt;a href=&quot;https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html&quot;&gt;automatic differentiation&lt;/a&gt; to compute the gradient of the loss function with respect to &lt;script type=&quot;math/tex&quot;&gt;\mathcal{A}&lt;/script&gt; and (b) blazing fast GPU acceleration that would prove super useful for large datasets. While the implementation was pretty straightforward, getting it to converge consistently took quite a while. In this section, I’ll walk you through the high-level components needed to implement NCA plus all the additional bells and whistles I added to get it to converge. The entirety of the code is available on &lt;a href=&quot;https://github.com/kevinzakka/torchnca&quot;&gt;GitHub&lt;/a&gt;.&lt;/p&gt;

&lt;p&gt;&lt;a name=&quot;init&quot;&gt;&lt;/a&gt;
&lt;strong&gt;Initialization.&lt;/strong&gt; Since NCA is a gradient-based iterative optimization process, it requires that we specify an initialization strategy for the matrix &lt;script type=&quot;math/tex&quot;&gt;\mathcal{A}&lt;/script&gt;. The two obvious ones (no, not zero init!) are identity initialization and random initialization. Recall that if &lt;script type=&quot;math/tex&quot;&gt;d&lt;/script&gt; is the chosen dimension of the embedding space, and if &lt;script type=&quot;math/tex&quot;&gt;X \in \mathcal{R}^{N \ \times \ D}&lt;/script&gt; is our input dataset, then &lt;script type=&quot;math/tex&quot;&gt;\mathcal{A} \in \mathcal{R}^{d \ \times \ D}&lt;/script&gt;.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;D&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# feature space dimension
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# embedding space dimension
&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;init&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;random&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
  &lt;span class=&quot;c1&quot;&gt;# random init from a normal distribution
&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# with mean 0 and variance 0.01
&lt;/span&gt;  &lt;span class=&quot;n&quot;&gt;A&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Parameter&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;randn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;D&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.01&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;elif&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;init&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;identity&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
  &lt;span class=&quot;c1&quot;&gt;# identity init
&lt;/span&gt;  &lt;span class=&quot;n&quot;&gt;A&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Parameter&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;eye&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;D&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;&lt;a name=&quot;comp-loss&quot;&gt;&lt;/a&gt;
&lt;strong&gt;Loss Function.&lt;/strong&gt; Computing the loss function requires forming a matrix of pairwise Euclidean distances in the transformed space, applying a softmax over the negative distances to compute pairwise probabilities, then summing over probabilities belonging to the same class. The trick here is to vectorize the softmax computation whilst ignoring diagonal values of the distance matrix (i.e. values where &lt;script type=&quot;math/tex&quot;&gt;i = j&lt;/script&gt;) and probabilities that don’t have the same class labels.&lt;/p&gt;

&lt;p&gt;To compute a pairwise Euclidean distance matrix, we make use of the following code:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;pairwise_l2_sq&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
  &lt;span class=&quot;s&quot;&gt;&quot;&quot;&quot;Compute pairwise squared Euclidean distances.
  &quot;&quot;&quot;&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;dot&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;double&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;t&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;double&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()))&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;norm_sq&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;diag&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;dot&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;dist&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;norm_sq&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;dot&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;norm_sq&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[:,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;dist&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;clamp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;dist&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;min&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# replace negative values with 0
&lt;/span&gt;  &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dist&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;float&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Note the cast to &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;double&lt;/code&gt; to increase numerical precision in the dot product computation and the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;clamp&lt;/code&gt; method to replace any negative values that could have arisen from numerical imprecisions with zeros.&lt;/p&gt;

&lt;p&gt;Next, we want to compute a softmax over the negative distances to obtain the pairwise probability matrix &lt;script type=&quot;math/tex&quot;&gt;p_{ij}&lt;/script&gt;. Unlike a typical softmax implementation, the denominator in our equation sums over all &lt;script type=&quot;math/tex&quot;&gt;k \neq i&lt;/script&gt;, i.e. it skips the diagonal entries of the pairwise distance matrix. A neat trick to achieve this without modifying the softmax function is to fill the diagonal entries with &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;np.inf&lt;/code&gt;. That way, taking the exponential of their negative evaluates to 0 and doesn’t contribute to the normalization.&lt;/p&gt;

&lt;p&gt;Now for each row &lt;script type=&quot;math/tex&quot;&gt;i&lt;/script&gt; in &lt;script type=&quot;math/tex&quot;&gt;p_{ij}&lt;/script&gt;, we need to sum over all columns &lt;script type=&quot;math/tex&quot;&gt;j \in C_i&lt;/script&gt;. We can achieve this simply by creating a pairwise boolean mask of class labels, element-wise multiplying it with &lt;script type=&quot;math/tex&quot;&gt;p_{ij}&lt;/script&gt; then calling the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;sum&lt;/code&gt; method. The code below executes all the aforementioned computations:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# compute pairwise boolean class label mask
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y_mask&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[:,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:]).&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;float&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# compute pairwise squared Euclidean distances
# in transformed space
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;embedding&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;t&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;A&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;distances&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;pairwise_l2_sq&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;embedding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# compute pairwise probability matrix p_ij defined by a
# softmax over negative squared distances in the transformed space.
# since we are dealing with negative values with the largest value
# being 0, we need not worry about numerical instabilities
# in the softmax function
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;p_ij&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;softmax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;distances&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# for each p_i, zero out any p_ij that is not of the same
# class label as i
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;p_ij_mask&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;p_ij&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y_mask&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# sum over js to compute p_i
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;p_i&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;p_ij_mask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;sum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;dim&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# compute expected number of points correctly classified by summing
# over all p_i's.
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;p_i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;sum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;&lt;a name=&quot;sgd&quot;&gt;&lt;/a&gt;
&lt;strong&gt;Replacing Conjugate Gradients with SGD.&lt;/strong&gt; The authors originally optimized NCA with conjuate gradients. I decided to stick with mini-batch Stochastic Gradient Descent. My reasoning was two-fold. First, with very large datasets, the size of the pairwise matrix grows quadratically with the number of points so it was essential that I use a mini-batch optimizer that could run very fast on a memory-limited GPU. Second, SGD has been shown to be a tried and true optimizer in deep learning that &lt;a href=&quot;https://ruder.io/optimizing-gradient-descent/&quot;&gt;tends to generalize better&lt;/a&gt; than its counterparts.&lt;/p&gt;

&lt;p&gt;&lt;a name=&quot;tricks&quot;&gt;&lt;/a&gt;
&lt;strong&gt;Stability Tricks.&lt;/strong&gt; It took an intense session of debugging to get the implementation to consistently work for the various initializations and input data sizes. Here they are, in no particular order:&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;Summing over log probabilities was more stable than the non-log variant. In other words, I ended up using a categorical cross-entropy loss.&lt;/li&gt;
  &lt;li&gt;Initially, the random initialization was sampled from a unit variance Gaussian. Lowering the variance to 0.01 seemed to make the optimization more stable.&lt;/li&gt;
  &lt;li&gt;Selecting the batch size was crucial for convergence. A small batch size leads to a very jittery loss function. This makes sense intuitively: a small batch means the pairwise matrix is only a very crude approximation of the neighbourhood graph since it only considers a random subset of all possible neigbhours. I noticed a good rule of thumb was to try to maximize the batch size within the GPU limits.&lt;/li&gt;
  &lt;li&gt;Normalizing the input data (i.e. subtracting the mean and dividing by the standard deviation) helped with convergence. Note that doing this requires that we store the computed statistics and scale any test data appropriately.&lt;/li&gt;
  &lt;li&gt;Without L2 regularization, the final matrix &lt;script type=&quot;math/tex&quot;&gt;\mathcal{A}&lt;/script&gt; tended to blow up in scale. Adding L2 regularization to the loss function helped tame the matrix and speed-up convergence.&lt;/li&gt;
  &lt;li&gt;Random init always converged to a collapsed projection where the points lay on a hyperplane. This is possible because there is no term in the loss function that explicity pulls different classes apart. To combat this, I added a hinge loss component to the loss function, essentially turning the NCA loss into a contrastive loss function.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;&lt;a name=&quot;results&quot;&gt;&lt;/a&gt;&lt;/p&gt;
&lt;h2 id=&quot;boring-show-me-what-it-can-do&quot;&gt;Boring… Show Me What It Can Do!&lt;/h2&gt;

&lt;p&gt;At this point, you’re probably curious to know if NCA lives up to its claims. Let’s go ahead and test the PyTorch implementation on 2 tasks: dimensionality reduction and kNN classification.&lt;/p&gt;

&lt;p&gt;Using the NCA API is super simple. Very briefly, you first instantiate an NCA object with an embedding dimension and an initialization strategy. Then you call the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;train&lt;/code&gt; method on the input and ground-truth tensors, specifying a batch size and learning rate. There are other parameters you can change, all documented in the class docstring.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;nca&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;NCA&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;dim&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;init&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;random&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# instantiate nca object
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nca&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;train&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;batch_size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;64&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lr&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;1e-4&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# fit nca model
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X_nca&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nca&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# apply the learned transformation
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;&lt;a name=&quot;dim-reduct&quot;&gt;&lt;/a&gt;&lt;/p&gt;
&lt;h4 id=&quot;dimensionality-reduction&quot;&gt;Dimensionality Reduction&lt;/h4&gt;

&lt;p&gt;For this task, I replicated a portion of the results from section 4 of the paper. Specifically, I generated a synthetic three-dimensional dataset which consists of 5 classes, shown in different colors in Figure 4. The first two dimensions of the dataset correspond to concentric circles, while the third dimension is just Gaussian noise with high variance.&lt;/p&gt;

&lt;div class=&quot;imgcap&quot;&gt;
&lt;img src=&quot;/assets/nca/res.png&quot; width=&quot;100%&quot; style=&quot;border:none;&quot; /&gt;
&lt;div class=&quot;thecap&quot; style=&quot;text-align:center&quot;&gt;&lt;b&gt;Figure 4:&lt;/b&gt; NCA vs. PCA vs. LDA on the synthetic dataset.&lt;/div&gt;
&lt;/div&gt;

&lt;p&gt;I then embed the dataset to a 2D space using PCA, LDA and NCA. The results are shown Figure 4. While NCA seems to have recovered the original concentric pattern, PCA fails to project out the noise, a direct consequence of the high variance nature of the noise. If we lower it to &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;0.1&lt;/code&gt; for example, PCA successfully recovers the pattern. LDA also struggles to recover the concentric pattern since the classes themselves are not linearly separable.&lt;/p&gt;

&lt;p&gt;&lt;a name=&quot;sentiment&quot;&gt;&lt;/a&gt;&lt;/p&gt;
&lt;h4 id=&quot;knn-on-mnist&quot;&gt;kNN On MNIST&lt;/h4&gt;

&lt;p&gt;The whole motivation for NCA was that it would vastly reduce the storage and search costs of kNN for high-dimensional datasets. To put this to the test, we compared the storage, run time and error rates of two variants of kNN on the MNIST dataset:&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;5-NN on the raw MNIST dataset (784 dimensional)&lt;/li&gt;
  &lt;li&gt;5-NN on the 32 dimensional NCA projection of MNIST&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;The results are shown in the table below.&lt;/p&gt;

&lt;style&gt;
table {
  font-family: arial, sans-serif;
  border-collapse: collapse;
  width: 100%;
}

td, th {
  border: 1px solid #dddddd;
  text-align: left;
  padding: 8px;
}

tr:nth-child(even) {
  background-color: #f0f0f0;
}
&lt;/style&gt;

&lt;table&gt;
  &lt;tr&gt;
    &lt;th style=&quot;text-align: center&quot;&gt;Algorithm&lt;/th&gt;
    &lt;th style=&quot;text-align: center&quot;&gt;Raw kNN&lt;/th&gt;
    &lt;th style=&quot;text-align: center&quot;&gt;NCA + kNN&lt;/th&gt;
  &lt;/tr&gt;
  &lt;tr&gt;
    &lt;th style=&quot;text-align: center&quot;&gt;Error (%)&lt;/th&gt;
    &lt;td style=&quot;text-align: center&quot;&gt;2.8&lt;/td&gt;
    &lt;td style=&quot;text-align: center&quot;&gt;3.3&lt;/td&gt;
  &lt;/tr&gt;
  &lt;tr&gt;
    &lt;th style=&quot;text-align: center&quot;&gt;Time (s)&lt;/th&gt;
    &lt;td style=&quot;text-align: center&quot;&gt;155.25&lt;/td&gt;
    &lt;td style=&quot;text-align: center&quot;&gt;2.37&lt;/td&gt;
  &lt;/tr&gt;
  &lt;tr&gt;
    &lt;th style=&quot;text-align: center&quot;&gt;Storage (Mb)&lt;/th&gt;
    &lt;td style=&quot;text-align: center&quot;&gt;156.8&lt;/td&gt;
    &lt;td style=&quot;text-align: center&quot;&gt;6.40&lt;/td&gt;
  &lt;/tr&gt;
&lt;/table&gt;

&lt;p&gt;That’s a 66x speedup in time and a 25x saveup in storage&lt;sup id=&quot;fnref:5&quot;&gt;&lt;a href=&quot;#fn:5&quot; class=&quot;footnote&quot;&gt;5&lt;/a&gt;&lt;/sup&gt;!&lt;/p&gt;

&lt;p&gt;&lt;a name=&quot;thankyou&quot;&gt;&lt;/a&gt;&lt;/p&gt;
&lt;h2 id=&quot;acknowledgements&quot;&gt;Acknowledgements&lt;/h2&gt;

&lt;p&gt;I’d like to thank Nick Hynes, Alex Nichol and Brent Yi for their valuable feedback throughout my debugging session and blog writing. I also want to thank Chris Choy for the insight he provided on mode collapse. The javascript code for the animation was adapted from Sam Greydanus’ &lt;a href=&quot;https://greydanus.github.io/&quot;&gt;blog&lt;/a&gt; – check him out, he’s got some great content.&lt;/p&gt;

&lt;hr /&gt;
&lt;div class=&quot;footnotes&quot;&gt;
  &lt;ol&gt;
    &lt;li id=&quot;fn:1&quot;&gt;
      &lt;p&gt;The paper in question is &lt;a href=&quot;https://arxiv.org/abs/1904.07846&quot;&gt;Temporal Cycle Consistency Learning&lt;/a&gt; from Dwibedi et. al. &lt;a href=&quot;#fnref:1&quot; class=&quot;reversefootnote&quot;&gt;&amp;#8617;&lt;/a&gt;&lt;/p&gt;
    &lt;/li&gt;
    &lt;li id=&quot;fn:2&quot;&gt;
      &lt;p&gt;John Schulman discusses this in more depth in his latest &lt;a href=&quot;http://joschu.net/blog/opinionated-guide-ml-research.html&quot;&gt;blog post&lt;/a&gt;. &lt;a href=&quot;#fnref:2&quot; class=&quot;reversefootnote&quot;&gt;&amp;#8617;&lt;/a&gt;&lt;/p&gt;
    &lt;/li&gt;
    &lt;li id=&quot;fn:3&quot;&gt;
      &lt;p&gt;You can convince yourself that this is a valid distance metric by checking that the non-negativity, symmetry and triangle inequality conditions are satisfied. &lt;a href=&quot;#fnref:3&quot; class=&quot;reversefootnote&quot;&gt;&amp;#8617;&lt;/a&gt;&lt;/p&gt;
    &lt;/li&gt;
    &lt;li id=&quot;fn:4&quot;&gt;
      &lt;p&gt;We negate the expression because our goal is to maximize the expectation and we’re going to feed it to an optimizer that performs minimization. &lt;a href=&quot;#fnref:4&quot; class=&quot;reversefootnote&quot;&gt;&amp;#8617;&lt;/a&gt;&lt;/p&gt;
    &lt;/li&gt;
    &lt;li id=&quot;fn:5&quot;&gt;
      &lt;p&gt;Performance on MNIST isn’t very representative of real world performance on tougher datasets but this is still a very cool result. &lt;a href=&quot;#fnref:5&quot; class=&quot;reversefootnote&quot;&gt;&amp;#8617;&lt;/a&gt;&lt;/p&gt;
    &lt;/li&gt;
  &lt;/ol&gt;
&lt;/div&gt;
</description>
        <pubDate>Mon, 10 Feb 2020 00:00:00 +0000</pubDate>
        <link>http://kevinzakka.github.io/2020/02/10/nca/</link>
        <guid isPermaLink="true">http://kevinzakka.github.io/2020/02/10/nca/</guid>
        
        <category>machine learning</category>
        
        <category>metric learning</category>
        
        <category>knn</category>
        
        <category>nca</category>
        
        
      </item>
    
      <item>
        <title>Learning to Assemble and to Generalize from Self-Supervised Disassembly</title>
        <description>&lt;p&gt;This is a crosspost from the official &lt;a href=&quot;https://ai.googleblog.com/2019/10/learning-to-assemble-and-to-generalize.html&quot;&gt;Google AI Blog&lt;/a&gt;.&lt;/p&gt;

&lt;p&gt;Our physical world is full of different shapes, and learning how they are all interconnected is a natural part of interacting with our surroundings — for example, we understand that coat hangers hook onto clothing racks, power plugs insert into wall outlets, and USB cables fit into USB sockets. This general concept of “how things fit together’’ based on their shapes is something that we acquire over time and experience, and it helps to increase the efficiency with which we perform tasks, like assembling DIY furniture kits or packing gifts into a box. If robots could also learn “how things fit together,” then perhaps they could become more adaptable to new manipulation tasks involving objects they have never seen before, like reconnecting severed pipes, or building makeshift shelters by piecing together debris during disaster response scenarios.&lt;/p&gt;

&lt;p&gt;To explore this idea, we worked with researchers from Stanford and Columbia Universities to develop &lt;a href=&quot;https://form2fit.github.io/&quot;&gt;Form2Fit&lt;/a&gt;, a robotic manipulation algorithm that uses deep neural networks to learn to visually recognize how objects correspond (or “fit”) to each other. To test this algorithm, we tasked a real robot to perform kit assembly, where it needed to accurately assemble objects into a blister pack or corrugated display to form a single unit. Previous systems built for this task required extensive manual tuning to assemble a single kit unit at a time. However, we demonstrate that by learning the general concept of “how things fit together,” Form2Fit enables our robot to assemble various types of kits with a 94% success rate. Furthermore, Form2Fit is one of the first systems capable of generalizing to new objects and kitting tasks not seen during training.&lt;/p&gt;

&lt;div class=&quot;imgcap&quot;&gt;
&lt;p align=&quot;center&quot;&gt;
  &lt;img src=&quot;/assets/form2fit/teaser-white.gif&quot; width=&quot;100%&quot; style=&quot;border:none;&quot; /&gt;
  &lt;div class=&quot;thecap&quot; style=&quot;text-align:center&quot;&gt;Form2Fit learns to assemble a wide variety of kits by finding geometric correspondences between object surfaces and their target placement locations. By leveraging geometric information learned from multiple kits during training, the system generalizes to new objects and kits.&lt;/div&gt;
&lt;/p&gt;
&lt;/div&gt;

&lt;p&gt;While often overlooked, shape analysis plays an important role in manipulation, especially for tasks like kit assembly. In fact, the shape of an object often matches the shape of its corresponding space in the packaging, and understanding this relationship is what allows people to do this task with minimal guesswork. At its core, Form2Fit aims to learn this relationship by training over numerous pairs of objects and their corresponding placing locations across multiple different kitting tasks – with the goal to acquire a broader understanding of how shapes and surfaces fit together. Form2Fit improves itself over time with minimal human supervision, gathering its own training data by repeatedly disassembling completed kits through trial and error, then time-reversing the disassembly sequences to get assembly trajectories. After training overnight for 12 hours, our robot learns effective pick and place policies for a variety of kits, achieving 94% assembly success rates with objects and kits in varying configurations, and over 86% assembly success rates when handling completely new objects and kits.&lt;/p&gt;

&lt;h3 id=&quot;data-driven-shape-descriptors-for-generalizable-assembly&quot;&gt;Data-Driven Shape Descriptors For Generalizable Assembly&lt;/h3&gt;

&lt;p&gt;The core of Form2Fit is a two-stream matching network that learns to infer orientation-sensitive geometric pixel-wise descriptors for objects and their target placement locations from visual data. These descriptors can be understood as compressed 3D point representations that encode object geometry, textures, and contextual task-level knowledge. Form2Fit uses these descriptors to establish correspondences between objects and their target locations (i.e., where they should be placed). Since these descriptors are orientation-sensitive, they allow Form2Fit to infer how the picked object should be rotated before it is placed in its target location.&lt;/p&gt;

&lt;p&gt;Form2Fit uses two additional networks to generate valid pick and place candidates. A suction network gets fed a 3D image of the objects and generates pixel-wise predictions of suction success. The suction probability map is visualized as a heatmap, where hotter pixels indicate better locations to grasp the object at the 3D location of the corresponding pixel. In parallel, a place network gets fed a 3D image of the target kit and outputs pixel-wise predictions of placement success. These, too, are visualized as a heatmap, where higher confidence values serve as better locations for the robot arm to approach from a top-down angle to place the object. Finally, the planner integrates the output of all three modules to produce the final pick location, place location and rotation angle.&lt;/p&gt;

&lt;div class=&quot;imgcap&quot;&gt;
&lt;p align=&quot;center&quot;&gt;
  &lt;img src=&quot;/assets/form2fit/overview.png&quot; width=&quot;100%&quot; style=&quot;border:none;&quot; /&gt;
  &lt;div class=&quot;thecap&quot; style=&quot;text-align:center&quot;&gt;Overview of Form2Fit: the suction and place networks infer candidate picking and placing locations in the scene respectively. The matching network generates pixel-wise orientation-sensitive descriptors to match picking locations to their corresponding placing locations. The planner then integrates it all to control the robot to execute the next best pick and place action.&lt;/div&gt;
  &lt;/p&gt;
&lt;/div&gt;

&lt;h3 id=&quot;learning-assembly-from-disassembly&quot;&gt;Learning Assembly from Disassembly&lt;/h3&gt;

&lt;p&gt;Neural networks require large amounts of training data, which can be difficult to collect for tasks like assembly. Precisely inserting objects into tight spaces with the correct orientation (e.g., in kits) is challenging to learn through trial and error, because the chances of success from random exploration can be slim. In contrast, disassembling completed units is often easier to learn through trial and error, since there are fewer incorrect ways to remove an object than there are to correctly insert it. We leveraged this difference in order to amass training data for Form2Fit.&lt;/p&gt;

&lt;div class=&quot;imgcap&quot;&gt;
&lt;p align=&quot;center&quot;&gt;
  &lt;img src=&quot;/assets/form2fit/trimmed.gif&quot; width=&quot;100%&quot; style=&quot;border:none;&quot; /&gt;
  &lt;div class=&quot;thecap&quot; style=&quot;text-align:center&quot;&gt;An example of self-supervision through time-reversal: rewinding a disassembly sequence of a deodorant kit over time generates a valid assembly sequence.&lt;/div&gt;
&lt;/p&gt;
&lt;/div&gt;

&lt;p&gt;Our key observation is that in many cases of kit assembly, a disassembly sequence – when reversed over time – becomes a valid assembly sequence. This concept, called &lt;a href=&quot;https://arxiv.org/abs/1810.01128&quot;&gt;time-reversed disassembly&lt;/a&gt;, enables Form2Fit to train entirely through self-supervision by randomly picking with trial and error to disassemble a fully-assembled kit, then reversing that disassembly sequence to learn how the kit should be put together.&lt;/p&gt;

&lt;h3 id=&quot;generalization-results&quot;&gt;Generalization Results&lt;/h3&gt;

&lt;p&gt;The results of our experiments show great potential for learning generalizable policies for assembly. For instance, when a policy is trained to assemble a kit in only one specific position and orientation, it can still robustly assemble random rotations and translations of the kit 90% of the time.&lt;/p&gt;

&lt;div class=&quot;imgcap&quot;&gt;
&lt;p align=&quot;center&quot;&gt;
  &lt;img src=&quot;/assets/form2fit/init.gif&quot; width=&quot;100%&quot; style=&quot;border:none;&quot; /&gt;
  &lt;div class=&quot;thecap&quot; style=&quot;text-align:center&quot;&gt;Form2Fit policies are robust to a wide range of rotations and translations of the kits.&lt;/div&gt;
&lt;/p&gt;
&lt;/div&gt;

&lt;p&gt;We also find that Form2Fit is capable of tackling novel configurations it has not been exposed to during training. For example, when training a policy on two single-object kits (floss and tape), we find that it can successfully assemble new combinations and mixtures of those kits, even though it has never seen such configurations before.&lt;/p&gt;

&lt;div class=&quot;imgcap&quot;&gt;
&lt;p align=&quot;center&quot;&gt;
  &lt;img src=&quot;/assets/form2fit/res1.gif&quot; width=&quot;100%&quot; style=&quot;border:none;&quot; /&gt;
  &lt;div class=&quot;thecap&quot; style=&quot;text-align:center&quot;&gt;Form2Fit policies can generalize to novel kit configurations such as multiple versions of the same kit and mixtures of different kits.&lt;/div&gt;
&lt;/p&gt;
&lt;/div&gt;

&lt;p&gt;Furthermore, when given completely novel kits on which it has not been trained, Form2Fit can generalize using its learned shape priors to assemble those kits with over 86% assembly accuracy.&lt;/p&gt;

&lt;div class=&quot;imgcap&quot;&gt;
&lt;p align=&quot;center&quot;&gt;
  &lt;img src=&quot;/assets/form2fit/res2.gif&quot; width=&quot;100%&quot; style=&quot;border:none;&quot; /&gt;
  &lt;div class=&quot;thecap&quot; style=&quot;text-align:center&quot;&gt;Form2Fit policies can generalize to never-before-seen single and multi-object kits.&lt;/div&gt;
&lt;/p&gt;
&lt;/div&gt;

&lt;h3 id=&quot;what-have-the-descriptors-learned&quot;&gt;What Have the Descriptors Learned?&lt;/h3&gt;

&lt;p&gt;To explore what the descriptors of the matching network from Form2Fit have learned to encode, we visualize the pixel-wise descriptors of various objects in RGB colorspace through use of an embedding technique called &lt;a href=&quot;https://en.wikipedia.org/wiki/T-distributed_stochastic_neighbor_embedding&quot;&gt;t-SNE&lt;/a&gt;.&lt;/p&gt;

&lt;div class=&quot;imgcap&quot;&gt;
&lt;p align=&quot;center&quot;&gt;
  &lt;img src=&quot;/assets/form2fit/tsne.png&quot; width=&quot;100%&quot; style=&quot;border:none;&quot; /&gt;
  &lt;div class=&quot;thecap&quot; style=&quot;text-align:center&quot;&gt;The t-SNE embedding of the learned object descriptors. Similarly oriented objects of the same category display identical colors (e.g. A, B or F, G), while different objects (e.g. C, H) or same objects with different orientations (e.g. A, C, D or H, F) exhibit different colors.&lt;/div&gt;
  &lt;/p&gt;
&lt;/div&gt;

&lt;p&gt;We observe that the descriptors have learned to encode (a) rotation — objects oriented differently have different descriptors (A, C, D, E) and (H, F); (b) spatial correspondence — same points on the same oriented objects share similar descriptors (A, B) and (F, G); and (c) object identity — zoo animals and fruits exhibit unique descriptors (columns 3 and 4).&lt;/p&gt;

&lt;h3 id=&quot;limitations--future-work&quot;&gt;Limitations &amp;amp; Future Work&lt;/h3&gt;

&lt;p&gt;While Form2Fit’s results are promising, its limitations suggest directions for future work. In our experiments, we assume a 2D planar workspace to constrain the kit assembly task so that it can be solved by sequencing top-down picking and placing actions. This may not work for all cases of assembly – for example, when a peg needs to be precisely inserted at a 45 degree angle. It would be interesting to expand Form2Fit to more complex action representations for 3D assembly.&lt;/p&gt;

&lt;p&gt;You can learn more about this work and download the code from our &lt;a href=&quot;https://github.com/kevinzakka/form2fit&quot;&gt;GitHub repository&lt;/a&gt;.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;
&lt;iframe width=&quot;560&quot; height=&quot;315&quot; src=&quot;https://www.youtube.com/embed/exnMwDmS1QI&quot; frameborder=&quot;0&quot; allow=&quot;accelerometer; autoplay; encrypted-media; gyroscope; picture-in-picture&quot; allowfullscreen=&quot;&quot;&gt;&lt;/iframe&gt;
&lt;/p&gt;

&lt;h3 id=&quot;acknowledgments&quot;&gt;Acknowledgments&lt;/h3&gt;

&lt;p&gt;This research was done by Kevin Zakka, Andy Zeng, Johnny Lee, and Shuran Song (faculty at Columbia University), with special thanks to Nick Hynes, Alex Nichol, and Ivan Krasin for fruitful technical discussions; Adrian Wong, Brandon Hurd, Julian Salazar, and Sean Snyder for hardware support; Ryan Hickman for valuable managerial support; and Chad Richards for helpful feedback on writing.&lt;/p&gt;
</description>
        <pubDate>Thu, 31 Oct 2019 00:00:00 +0000</pubDate>
        <link>http://kevinzakka.github.io/2019/10/31/form2fit/</link>
        <guid isPermaLink="true">http://kevinzakka.github.io/2019/10/31/form2fit/</guid>
        
        <category>robotics</category>
        
        <category>research</category>
        
        
      </item>
    
      <item>
        <title>Manifesto</title>
        <description>&lt;p&gt;I find writing to be a very fascinating and therapeutic activity. There’s nothing quite like twiddling a bunch of words into a sequence, reading the result out loud, grimacing, and adjusting until it sounds just right. It’s the reason I started this blog yet I find that I haven’t been able to write as much as I would like to. It sucks, but articles on here have usually been academic and because I prioritize quality over quantity, finding the time to write them has been very challenging.&lt;/p&gt;

&lt;p&gt;To combat this dry spell, I’ve decided to create a new section of the blog entitled &lt;strong&gt;Miscellany&lt;/strong&gt;, where I’ll post on a variety of topics such as interesting research papers, books I read, and philosophical ponderings of life. I still intend to publish on the main section, but posts there will be reserved for tutorials and research expositions primarily in machine learning. I’m aiming to write once a month and while it’s not much, it’s still better than nothing. As &lt;a href=&quot;https://www.youtube.com/watch?v=46GwJbrMghQ&amp;amp;feature=youtu.be&amp;amp;t=172&quot;&gt;Andy Dufresne puts it beautifully&lt;/a&gt; in &lt;em&gt;The Shawshank Redemption&lt;/em&gt;:&lt;/p&gt;

&lt;blockquote&gt;
  &lt;p&gt;Get busy living, or get busy dying&lt;sup id=&quot;fnref:1&quot;&gt;&lt;a href=&quot;#fn:1&quot; class=&quot;footnote&quot;&gt;1&lt;/a&gt;&lt;/sup&gt;.&lt;/p&gt;
&lt;/blockquote&gt;

&lt;div class=&quot;footnotes&quot;&gt;
  &lt;ol&gt;
    &lt;li id=&quot;fn:1&quot;&gt;
      &lt;p&gt;A tad bit dramatic for my case, but I couldn’t resist. &lt;a href=&quot;#fnref:1&quot; class=&quot;reversefootnote&quot;&gt;&amp;#8617;&lt;/a&gt;&lt;/p&gt;
    &lt;/li&gt;
  &lt;/ol&gt;
&lt;/div&gt;
</description>
        <pubDate>Sun, 26 May 2019 00:00:00 +0000</pubDate>
        <link>http://kevinzakka.github.io/2019/05/26/manifesto/</link>
        <guid isPermaLink="true">http://kevinzakka.github.io/2019/05/26/manifesto/</guid>
        
        <category>personal</category>
        
        
      </item>
    
      <item>
        <title>Dex-Net 2.0: Deep Learning to Plan Robust Grasps</title>
        <description>&lt;p&gt;In this blog post, we’re going to take a close look at &lt;a href=&quot;https://arxiv.org/abs/1703.09312&quot;&gt;Dex-Net 2.0: Deep Learning to Plan Robust Grasps with Synthetic Point Clouds and Analytic Grasp Metrics&lt;/a&gt; by &lt;em&gt;Jeffrey Mahler&lt;/em&gt;, &lt;em&gt;Jacky Liang&lt;/em&gt;, &lt;em&gt;Sherdil Niyaz&lt;/em&gt;, &lt;em&gt;Michael Laskey&lt;/em&gt;, &lt;em&gt;Richard Doan&lt;/em&gt;, &lt;em&gt;Xinyu Liu&lt;/em&gt;, &lt;em&gt;Juan Aparicio Ojea&lt;/em&gt;, and &lt;em&gt;Ken Goldberg&lt;/em&gt;.&lt;/p&gt;

&lt;div class=&quot;imgcap&quot;&gt;
  &lt;img src=&quot;/assets/dexnet/teaser.png&quot; width=&quot;100%&quot; style=&quot;border:none;&quot; /&gt;
  &lt;div class=&quot;thecap&quot; style=&quot;text-align:center&quot;&gt;Overview&lt;/div&gt;
&lt;/div&gt;

&lt;p&gt;&lt;strong&gt;TL, DR.&lt;/strong&gt; This paper tackles grasp planning which is the task of finding a gripper configuration (pose and width) that maximizes a success metric subject to kinematic and collision constraints. The suggested approach is to train a Grasp Quality Convolutional Neural Network (GQ-CNN) on a large synthetic dataset of depth images with associated positive and negative grasps. Then during test time, one can sample various grasps from a depth image, feed each through the GQ-CNN, pick the one with the highest probability of success, and execute the grasp open-loop.&lt;/p&gt;

&lt;h3 id=&quot;variables&quot;&gt;Variables&lt;/h3&gt;

&lt;p&gt;Let’s start by introducing the variables that appear in the paper.&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;&lt;script type=&quot;math/tex&quot;&gt;x = (O, T_o, T_c, \gamma)&lt;/script&gt;: the state describing the variable properties of the camera and objects in the environment, where:
    &lt;ul&gt;
      &lt;li&gt;&lt;script type=&quot;math/tex&quot;&gt;O&lt;/script&gt;: the geometry and mass properties of the object.&lt;/li&gt;
      &lt;li&gt;&lt;script type=&quot;math/tex&quot;&gt;T_o, T_c&lt;/script&gt;: 3D poses of the object and camera respectively.&lt;/li&gt;
      &lt;li&gt;&lt;script type=&quot;math/tex&quot;&gt;\gamma&lt;/script&gt;: the coefficient of friction between the object and the gripper.&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;&lt;script type=&quot;math/tex&quot;&gt;u = (p, \phi)&lt;/script&gt;: a parallel-jaw grasp in 3D space, specified by a center &lt;script type=&quot;math/tex&quot;&gt;p = (x, y, z)&lt;/script&gt; relative to the camera and an angle in the table plane &lt;script type=&quot;math/tex&quot;&gt;\phi&lt;/script&gt;.&lt;/li&gt;
  &lt;li&gt;&lt;script type=&quot;math/tex&quot;&gt;y = R^{H \times W}&lt;/script&gt;: a pointcloud represented as a depth image with height H and width W taken by the camera with known intrinsics &lt;script type=&quot;math/tex&quot;&gt;K&lt;/script&gt; and pose &lt;script type=&quot;math/tex&quot;&gt;T_c&lt;/script&gt;.&lt;/li&gt;
  &lt;li&gt;&lt;script type=&quot;math/tex&quot;&gt;S(u, x) \in \{0, 1\}&lt;/script&gt;: a binary-valued grasp success metric, such as force closure.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;Using these random variables, we can define a joint distribution &lt;script type=&quot;math/tex&quot;&gt;p(S, x, u, y)&lt;/script&gt; that models the inherent uncertainty associated with our assumptions, such as erroneous sensors readings (calibration error, noise, limiting pinhole model, etc.), and imprecise control (kinematic inaccuracies, etc.).&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Goal.&lt;/strong&gt; Ingest a depth image &lt;script type=&quot;math/tex&quot;&gt;u&lt;/script&gt; of an object in a scene with an associated grasp candidate &lt;script type=&quot;math/tex&quot;&gt;u&lt;/script&gt;, and spit out the probability that &lt;script type=&quot;math/tex&quot;&gt;u&lt;/script&gt; will succeed under the above uncertainties. This is equivalent to predicting the &lt;strong&gt;robustness&lt;/strong&gt; &lt;script type=&quot;math/tex&quot;&gt;Q&lt;/script&gt; of a grasp, defined as the expected value of &lt;script type=&quot;math/tex&quot;&gt;S&lt;/script&gt; conditioned on &lt;script type=&quot;math/tex&quot;&gt;u&lt;/script&gt; and &lt;script type=&quot;math/tex&quot;&gt;y&lt;/script&gt;, i.e. &lt;script type=&quot;math/tex&quot;&gt;Q(u, y) = \mathbb{E}[S \vert u, y]&lt;/script&gt;.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Solution.&lt;/strong&gt; Use a neural network with weights &lt;script type=&quot;math/tex&quot;&gt;\theta&lt;/script&gt; to approximate the complex, high-dimensional function &lt;script type=&quot;math/tex&quot;&gt;Q&lt;/script&gt;. Concretely,&lt;/p&gt;

&lt;script type=&quot;math/tex; mode=display&quot;&gt;\hat{\theta} = \arg \min_{\theta} \ \mathbb{E}_{p(S, u, x, y)} \big[L(S, Q_{\theta}(u, y)) \big]&lt;/script&gt;

&lt;p&gt;And finally, using Monte-Carlo sampling of input-output pairs from our joint distribution, we obtain:&lt;/p&gt;

&lt;script type=&quot;math/tex; mode=display&quot;&gt;\hat{\theta} = \arg\min_{\theta} \sum_{i=1}^{N} L(S_i, Q_{\theta}(u_i, y_i))&lt;/script&gt;

&lt;p&gt;where &lt;script type=&quot;math/tex&quot;&gt;(S_i, u_i, x_i, y_i) \sim p(S, x, u, y)&lt;/script&gt;.&lt;/p&gt;

&lt;h3 id=&quot;generative-graphical-model&quot;&gt;Generative Graphical Model&lt;/h3&gt;

&lt;p&gt;We can think of our joint &lt;script type=&quot;math/tex&quot;&gt;p(S, x, u, y)&lt;/script&gt; as a generative model of images, grasps and success metrics. The relationship between the different variables is illustrated in the graphical model below.&lt;/p&gt;

&lt;div class=&quot;imgcap&quot;&gt;
  &lt;img src=&quot;/assets/dexnet/gm.png&quot; width=&quot;45%&quot; style=&quot;border:none;&quot; /&gt;
  &lt;div class=&quot;thecap&quot; style=&quot;text-align:center&quot;&gt;Graphical Model&lt;/div&gt;
&lt;/div&gt;

&lt;p&gt;Using the &lt;a href=&quot;https://en.wikipedia.org/wiki/Chain_rule_(probability)&quot;&gt;chain rule&lt;/a&gt;, we can express the joint &lt;script type=&quot;math/tex&quot;&gt;p(S, x, u, y)&lt;/script&gt;  as the product of 4 terms: &lt;script type=&quot;math/tex&quot;&gt;p(S \vert u, y, x)&lt;/script&gt;, &lt;script type=&quot;math/tex&quot;&gt;p(u \vert x, y)&lt;/script&gt;, &lt;script type=&quot;math/tex&quot;&gt;p(y \vert x)&lt;/script&gt; and &lt;script type=&quot;math/tex&quot;&gt;p(x)&lt;/script&gt;. And since &lt;script type=&quot;math/tex&quot;&gt;S&lt;/script&gt; and &lt;script type=&quot;math/tex&quot;&gt;u&lt;/script&gt; are independent of &lt;script type=&quot;math/tex&quot;&gt;y&lt;/script&gt; (no arrow going from &lt;script type=&quot;math/tex&quot;&gt;y&lt;/script&gt; to &lt;script type=&quot;math/tex&quot;&gt;S&lt;/script&gt; or &lt;script type=&quot;math/tex&quot;&gt;u&lt;/script&gt;), we can reduce the expression to&lt;/p&gt;

&lt;script type=&quot;math/tex; mode=display&quot;&gt;p(S, u, y, x) = {\color{red}{p(S \vert u, x)}} \cdot {\color{orange}{p(u \vert x)}} \cdot {\color{blue}{p(y \vert x)}} \cdot {\color{green}{p(x)}}&lt;/script&gt;

&lt;p&gt;where:&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;&lt;script type=&quot;math/tex&quot;&gt;{\color{green}{p(x)}}&lt;/script&gt; is the state distribution.&lt;/li&gt;
  &lt;li&gt;&lt;script type=&quot;math/tex&quot;&gt;{\color{blue}{p(y \vert x)}}&lt;/script&gt; is the observation model, conditioned on the current state.&lt;/li&gt;
  &lt;li&gt;&lt;script type=&quot;math/tex&quot;&gt;{\color{orange}{p(u \vert x)}}&lt;/script&gt; is the grasp candidate model, conditioned on the current state.&lt;/li&gt;
  &lt;li&gt;&lt;script type=&quot;math/tex&quot;&gt;{\color{red}{p(S \vert u, x)}}&lt;/script&gt; is the analytic model of grasp success conditioned on the grasp candidate and current state.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;The state &lt;script type=&quot;math/tex&quot;&gt;x = (O, T_o, T_c, \gamma)&lt;/script&gt; is represented by the blue nodes in the graphical model. Using the chain rule and independence properties, we can express its underlying distribution as the product of:&lt;/p&gt;

&lt;script type=&quot;math/tex; mode=display&quot;&gt;% &lt;![CDATA[
\begin{align*}
{\color{green}{p(x)}}
&amp;= p(\gamma \vert T_c, T_o, O) \cdot p(T_c \vert T_o, O) \cdot p(T_o \vert O) \cdot p(O) \\
&amp;= p(\gamma) \cdot p(T_c) \cdot p(T_o \vert O) \cdot p(O)
\end{align*} %]]&gt;&lt;/script&gt;

&lt;p&gt;with:&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;&lt;script type=&quot;math/tex&quot;&gt;p(\gamma)&lt;/script&gt;: truncated Gaussian over friction coefficients.&lt;/li&gt;
  &lt;li&gt;&lt;script type=&quot;math/tex&quot;&gt;p(O)&lt;/script&gt;: discrete uniform distribution over 3D object models.&lt;/li&gt;
  &lt;li&gt;&lt;script type=&quot;math/tex&quot;&gt;p(T_o \vert O)&lt;/script&gt;: continuous uniform distribution over discrete set of stable object poses.&lt;/li&gt;
  &lt;li&gt;&lt;script type=&quot;math/tex&quot;&gt;p(T_c)&lt;/script&gt;: continuous uniform distribution over spherical coordinates and polar angle.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;The grasp candidate model &lt;script type=&quot;math/tex&quot;&gt;{\color{orange}{p(u \vert x)}}&lt;/script&gt; is a uniform distribution over pairs of antipodal contact points on the object surface whose grasp axis is parallel to the table plane (we want top-down grasps), the observation model &lt;script type=&quot;math/tex&quot;&gt;{\color{blue}{p(y \vert x)}}&lt;/script&gt; is a rendered depth image of the scene corrupted with multiplicative and Gaussian Process noise, and the success model &lt;script type=&quot;math/tex&quot;&gt;{\color{red}{p(S \vert u, x)}}&lt;/script&gt; is a binary-valued reward function subject to 2 constraints: epsilon quality and collision freedom.&lt;/p&gt;

&lt;p&gt;Now that we’ve examined the inner workings of our generative model &lt;script type=&quot;math/tex&quot;&gt;p&lt;/script&gt;, let’s see how we can use it to generate the massive Dex-Net dataset.&lt;/p&gt;

&lt;h3 id=&quot;generating-dex-net&quot;&gt;Generating Dex-Net&lt;/h3&gt;

&lt;p&gt;To train our GQ-CNN, we need to generate i.i.d samples, consisting of depth images, grasps, and grasp robustness labels, by sampling from the generative joint &lt;script type=&quot;math/tex&quot;&gt;p(S, x, u, y)&lt;/script&gt;.&lt;/p&gt;

&lt;div class=&quot;imgcap&quot;&gt;
  &lt;img src=&quot;/assets/dexnet/data-gen.png&quot; width=&quot;100%&quot; style=&quot;border:none;&quot; /&gt;
  &lt;div class=&quot;thecap&quot; style=&quot;text-align:center&quot;&gt;Data Generation Pipeline&lt;/div&gt;
&lt;/div&gt;

&lt;ol&gt;
  &lt;li&gt;Randomly select, from a database of 1,500 meshes, a 3D object mesh using a discrete uniform distribution.&lt;/li&gt;
  &lt;li&gt;Randomly select, from a set of stable poses, a pose for this object using a continuous uniform distribution.&lt;/li&gt;
  &lt;li&gt;Use rejection sampling to generate top-down parallel-jaw grasps covering the surface of the object.&lt;/li&gt;
  &lt;li&gt;Randomly sample the camera pose (also from a continuous uniform distribution) and use it to render the object and its pose w.r.t to the camera into a depth image using ray tracing.&lt;/li&gt;
  &lt;li&gt;Classify the robustness of each sampled grasps to obtain a set of positive and negative grasps. Robustness is estimated using force closure probability which is a function of object pose, gripper pose, and friction coefficient uncertainty.&lt;/li&gt;
&lt;/ol&gt;

&lt;h3 id=&quot;training-the-gq-cnn&quot;&gt;Training the GQ-CNN&lt;/h3&gt;

&lt;p&gt;Once the synthetic dataset has been generated, it becomes trivial to train the network.&lt;/p&gt;

&lt;div class=&quot;imgcap&quot;&gt;
  &lt;img src=&quot;/assets/dexnet/model.png&quot; width=&quot;65%&quot; style=&quot;border:none;&quot; /&gt;
  &lt;div class=&quot;thecap&quot; style=&quot;text-align:center&quot;&gt;Overview of the Model&lt;/div&gt;
&lt;/div&gt;

&lt;p&gt;Remember how we mentioned that GQ-CNN takes as input a depth image and a grasp candidate? Well it actually turns out that the authors have a very clever way of encoding the grasp information into the depth image: they take a depth image and grasp candidate and transform the depth image such that the grasp pixel location &lt;script type=&quot;math/tex&quot;&gt;(i, j)&lt;/script&gt; – projected from the grasp position &lt;script type=&quot;math/tex&quot;&gt;(x, y)&lt;/script&gt; – is aligned with the image center and the grasp axis &lt;script type=&quot;math/tex&quot;&gt;\varphi&lt;/script&gt; corresponds to the middle row of the image. Then, at every iteration of SGD, we sample the transformed depth image and the remaining grasp variable &lt;script type=&quot;math/tex&quot;&gt;z&lt;/script&gt; (i.e the gripper depth from the camera), normalize the depth image to zero mean and unit standard deviation, and feed the tuple to the 18M parameter GQ-CNN model.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Note 1.&lt;/strong&gt; The model is a typical deep learning architecture composed of convolutional, max-pool and fully-connected primitives.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Note 2.&lt;/strong&gt; The depth alignment makes it easier for the model to train since it doesn’t have to worry about any rotational invariances. As for feeding the gripper depth to the model, I would think this is useful for pruning grasps that have the correct 2D position and orientation, but are too far away from the object (i.e. either not touching or barely touching).&lt;/p&gt;

&lt;h3 id=&quot;grasp-planning-inference-time&quot;&gt;Grasp Planning (Inference Time)&lt;/h3&gt;

&lt;p&gt;Once the model is trained, we can pair the QG-CNN with a policy of choice. The one used in the paper is &lt;script type=&quot;math/tex&quot;&gt;\pi_{\theta}(y) = \arg \max_{u \in C} Q_{\theta}(u, y)&lt;/script&gt; which amounts to sampling a set of predefined grasps from a depth image subject to a set of constraints &lt;script type=&quot;math/tex&quot;&gt;C&lt;/script&gt; (e.g. kinematic and collision constraints), scoring each grasp using the GQ-CNN, and finally executing the most robust grasp. There are two sampling strategies used to generate grasp candidates: antipodal grasp sampling and cross-entropy sampling.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Antipodal Grasp Sampling.&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;First, we perform edge detection by locating pixel areas with high gradient magnitude. This is especially useful since graspable regions usually correspond to contact points on opposite edges of an object.&lt;/p&gt;

&lt;div class=&quot;imgcap&quot;&gt;
  &lt;img src=&quot;/assets/dexnet/edge-detection.png&quot; width=&quot;100%&quot; style=&quot;border:none;&quot; /&gt;
&lt;/div&gt;

&lt;p&gt;Then we sample pairs of pixels belonging to these areas to generate antipodal contact points on the object. We enforce the constraints that point pairs are parallel to the table plane.&lt;/p&gt;

&lt;div class=&quot;imgcap&quot;&gt;
  &lt;img src=&quot;/assets/dexnet/cands.gif&quot; width=&quot;50%&quot; style=&quot;border:none;&quot; /&gt;
&lt;/div&gt;

&lt;p&gt;We repeat this step until we reach the desired number of grasps, potentially increasing the friction coefficient if the amount is insufficient. In the final step, 2D grasps are deprojected to 3D grasps using the camera intrinsics and extrinsics and multiple grasps are obtained from the same contact points by discretizing the height starting from the object surface to the table surface (&lt;script type=&quot;math/tex&quot;&gt;h = 0&lt;/script&gt;).&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Cross Entropy Method.&lt;/strong&gt;&lt;/p&gt;

&lt;div class=&quot;imgcap&quot;&gt;
  &lt;img src=&quot;/assets/dexnet/cem.png&quot; width=&quot;75%&quot; style=&quot;border:none;&quot; /&gt;
  &lt;div class=&quot;thecap&quot; style=&quot;text-align:center&quot;&gt;Evolution of grasp robustness as the gripper center sweeps the depth image from top to bottom.&lt;/div&gt;
&lt;/div&gt;

&lt;p&gt;Randomly choosing a grasp from a set of candidates doesn’t work very well in cases where the grasping regions are small and require very precise gripper configurations. Taking a look at the image above, we can see that as we sweep candidate grasps from top to bottom, grasp robustness stays near zero and spikes momentarily when we reach the good, yet narrow grasping area. Thus, uniform sampling of grasp candidates is inefficient especially since we’re trying to perform real-time grasp planning.&lt;/p&gt;

&lt;p&gt;This is where importance sampling – one of &lt;a href=&quot;https://kevinzakka.github.io/2018/09/28/prioritized-learning/&quot;&gt;my favorite&lt;/a&gt; techniques – can help! We can modify our sampling strategy such that at every iteration, we refit the candidate distribution to the grasps with the highest predicted robustness. The algorithm to perform this fitting is the cross-entropy method (CEM) which tries to minimize the cross-entropy between a mixture of gaussians and the top-k percentile of grasps ranked by GQ-CNN. The result is that at every iteration, we are more likely to sample grasps with high-robustness values (grasps in the spike area) and converge to an optimal grasp candidate. This fitting process is illustrated below.&lt;/p&gt;

&lt;div class=&quot;imgcap&quot;&gt;
  &lt;img src=&quot;/assets/dexnet/sampled.gif&quot; width=&quot;50%&quot; style=&quot;border:none;&quot; /&gt;
&lt;/div&gt;

&lt;h3 id=&quot;discussion&quot;&gt;Discussion&lt;/h3&gt;

&lt;ul&gt;
  &lt;li&gt;The sampling of grasps is inefficient. It would be interesting to extend the GQ-CNN to a fully-convolutional architecture where robustness labels can be computed for every pixel in the depth image in a single forward pass.&lt;/li&gt;
  &lt;li&gt;Dex-Net is open-loop which means that once a grasp candidate has been picked, it is executed blindly with no visual feedback. This sets it up for failure when camera calibration is imprecise or the environment it is placed in is dynamic and susceptible to change.&lt;/li&gt;
  &lt;li&gt;If we can speed-up Dex-Net by creating a smaller, fully-convolutional GQ-CNN, we may be able to run it at a high enough frequency to incorporate visual feedback and close the loop.&lt;/li&gt;
&lt;/ul&gt;
</description>
        <pubDate>Mon, 05 Nov 2018 00:00:00 +0000</pubDate>
        <link>http://kevinzakka.github.io/2018/11/05/dexnet/</link>
        <guid isPermaLink="true">http://kevinzakka.github.io/2018/11/05/dexnet/</guid>
        
        <category>grasping</category>
        
        <category>robotics</category>
        
        <category>cnn</category>
        
        
      </item>
    
      <item>
        <title>Learning What to Learn and When to Learn It</title>
        <description>&lt;p&gt;&lt;small&gt;&lt;b&gt;Disclaimer&lt;/b&gt;: This blog post describes unfinished research and should be treated as a work in progress.&lt;/small&gt;&lt;/p&gt;

&lt;p&gt;Hello world! I’m coming out of hibernation after 14 months of radio silence on this blog. I have a lot of things to blog about, from my research internship at Stanford University this past summer, to wrapping up my B.Eng. in EE in July – and I’ll hopefully get to those in future blog posts – but today, I’d like to talk about some of the cool research I did in my senior year of undergrad. Unfortunately, it’s not GAN/RL related (read as fortunately) but it’s definitely an interesting aspect of the field that could use some more attention.&lt;/p&gt;

&lt;p&gt;The problem we’ll be investigating today is whether we can get Deep Neural Networks (DNNs) to converge faster and learn more efficiently. In particular, we’ll try to answer the following questions:&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;Do we &lt;em&gt;really&lt;/em&gt; need all the training samples in a dataset to reach a desired accuracy?&lt;/li&gt;
  &lt;li&gt;Can we do better than (lazy) uniform sampling of the data in a given training epoch?&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;It actually turns out that on MNIST, we can reliably speedup training by a factor of 2 using just 30% of the available data&lt;sup id=&quot;fnref:1&quot;&gt;&lt;a href=&quot;#fn:1&quot; class=&quot;footnote&quot;&gt;1&lt;/a&gt;&lt;/sup&gt;!&lt;/p&gt;

&lt;h4 id=&quot;table-of-contents&quot;&gt;Table of Contents&lt;/h4&gt;

&lt;ul&gt;
  &lt;li&gt;&lt;a href=&quot;#toc1&quot;&gt;Motivation&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#toc2&quot;&gt;Refresher&lt;/a&gt;
    &lt;ul&gt;
      &lt;li&gt;&lt;a href=&quot;#toc3&quot;&gt;Stochastic Gradient Descent&lt;/a&gt;&lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#toc4&quot;&gt;Importance Sampling&lt;/a&gt;&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#toc5&quot;&gt;Quantifying Sample Importance&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#toc6&quot;&gt;Loss Patterns&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#toc7&quot;&gt;SGD on Steroids&lt;/a&gt;
    &lt;ul&gt;
      &lt;li&gt;&lt;a href=&quot;#toc8&quot;&gt;Mini-Batch Resampling&lt;/a&gt;&lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#toc9&quot;&gt;Auxiliary Model&lt;/a&gt;&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#toc10&quot;&gt;Things I Wish I Tried&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#toc11&quot;&gt;Closing Thoughts&lt;/a&gt;&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;&lt;a name=&quot;toc1&quot;&gt;&lt;/a&gt;&lt;/p&gt;
&lt;h2 id=&quot;motivation&quot;&gt;Motivation&lt;/h2&gt;

&lt;p&gt;Human beings acquire knowledge in a unique way, accelerating their learning by choosing where and when to focus their efforts on the available training material. For example, when practicing a new musical composition, a pianist will spend more time on the difficult measures – breaking them down into manageable pieces that can be progressively mastered – rather than wasting her efforts on the simpler, more familiar parts.&lt;/p&gt;

&lt;div class=&quot;imgcap&quot;&gt;
  &lt;img src=&quot;/assets/pr-lr/music-sheet-bach.jpg&quot; width=&quot;75%&quot; /&gt;
  &lt;div class=&quot;thecap&quot; style=&quot;text-align:center&quot;&gt;&lt;a href=&quot;https://www.thestrad.com/yehudi-menuhins-marked-up-copy-of-bachs-solo-violin-sonata-no2/6651.article&quot;&gt;Annotated Copy of Bach’s Solo Violin Sonata No. 2&lt;/a&gt;&lt;/div&gt;
&lt;/div&gt;

&lt;p&gt;Much of the same can be said about our formal primary and secondary education: our teachers help us learn from a smart selection of examples, leveraging previously acquired concepts to help guide our learning of new tools and abstractions. Human learning thus exhibits &lt;strong&gt;resource&lt;/strong&gt; and &lt;strong&gt;time&lt;/strong&gt; efficiency: we become proficient at mastering new concepts by selecting first, a &lt;em&gt;subset&lt;/em&gt; of what is available to us in terms of learning material, and second, the &lt;em&gt;sequence&lt;/em&gt; in which to learn the selected items such that we minimize acquisition time.&lt;/p&gt;

&lt;p&gt;Unfortunately, the training algorithms we use in AI, unlike human learning, are data hungry and time consuming. With vanilla stochastic gradient descent (SGD) for example, the standard go-to optimizer, we repetitively iterate over the training data in sequential mini-batches for a large number of epochs, where a mini-batch is constructed by uniformly sampling &lt;script type=&quot;math/tex&quot;&gt;b&lt;/script&gt; training points from the dataset. On large datasets – a necessity for good generalization – the naiveté of this sampling strategy hinders convergence and bottlenecks computation.&lt;/p&gt;

&lt;p&gt;&lt;a name=&quot;toc2&quot;&gt;&lt;/a&gt;&lt;/p&gt;
&lt;h2 id=&quot;refresher&quot;&gt;Refresher&lt;/h2&gt;

&lt;p&gt;So how can we improve SGD? Can we replace uniform sampling with a more efficient sampling distribution? More specifically, can we somehow predict a sample’s importance such that we adaptively construct training batches that catalyze more learning-per-iteration? These are all excellent questions we’ll be tackling further in the post, so let’s begin by refreshing a few concepts.&lt;/p&gt;

&lt;p&gt;&lt;a name=&quot;toc3&quot;&gt;&lt;/a&gt;
&lt;strong&gt;Stochastic Gradient Descent.&lt;/strong&gt; Given a neural network &lt;script type=&quot;math/tex&quot;&gt;M&lt;/script&gt; parameterized by a set of weights &lt;script type=&quot;math/tex&quot;&gt;W&lt;/script&gt;, a dataset &lt;script type=&quot;math/tex&quot;&gt;\mathcal{D}&lt;/script&gt;, and a loss function &lt;script type=&quot;math/tex&quot;&gt;L&lt;/script&gt;, we can express the goal of training as finding the optimal set of weights &lt;script type=&quot;math/tex&quot;&gt;\hat{W}&lt;/script&gt; such that,&lt;/p&gt;

&lt;script type=&quot;math/tex; mode=display&quot;&gt;% &lt;![CDATA[
\begin{equation}
\begin{split}
\hat{W} &amp; = \arg \min_{W} \ L_{\mathcal{D}} \\
&amp; = \arg \min_{W} \ \frac{1}{B} \sum_{i=1}^{B} L_i \\
&amp; = \arg \min_{W} \frac{1}{B} \sum_{i=1}^{B} \sum_{j=1}^{b} L_{ij} \big( M(x_j; W), y_j \big) \\
\end{split}
\end{equation} %]]&gt;&lt;/script&gt;

&lt;p&gt;where &lt;script type=&quot;math/tex&quot;&gt;B&lt;/script&gt; corresponds to the number of batches in an epoch, &lt;script type=&quot;math/tex&quot;&gt;b&lt;/script&gt; the number of training observations in a batch, and &lt;script type=&quot;math/tex&quot;&gt;(x_i, y_i)&lt;/script&gt; an input-output training pair.&lt;/p&gt;

&lt;div class=&quot;imgcap&quot;&gt;
  &lt;img src=&quot;/assets/pr-lr/sgd.png&quot; width=&quot;100%&quot; style=&quot;border:none;&quot; /&gt;
  &lt;div class=&quot;thecap&quot; style=&quot;text-align:center&quot;&gt;&lt;a href=&quot;https://distill.pub/2017/momentum/&quot;&gt;Converging to an Optimum with SGD&lt;/a&gt;&lt;/div&gt;
&lt;/div&gt;

&lt;p&gt;Without loss of generality, we can simplify the notation by considering just one training observation, a special case where the batch size is equal to 1. In that case, training our neural network &lt;script type=&quot;math/tex&quot;&gt;M&lt;/script&gt; amounts to updating the weight vector &lt;script type=&quot;math/tex&quot;&gt;W&lt;/script&gt; by taking a small step in the direction of the gradient of the loss with respect to &lt;script type=&quot;math/tex&quot;&gt;W&lt;/script&gt; between two consecutive iterations:&lt;/p&gt;

&lt;script type=&quot;math/tex; mode=display&quot;&gt;W_{t+1} = W_t - \alpha \ \mu_i \ \nabla_{W_t} L_i&lt;/script&gt;

&lt;p&gt;In the above equation, &lt;script type=&quot;math/tex&quot;&gt;i&lt;/script&gt; is a discrete random variable sampled from &lt;script type=&quot;math/tex&quot;&gt;\mathcal{D}&lt;/script&gt; according to a probability distribution &lt;script type=&quot;math/tex&quot;&gt;\mathcal{P}&lt;/script&gt; with probabilities &lt;script type=&quot;math/tex&quot;&gt;p_i&lt;/script&gt; and sampling weights &lt;script type=&quot;math/tex&quot;&gt;\mu_i&lt;/script&gt;. With vanilla SGD and uniform sampling, we have that &lt;script type=&quot;math/tex&quot;&gt;\forall i \in \mathcal{D}&lt;/script&gt;,&lt;/p&gt;

&lt;script type=&quot;math/tex; mode=display&quot;&gt;\begin{equation*}
    \mu_i = 1 \\
    p_i  = \frac{1}{|\mathcal{D_t}|}
\end{equation*}&lt;/script&gt;

&lt;p&gt;&lt;a name=&quot;toc4&quot;&gt;&lt;/a&gt;
&lt;strong&gt;Importance Sampling.&lt;/strong&gt; Importance sampling is a neat little trick for reducing the variance of an integral estimation by selecting a better distribution from which to sample a random variable. The trick is to multiply the integrand by a cleverly disguised &lt;script type=&quot;math/tex&quot;&gt;1&lt;/script&gt;:&lt;/p&gt;

&lt;script type=&quot;math/tex; mode=display&quot;&gt;% &lt;![CDATA[
\begin{equation}
\begin{split}
E_{x \sim p(x)} \big[\ f(x) \big] &amp; = \int f(x)\ p(x)\ dx \\
&amp; = \int f(x)\ p(x)\ \frac{q(x)}{q(x)}\ dx \\
&amp; = \int \frac{p(x)}{q(x)}\cdot f(x)\ q(x)\ dx \\
&amp; = E_{x \sim q(x)} \big[\ f(x)\cdot \frac{p(x)}{q(x)} \big] \\
\end{split}
\end{equation} %]]&gt;&lt;/script&gt;

&lt;p&gt;Since many quantities of interest (probabilities, sums, integrals)&lt;sup id=&quot;fnref:2&quot;&gt;&lt;a href=&quot;#fn:2&quot; class=&quot;footnote&quot;&gt;2&lt;/a&gt;&lt;/sup&gt; can be obtained by computing the mean of a function of a random variable &lt;script type=&quot;math/tex&quot;&gt;E[f(X)]&lt;/script&gt;, we can greatly accelerate – and even improve – Monte-Carlo estimates by switching out the original probability distribution with a density that minimizes the sampling of points that contribute very little to the estimate, i.e. points with a function value of 0.&lt;/p&gt;

&lt;div class=&quot;imgcap&quot;&gt;
  &lt;img src=&quot;/assets/pr-lr/mc-imp.jpg&quot; width=&quot;80%&quot; style=&quot;border:none;&quot; /&gt;
  &lt;div class=&quot;thecap&quot; style=&quot;text-align:center&quot;&gt;Smaller Point Spread with Importance Sampling&lt;/div&gt;
&lt;/div&gt;

&lt;p&gt;For a tutorial on Monte-Carlo estimation and Importance Sampling, click &lt;a href=&quot;https://github.com/kevinzakka/blog-code/blob/master/pr-lr/Monte%20Carlo%20and%20Importance%20Sampling.ipynb&quot;&gt;here&lt;/a&gt;.&lt;/p&gt;

&lt;p&gt;&lt;a name=&quot;toc5&quot;&gt;&lt;/a&gt;&lt;/p&gt;
&lt;h2 id=&quot;quantifying-sample-importance&quot;&gt;Quantifying Sample Importance&lt;/h2&gt;

&lt;p&gt;In the previous section, we mentioned that uniform sampling assigns equal importance to all the training points in &lt;script type=&quot;math/tex&quot;&gt;\mathcal{D}&lt;/script&gt;. This is obviously wasteful: while some samples are “easy” for the model and can be discarded in the initial stages with minimal impact on performance, the more “difficult” samples should be addressed more frequently throughout the training since they contribute to faster learning. So can we find a way to quantify this “importance”?&lt;/p&gt;

&lt;p&gt;Fortunately, the answer is yes: we can theoretically&lt;sup id=&quot;fnref:3&quot;&gt;&lt;a href=&quot;#fn:3&quot; class=&quot;footnote&quot;&gt;3&lt;/a&gt;&lt;/sup&gt; show that this quantity is none other than the norm of the gradient of a sample. Intuitively this makes sense: in the classification setting for example, we would expect misclassified examples to exhibit larger gradients than their correctly classified counterparts. Unfortunately, the norm of the gradient is pretty expensive to compute, especially in settings where we would like to avoid computing a full forward and backwards pass.&lt;/p&gt;

&lt;p&gt;What about the loss of a sample? We essentially get it for free in the forward pass of backprop, so if we can show some degree of correlation with the gradient norm, it would be a less accurate but way cheaper metric for importance. Let’s try and verify this with a small PyTorch experiment. We’re going to train a small convnet on MNIST and record both the loss and gradient of every image in an epoch. We’ll then sort the list containing the gradient norms and use it to index the list of losses. A scatter plot of the reindexed losses should reveal a few things:&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;If there &lt;em&gt;is&lt;/em&gt; indeed a correlation, there should be a (potentially noisy) straight line through the scatter plot.&lt;/li&gt;
  &lt;li&gt;If the correlation is positive – implying that a higher gradient norm corresponds to a higher loss value and vice versa – this line should be increasing.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;&lt;strong&gt;EDIT (08/06/2019)&lt;/strong&gt;: @AruniRC kindly mentioned that I can compute the &lt;a href=&quot;https://en.wikipedia.org/wiki/Pearson_correlation_coefficient&quot;&gt;Pearson correlation coefficient&lt;/a&gt; to numerically quantify the degree of correlation between the gradient norm and the loss value. I’ve now added a cell in the notebook to compute it.&lt;/p&gt;

&lt;p&gt;Here’s a code snippet for computing the L2 norm of the gradient of a batch of losses with respect to the parameters of the network. Since there’s a pair of weights and biases associated with every convolutional and fully-connected layer and we want to return a scalar, we can calculate and return the square root of the sum of the squared gradient norms.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;gradient_norm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;losses&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;norms&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[]&lt;/span&gt;
  &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;l&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;losses&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;grad_params&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;autograd&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;grad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;l&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;parameters&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;create_graph&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;grad_norm&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;grad&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;grad_params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
      &lt;span class=&quot;n&quot;&gt;grad_norm&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;grad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;norm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;pow&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;norms&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;append&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;grad_norm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sqrt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;())&lt;/span&gt;
  &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;norms&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Incorporating the above function in the training loop is pretty trivial. All we need to do is record a &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;(grad_norm, loss)&lt;/code&gt; tuple for every image in the dataset.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# train for 1 epoch
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;epoch_stats&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[]&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;batch_idx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;target&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;enumerate&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;train_loader&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;target&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;to&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;target&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;to&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;optimizer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;zero_grad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;output&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;losses&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;F&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nll_loss&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;output&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;target&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;reduction&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'none'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;grad_norms&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;gradient_norm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;losses&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;indices&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;batch_idx&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))]&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;batch_stats&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[]&lt;/span&gt;
  &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;g&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;l&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;zip&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;indices&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;grad_norms&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;losses&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;batch_stats&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;append&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;g&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;l&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]])&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;epoch_stats&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;append&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;batch_stats&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;losses&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mean&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;backward&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;optimizer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;step&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;We can compute the correlation between &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;grad_norms&lt;/code&gt; and &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;losses&lt;/code&gt; using the following one-liner:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;corr&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cov&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;grad_norms&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;losses&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;std&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;grad_norms&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;std&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;losses&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;Pearson Correlation Coeff: {}&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;format&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;corr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]))&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# prints ~0.83
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;This returns a value of 0.83 which shows a strong relationship between both variables. Next, we verify this intuition graphically by indexing our losses using the sorted gradient norms and generating the aforementioned scatter plot.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# reindex the losses using the sorted gradient norms
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;flat&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;val&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;sublist&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;epoch_stats&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;val&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;sublist&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;sorted_idx&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;sorted&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;flat&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;key&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;flat&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;][&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;][&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;sorted_losses&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;flat&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;idx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;][&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;][&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;].&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;item&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;idx&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;sorted_idx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;div class=&quot;imgcap&quot;&gt;
&lt;img src=&quot;/assets/pr-lr/loss_vs_grad.jpg&quot; width=&quot;100%&quot; style=&quot;border:none;&quot; /&gt;
&lt;div class=&quot;thecap&quot; style=&quot;text-align:center&quot;&gt;Sorted Losses According to Gradient Norm&lt;/div&gt;
&lt;/div&gt;

&lt;p&gt;The above plot suggests that we &lt;em&gt;can&lt;/em&gt; indeed use the loss value of a sample as a proxy for its importance. This is exciting news and opens up some interesting avenues for improving SGD.&lt;/p&gt;

&lt;p&gt;If you want to reproduce the above logic, click &lt;a href=&quot;https://github.com/kevinzakka/blog-code/blob/master/pr-lr/Loss%20vs%20Gradient%20Norm.ipynb&quot;&gt;here&lt;/a&gt;.&lt;/p&gt;

&lt;p&gt;&lt;a name=&quot;toc6&quot;&gt;&lt;/a&gt;&lt;/p&gt;
&lt;h2 id=&quot;loss-patterns&quot;&gt;Loss Patterns&lt;/h2&gt;

&lt;p&gt;In this section, we’ll try to answer the following question:&lt;/p&gt;

&lt;blockquote&gt;
  &lt;p&gt;Is a sample’s importance consistent across epochs? In other words, if a sample exhibits low loss in the early stages of training, is this still the case in later epochs?&lt;/p&gt;
&lt;/blockquote&gt;

&lt;p&gt;There is substantial benefit in providing empirical evidence to this hypothesis. The reasons are two-fold: &lt;strong&gt;first&lt;/strong&gt;, by eliminating consistently low-loss images from the dataset, we reduce train time proportionally to the discarded images; &lt;strong&gt;second&lt;/strong&gt;, by oversampling the high-loss images, we reduce the variance of the gradients and speedup the convergence to &lt;script type=&quot;math/tex&quot;&gt;\hat{W}&lt;/script&gt;.&lt;/p&gt;

&lt;p&gt;To explore this idea, we’re going to track every sample’s loss over a set number of epochs. We’ll bin the loss values into 10 quantiles and compare the histograms over the different epochs. Finally, we’ll repeat these steps with shuffling turned off, then turned on.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;NB:&lt;/strong&gt; We need to be a bit careful with keeping track of a sample’s index when shuffling is turned on. The solution is to create a permutation of &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;[0, 1, 2, ..., 59,999]&lt;/code&gt; at the beginning of every epoch and feed it to a sequential sampler &lt;strong&gt;with shuffling turned off&lt;/strong&gt;. By remapping the indices to their true ordering relative to the permutations at the end of training, we would have effectively simulated random shuffling.&lt;/p&gt;

&lt;p&gt;If this sounds complicated, let me show you how simple it is to achieve in PyTorch:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# PermSampler takes a list of `indices` and iterates over it sequentially
&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;PermSampler&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Sampler&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
  &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;indices&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;indices&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;indices&lt;/span&gt;
  &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;__iter__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;iter&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;indices&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;__len__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;indices&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# if `permutation` is None, we return a data loader with no shuffling
# if `permutation` is a list of indices, we return a data loader that iterates
# over the MNIST dataset with indices specified by `permutation`.
&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;get_data_loader&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;data_dir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;batch_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;permutation&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;normalize&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;transforms&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Normalize&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mean&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.1307&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;std&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.3081&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,))&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;transform&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;transforms&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Compose&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;transforms&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ToTensor&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;normalize&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;dataset&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;MNIST&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;root&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;data_dir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;train&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;download&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;transform&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;transform&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;sampler&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;
  &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;permutation&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;is&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;not&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;sampler&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;PermSampler&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;permutation&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;loader&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;DataLoader&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;dataset&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;batch_size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;batch_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;shuffle&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;False&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;loader&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;After training for 5 epochs, we collect a list containing a tuple &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;(idx, loss_idx)&lt;/code&gt; for every image in the dataset. We can remap the indices with the following code:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# remap the indices based on the permutations list
&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;stat&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;perm&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;zip&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;stats_with_shuffling_flat&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;permutations&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
  &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;stat&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;stat&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;][&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;perm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Finally, we bin the sorted losses of every epoch into 10 bins and compute the percent match of bins across all epochs, the last 4 epochs, and the last 2 epochs.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;percentage_split&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;seq&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;percentages&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;cdf&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cumsum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;percentages&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;k&quot;&gt;assert&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;allclose&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cdf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;1.0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;stops&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;list&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;map&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cdf&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;seq&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)))&lt;/span&gt;
  &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;seq&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;a&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;b&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;a&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;b&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;zip&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;+&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;stops&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;stops&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)]&lt;/span&gt;

&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;bin_losses&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;all_epochs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;num_quantiles&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;10&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;percentile_splits&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[]&lt;/span&gt;
  &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;ep&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;all_epochs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;sorted_loss_idx&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;sorted&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ep&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;key&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;ep&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;][&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;reverse&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;splits&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;percentage_split&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sorted_loss_idx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_quantiles&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;/&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;100&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_quantiles&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;percentile_splits&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;append&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;splits&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;percentile_splits&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;fr&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;all_matches&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[]&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;f&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;fr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;percent_matches&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[]&lt;/span&gt;
  &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_quantiles&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;percentile_all&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[]&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;j&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;percentile_splits&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)):&lt;/span&gt;
      &lt;span class=&quot;n&quot;&gt;percentile_all&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;append&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;percentile_splits&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;j&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;][&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;
      &lt;span class=&quot;n&quot;&gt;matching&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;reduce&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;intersect1d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;percentile_all&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
      &lt;span class=&quot;n&quot;&gt;percent&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;100&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;matching&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;percentile_all&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;
      &lt;span class=&quot;n&quot;&gt;percent_matches&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;append&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;percent&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;all_matches&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;append&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;percent_matches&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;It’s interesting to compute percent matches across a varying range of epochs. The reason is that the training dynamics are less stable in the early epochs when the model weights are still random (analogous to transient response and steady state in circuit theory). For example, we would expect to have higher percent matches if we eliminate the first epoch from the analysis – and this is verified in the below plot!&lt;/p&gt;

&lt;div class=&quot;img&quot;&gt;
&lt;img src=&quot;/assets/pr-lr/no_shuffling.jpg&quot; width=&quot;100%&quot; style=&quot;border:none;&quot; /&gt;
&lt;img src=&quot;/assets/pr-lr/with_shuffling.jpg&quot; width=&quot;100%&quot; style=&quot;border:none;&quot; /&gt;
&lt;/div&gt;

&lt;p&gt;The histograms confirm our hypothesis:&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;~ 30% of the samples with a loss value in the top 10% consistently rank in those ranges across all epochs. This number increases to ~ 60% across epochs 1 through 4 and ~ 85% across the last two epochs.&lt;/li&gt;
  &lt;li&gt;~ 30% of the samples with a loss value in the bottom 10% consistently rank in those ranges across all epochs. This number increases to ~ 50% across epochs 1 through 4 and ~ 70% across the last two epochs.&lt;/li&gt;
  &lt;li&gt;Shuffling has a minimial impact on the loss evolution of the samples across epochs.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;If you want to reproduce the histograms, click &lt;a href=&quot;https://github.com/kevinzakka/blog-code/blob/master/pr-lr/Loss%20Patterns.ipynb&quot;&gt;here&lt;/a&gt;.&lt;/p&gt;

&lt;p&gt;&lt;a name=&quot;toc7&quot;&gt;&lt;/a&gt;&lt;/p&gt;
&lt;h2 id=&quot;sgd-on-steroids&quot;&gt;SGD on Steroids&lt;/h2&gt;

&lt;p&gt;&lt;a name=&quot;toc8&quot;&gt;&lt;/a&gt;
&lt;strong&gt;Mini-Batch Resampling.&lt;/strong&gt; In the first version of SGD-S, we’re going to split our training epochs into 2 stages:&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;&lt;strong&gt;Transient Epochs&lt;/strong&gt;: in the transient epochs, we train our model exactly as we would in regular SGD. However, in the last epoch, we record and return the losses of every image in the dataset.&lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;Steady-State Epochs&lt;/strong&gt;:
    &lt;ul&gt;
      &lt;li&gt;For every epoch in the steady-state, we sample batches using the loss as the sampling distribution.&lt;/li&gt;
      &lt;li&gt;At the end of every epoch in the steady-state, we eliminate 10% of the images with the lowest losses. Furthermore, we can choose to randomly introduce a fraction of the discarded images to combat potential catastrophic forgetting.&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;Let’s illustrate how we can use the loss function to construct an importance sampling distribution for mini-batch resampling. This is achievable using PyTorch’s &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;WeightedRandomSampler&lt;/code&gt; in conjunction with the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;DataLoader&lt;/code&gt;.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# sort the loss in decreasing order
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sorted_loss_idx&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;sorted&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;losses&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;key&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;losses&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;][&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;reverse&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# house cleaning
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;to_remove&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;sorted_loss_idx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;perc_to_remove&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;100&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sorted_loss_idx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)):]&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;to_keep&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;sorted_loss_idx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[:&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;perc_to_remove&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;100&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sorted_loss_idx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))]&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;to_add&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;list&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;random&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;choice&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;removed&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(.&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;01&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sorted_loss_idx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;replace&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;False&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;new_idx&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;to_keep&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;to_add&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;new_idx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sort&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;weights&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;losses&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;idx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;][&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;idx&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;new_idx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;sampler&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;WeightedRandomSampler&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;weights&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;weights&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;&lt;a name=&quot;toc9&quot;&gt;&lt;/a&gt;
&lt;strong&gt;Auxiliary Model.&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;&lt;a name=&quot;toc10&quot;&gt;&lt;/a&gt;&lt;/p&gt;
&lt;h2 id=&quot;things-i-wish-i-tried&quot;&gt;Things I Wish I Tried&lt;/h2&gt;

&lt;p&gt;&lt;a name=&quot;toc11&quot;&gt;&lt;/a&gt;&lt;/p&gt;
&lt;h2 id=&quot;closing-thoughts&quot;&gt;Closing Thoughts&lt;/h2&gt;

&lt;hr /&gt;
&lt;div class=&quot;footnotes&quot;&gt;
  &lt;ol&gt;
    &lt;li id=&quot;fn:1&quot;&gt;
      &lt;p&gt;CIFAR results pending. &lt;a href=&quot;#fnref:1&quot; class=&quot;reversefootnote&quot;&gt;&amp;#8617;&lt;/a&gt;&lt;/p&gt;
    &lt;/li&gt;
    &lt;li id=&quot;fn:2&quot;&gt;
      &lt;p&gt;Explain how. &lt;a href=&quot;#fnref:2&quot; class=&quot;reversefootnote&quot;&gt;&amp;#8617;&lt;/a&gt;&lt;/p&gt;
    &lt;/li&gt;
    &lt;li id=&quot;fn:3&quot;&gt;
      &lt;p&gt;Add proof or point to it. &lt;a href=&quot;#fnref:3&quot; class=&quot;reversefootnote&quot;&gt;&amp;#8617;&lt;/a&gt;&lt;/p&gt;
    &lt;/li&gt;
  &lt;/ol&gt;
&lt;/div&gt;
</description>
        <pubDate>Fri, 28 Sep 2018 00:00:00 +0000</pubDate>
        <link>http://kevinzakka.github.io/2018/09/28/prioritized-learning/</link>
        <guid isPermaLink="true">http://kevinzakka.github.io/2018/09/28/prioritized-learning/</guid>
        
        <category>deep learning</category>
        
        <category>sgd</category>
        
        <category>importance sampling</category>
        
        <category>pytorch</category>
        
        <category>2018</category>
        
        
      </item>
    
      <item>
        <title>Getting Up and Running with PyTorch on Amazon Cloud</title>
        <description>&lt;p align=&quot;center&quot;&gt;
 &lt;img src=&quot;/assets/aws/splash.png&quot; alt=&quot;Drawing&quot; width=&quot;60%&quot; /&gt;
&lt;/p&gt;

&lt;p&gt;This is a succint tutorial aimed at helping you set up an AWS GPU instance so that you can train and test your PyTorch models in the cloud. If you don’t own a GPU like me, this can be a great way of drastically reducing the training time of your models, so while your instance is furiously crunching numbers in some faraway Amazon server, you can peacefully experiment with and prototype new architectures from the comfort of a Starbucks couch.&lt;/p&gt;

&lt;div class=&quot;imgcap&quot;&gt;
&lt;p align=&quot;center&quot;&gt;
 &lt;img src=&quot;/assets/aws/cpu-meter.png&quot; width=&quot;30%&quot; style=&quot;border:none;&quot; /&gt;
 &lt;div class=&quot;thecap&quot; style=&quot;text-align:center&quot;&gt;I mean we all love a silent macbook, right?&lt;/div&gt;
&lt;/p&gt;
&lt;/div&gt;

&lt;p&gt;The cool part is that if you’re a high school or college student, you can sign up for a Github Developer pack which will get you $150 worth of free AWS credits. That’s around 167 hours or 7 days of compute time&lt;sup id=&quot;fnref:1&quot;&gt;&lt;a href=&quot;#fn:1&quot; class=&quot;footnote&quot;&gt;1&lt;/a&gt;&lt;/sup&gt;, an amply sufficient amount for those fun weekend side projects and experiments. As usual, any code or script that appears on this page can be downloaded from my &lt;a href=&quot;https://github.com/kevinzakka/blog-code/tree/master/aws-pytorch&quot;&gt;Blog Repository&lt;/a&gt;. And on that note, let’s get started!&lt;/p&gt;

&lt;h4 id=&quot;table-of-contents&quot;&gt;Table of Contents&lt;/h4&gt;

&lt;ul&gt;
  &lt;li&gt;&lt;a href=&quot;#toc1&quot;&gt;Configuring Your EC2 Instance&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#toc2&quot;&gt;Launching &amp;amp; Managing Your EC2 Instance&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#toc3&quot;&gt;SSH Persistence With TMUX&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#toc4&quot;&gt;Conclusion&lt;/a&gt;&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;&lt;a name=&quot;toc1&quot;&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h2 id=&quot;configuring-your-ec2-instance&quot;&gt;Configuring Your EC2 Instance&lt;/h2&gt;

&lt;p&gt;I’m assuming you’ve already created an AWS account but if you haven’t, the whole process shouldn’t take you more than 2 minutes. Note that it will require you to enter your credit card information which is necessary to charge you &lt;em&gt;if and when&lt;/em&gt; you exceed your free credits. Now’s also a great time to claim your &lt;a href=&quot;https://education.github.com/pack&quot;&gt;GitHub Student Developer Pack&lt;/a&gt; credits so go ahead and do that.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Pick your Region.&lt;/strong&gt; Ok, so the instance type we are going to use is located in &lt;strong&gt;US West (Oregon)&lt;/strong&gt; so make sure the region information on the top right of the screen correctly reflects that.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;
 &lt;img src=&quot;/assets/aws/step1.png&quot; alt=&quot;Drawing&quot; width=&quot;80%&quot; /&gt;
&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Limit Increase.&lt;/strong&gt; The next thing we need to do is request a limit increase for EC2 instances. For some weird reason, Amazon automatically sets the limit to 0 upon account creation so it has to be increased by sending in a support ticket.&lt;/p&gt;

&lt;p&gt;Go ahead and click &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Support &amp;gt; Support Center&lt;/code&gt; at the top right of your screen. This will direct you to a page with a blue &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Create Case&lt;/code&gt; button that you should click. You’ll be greeted with the following:&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;
 &lt;img src=&quot;/assets/aws/step2.png&quot; alt=&quot;Drawing&quot; width=&quot;80%&quot; /&gt;
&lt;/p&gt;

&lt;p&gt;We want a Limit Increase for EC2 instances meaning you need to select &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Service Limit Increase&lt;/code&gt; in &lt;strong&gt;Regarding&lt;/strong&gt; and &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;EC2 Instances&lt;/code&gt; in &lt;strong&gt;Limit Type&lt;/strong&gt;. Now fill in the &lt;strong&gt;Request 1&lt;/strong&gt; box and &lt;strong&gt;Use Case Description&lt;/strong&gt; as I’ve done here.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;
 &lt;img src=&quot;/assets/aws/step3.png&quot; alt=&quot;Drawing&quot; width=&quot;80%&quot; /&gt;
&lt;/p&gt;

&lt;p&gt;Finally, make sure to select &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Web&lt;/code&gt; as your &lt;strong&gt;Contact method&lt;/strong&gt; and submit the request. Note that the time of response varies: I’ve had limit increases resolved in a matter of minutes and sometimes up to a full day, so be patient. Also, feel free to change the &lt;strong&gt;New limit value&lt;/strong&gt; to suit your needs. I’ve opted for 2 because the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;p2.xlarge&lt;/code&gt; instance type we’ll be working with has a single GPU with memory constraints that may limit the number of jobs I may run concurrently.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Configure Instance.&lt;/strong&gt; Ok, we’re now ready to create and configure our EC2 instance. Back on the home page console (click on the orange cube in the top left), navigate to &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;EC2&lt;/code&gt; in the Compute services section, and then click on the blue &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Launch Instance&lt;/code&gt; button.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;
 &lt;img src=&quot;/assets/aws/step4.png&quot; alt=&quot;Drawing&quot; width=&quot;80%&quot; /&gt;
&lt;/p&gt;

&lt;p&gt;You’ll be greeted with a 7-step process like so.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;
 &lt;img src=&quot;/assets/aws/step5.png&quot; alt=&quot;Drawing&quot; width=&quot;80%&quot; /&gt;
&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;AMI.&lt;/strong&gt; First select the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Ubuntu Server 16.04 LTS (HVM), SSD Volume Type&lt;/code&gt; as the AMI of choice.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Instance.&lt;/strong&gt; Select &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;p2.xlarge&lt;/code&gt; as your instance type. This is an instance with a single GPU which is what we asked for in our limit increase request.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Spot Instances.&lt;/strong&gt; At this point, you should be on the &lt;strong&gt;Configure Instance Details&lt;/strong&gt; step. This is where things get interesting. In fact, Amazon gives us the ability to bid on spare Amazon EC2 computing capacity for a much cheaper price than the on-demand one.&lt;/p&gt;

&lt;p&gt;Basically, what that means is that if our bid price is higher than the current market price, our instance will be launched and charged at that price. The only downside is that if that ever flips around, instances get &lt;span style=&quot;color:red&quot;&gt;terminated&lt;/span&gt; instantly and with no warning&lt;sup id=&quot;fnref:2&quot;&gt;&lt;a href=&quot;#fn:2&quot; class=&quot;footnote&quot;&gt;2&lt;/a&gt;&lt;/sup&gt;.&lt;/p&gt;

&lt;p&gt;&lt;span style=&quot;color:blue&quot;&gt;TL;DR:&lt;/span&gt; Spot instances can be ideal for non-critical experimentation like hyperparameter tuning but stay away from them if you need to train a model for a large number of epochs.&lt;/p&gt;

&lt;p&gt;I’ll assume the user uses On-Demand pricing for the remainder of this post but if you do want to find out more about Spot Instances, feel free to watch this Youtube &lt;a href=&quot;https://www.youtube.com/watch?v=_XT6McviY7w&quot;&gt;video&lt;/a&gt;.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Add Storage.&lt;/strong&gt; Next, we’ll be increasing the size of our Root Volume to accomodate large datasets such as ImageNet which is around 48 Gb. Feel free to enter any number above that.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;
 &lt;img src=&quot;/assets/aws/step6.png&quot; alt=&quot;Drawing&quot; width=&quot;80%&quot; /&gt;
&lt;/p&gt;

&lt;p&gt;Note that the Root Volume is EBS-backed meaning it persists on instance termination. The default behavior however is to delete it on termination. Weird right? Well, not really. With ephemeral storage, the other type of storage AWS offers, there is no persist option, whether it be on instance stop or terminate. Thus EBS with delete-on-terminate gives us the ability to keep our data on disk when the instance is stopped!&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Configure Security Group.&lt;/strong&gt; You can skip the &lt;strong&gt;Add Tags&lt;/strong&gt; section and jump to this last step. This part is important because it will allow us to monitor our training with Tensorboard and use Jupyter Notebook. We’ll be adding 4 protocols as shown in the picture below.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;
 &lt;img src=&quot;/assets/aws/step7.png&quot; alt=&quot;Drawing&quot; width=&quot;80%&quot; /&gt;
&lt;/p&gt;

&lt;p&gt;Once you click the launch button, a window will pop up and prompt you to create a key-pair. This little file is needed when ssh-ing into your instance, so download it and store it in a secure location you’ll remember. For this tutorial’s sake, I’ll be calling mine &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;aws-dl.pem&lt;/code&gt; and storing it in my Downloads folder.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;
 &lt;img src=&quot;/assets/aws/step8.png&quot; alt=&quot;Drawing&quot; width=&quot;80%&quot; /&gt;
&lt;/p&gt;

&lt;p&gt;&lt;a name=&quot;toc2&quot;&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h2 id=&quot;launching--managing-your-ec2-instance&quot;&gt;Launching &amp;amp; Managing Your EC2 Instance&lt;/h2&gt;

&lt;p&gt;We’ve finally arrived at the point where we can ssh into our EC2 instance. To do so, you’ll need to navigate to the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Instances&lt;/code&gt; page located in the navigation panel on the left of your screen. You’ll be greeted with the following:&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;
 &lt;img src=&quot;/assets/aws/step9.png&quot; alt=&quot;Drawing&quot; width=&quot;80%&quot; /&gt;
&lt;/p&gt;

&lt;p&gt;You need to take note of 2 things:&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;&lt;strong&gt;Public DNS (IPv4)&lt;/strong&gt;: &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;ec2-52-42-90-161.us-west-2.compute.amazonaws.com&lt;/code&gt;&lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;IPv4 Public IP&lt;/strong&gt;: &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;52.42.90.161&lt;/code&gt;&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;Other than that, there are just 2 ways to interact with your instance you need to be aware of: &lt;strong&gt;login&lt;/strong&gt; with ssh and &lt;strong&gt;copy&lt;/strong&gt; a file to it with scp.&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;ssh -v -i X ubuntu@Y&lt;/code&gt; where X represents the path to the key-pair file and Y represents the Public IP of your instance.&lt;/li&gt;
  &lt;li&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;scp -i W -r X ubuntu@Y:Z&lt;/code&gt; where W is the path to the key-pair file, X is the path to the local file, Y is the Public IP, and Z is the destination path on the instance.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;It’s important to note that if you’re using the key-pair file for the very first time, you’ll need to change its permission to read and write by running &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;chmod 600 ~/Downloads/aws-dl.pem&lt;/code&gt;.&lt;/p&gt;

&lt;p&gt;With all that being said, we can finally fire up a terminal and execute the following command:&lt;/p&gt;

&lt;p&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;
ssh -v -i ~/Downloads/aws-dl.pem ubuntu@52.42.90.161
&lt;/code&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;
 &lt;img src=&quot;/assets/aws/term1.png&quot; alt=&quot;Drawing&quot; width=&quot;80%&quot; /&gt;
&lt;/p&gt;

&lt;p&gt;Enter yes, and voila! You should be successfully logged in. The instance is still not ready for use as there are a few more things that need to be done, but fear not. I’ve created a small bash script that you can execute which automates the following:&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;It downloads and installs the required nvidia gpu drivers.&lt;/li&gt;
  &lt;li&gt;It updates and upgrades the distribution packages.&lt;/li&gt;
  &lt;li&gt;It installs python3 along with virtualenv.&lt;/li&gt;
  &lt;li&gt;It creates a virtualenv called &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;deepL&lt;/code&gt; that will house all the required pip packages and PyTorch.&lt;/li&gt;
  &lt;li&gt;And it finally installs PyTorch v0.2.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;Go ahead and download &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;install.sh&lt;/code&gt; from my &lt;a href=&quot;https://github.com/kevinzakka/blog-code/tree/master/aws-pytorch&quot;&gt;repo&lt;/a&gt; and save it to your Desktop. We need to copy it to our instance, so apply the command I mentioned above:&lt;/p&gt;

&lt;p&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;
scp -i ~/Downloads/aws-dL.pem -r ~/Desktop/install.sh ubuntu@52.42.90.161:~/.
&lt;/code&gt;&lt;/p&gt;

&lt;p&gt;Next, go back to the terminal window logged into the instance and execute the following 2 commands:&lt;/p&gt;

&lt;div class=&quot;language-plaintext highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;chmod +x install.sh
./install.sh
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Once that’s done, you’ll need to reboot your instance. Enter &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;exit&lt;/code&gt; at the command line and navigate to your browser as in the image below. Be patient and wait for a few minutes before you ssh back into the instance!&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;
 &lt;img src=&quot;/assets/aws/step10.png&quot; alt=&quot;Drawing&quot; width=&quot;80%&quot; /&gt;
&lt;/p&gt;

&lt;p&gt;At this point, we should sanity check our installation by seeing if PyTorch loads correctly.&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;First, activate the virtualenv by executing &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;source ~/envs/deepL/bin/activate&lt;/code&gt;.&lt;/li&gt;
  &lt;li&gt;Enter &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;python&lt;/code&gt; and inside the interpreter, &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;import torch&lt;/code&gt; then &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;torch.__version__&lt;/code&gt;. Fingers crossed, this should print out &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;0.2.0_1&lt;/code&gt;.&lt;/li&gt;
  &lt;li&gt;Lastly, check that the GPU is visible by typing &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;torch.cuda.is_available()&lt;/code&gt; which should print out True.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;&lt;span style=&quot;color:red&quot;&gt;Once you’ve finished working on your instance, you should stop it immediately to avoid incurring additional charges.&lt;/span&gt;&lt;/p&gt;

&lt;p&gt;&lt;a name=&quot;toc3&quot;&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h2 id=&quot;ssh-persistence-with-tmux&quot;&gt;SSH Persistence With TMUX&lt;/h2&gt;

&lt;p&gt;I would be doing you a great disservice if I didn’t mention this nifty little package called &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;tmux&lt;/code&gt; that you can use when running your instances for long periods of time. &lt;em&gt;What exactly is tmux, and why should you use it&lt;/em&gt;?&lt;/p&gt;

&lt;p&gt;Well, if you’re shhed into an instance, peacefully running a job, and your connection suddenly drops, your ssh connection will automatically get killed. This means anything running on that instance stops as well (i.e. your model will stop training). Closing your laptop to commute from university to your house for example becomes a big no no.&lt;/p&gt;

&lt;div class=&quot;imgcap&quot;&gt;
&lt;p align=&quot;center&quot;&gt;
 &lt;img src=&quot;/assets/aws/term3.png&quot; width=&quot;80%&quot; style=&quot;border:none;&quot; /&gt;
 &lt;div class=&quot;thecap&quot; style=&quot;text-align:center&quot;&gt;A TMUX session&lt;/div&gt;
&lt;/p&gt;
&lt;/div&gt;

&lt;p&gt;This is where tmux comes in! Tmux makes it so that anything running within a session persists even if the connection drops or the terminal gets killed. To see it in action, I’d suggest you watch the following &lt;a href=&quot;https://www.youtube.com/watch?v=BHhA_ZKjyxo&quot;&gt;video&lt;/a&gt;.&lt;/p&gt;

&lt;p&gt;Thus, your workflow should always be as follows:&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;SSH into your aws instance.&lt;/li&gt;
  &lt;li&gt;Create a new tmux session called work using the command &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;tmux new -s work&lt;/code&gt;.&lt;/li&gt;
  &lt;li&gt;Do everything as you would previously.&lt;/li&gt;
  &lt;li&gt;Detach from the session by pressing &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;ctrl-b&lt;/code&gt; followed by &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;d&lt;/code&gt;.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;Once you’ve detached yourself from the session, you can work on anything else, even go to sleep… Subsequently, if you need to reattach to that particular tmux session to check your progress, run &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;tmux a -t work&lt;/code&gt;.&lt;/p&gt;

&lt;p&gt;That’s pretty much it. For a more complete list of tmux commands, you should refer to this lovely &lt;a href=&quot;https://gist.github.com/MohamedAlaa/2961058&quot;&gt;cheatsheet&lt;/a&gt;.&lt;/p&gt;

&lt;p&gt;&lt;a name=&quot;toc4&quot;&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h2 id=&quot;conclusion&quot;&gt;Conclusion&lt;/h2&gt;

&lt;p&gt;In this tutorial, we went over the basic steps needed to create a free, GPU-powered Amazon AWS instance. We explored how to interact with our instance using the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;ssh&lt;/code&gt; and &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;scp&lt;/code&gt; commands and how a bash script could be leveraged to download and install all the required packages needed to run PyTorch. Finally, we saw how we could make our ssh session persistent using a very important program called tmux.&lt;/p&gt;

&lt;p&gt;Until next time!&lt;/p&gt;

&lt;hr /&gt;
&lt;div class=&quot;footnotes&quot;&gt;
  &lt;ol&gt;
    &lt;li id=&quot;fn:1&quot;&gt;
      &lt;p&gt;This is for a GPU-powered p2.xlarge instance with an on-demand price of around $0.9/hr. &lt;a href=&quot;#fnref:1&quot; class=&quot;reversefootnote&quot;&gt;&amp;#8617;&lt;/a&gt;&lt;/p&gt;
    &lt;/li&gt;
    &lt;li id=&quot;fn:2&quot;&gt;
      &lt;p&gt;A terminated instance gets deleted, meaning you lose whatever’s on there permanently. On the other hand, a stopped instance just goes offline so you don’t get charged for it and you can fire it back up again at a later time. &lt;a href=&quot;#fnref:2&quot; class=&quot;reversefootnote&quot;&gt;&amp;#8617;&lt;/a&gt;&lt;/p&gt;
    &lt;/li&gt;
  &lt;/ol&gt;
&lt;/div&gt;
</description>
        <pubDate>Sun, 13 Aug 2017 00:00:00 +0000</pubDate>
        <link>http://kevinzakka.github.io/2017/08/13/aws-pytorch/</link>
        <guid isPermaLink="true">http://kevinzakka.github.io/2017/08/13/aws-pytorch/</guid>
        
        <category>deep learning</category>
        
        <category>aws</category>
        
        <category>amazon</category>
        
        <category>pytorch</category>
        
        <category>2017</category>
        
        
      </item>
    
      <item>
        <title>Understanding Recurrent Neural Networks - Part I</title>
        <description>&lt;p&gt;Recurrent Neural Networks have been my Achilles’ heel for the past few months. Admittedly, I haven’t had the grit to sit down and work out their details, but I’ve figured it’s time I stop treating them like black boxes and try instead to discover what makes them tick. My intentions with this series are hence twofold: first, to combat my weakness by understanding their inner workings and coding one from scratch; and second, to write down what I learn in order to reinforce the insights I may gain along the way.&lt;/p&gt;

&lt;p&gt;In this first installment, we’ll be introducing the intuition behind RNNs, motivating their use by highlighting a glaring limitation of traditional neural networks. We’ll then transition into a more technical description of their architecture which will be useful for the next installment where we’ll code one from scratch in numpy.&lt;/p&gt;

&lt;h4 id=&quot;table-of-contents&quot;&gt;Table of Contents&lt;/h4&gt;

&lt;ul&gt;
  &lt;li&gt;&lt;a href=&quot;#toc1&quot;&gt;Human Learning&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#toc2&quot;&gt;The Woes of Traditional Neural Nets&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#toc3&quot;&gt;Enhancing Neural Networks with Memory&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#toc4&quot;&gt;The Nitty Gritty Details&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#toc5&quot;&gt;References&lt;/a&gt;&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;&lt;a name=&quot;toc1&quot;&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h3 id=&quot;human-learning&quot;&gt;Human Learning&lt;/h3&gt;

&lt;blockquote&gt;
&lt;p&gt;We are the sum total of our experiences. None of us are the same as we were yesterday, nor will be tomorrow.&lt;/p&gt;
&lt;cite&gt;B.J. Neblett&lt;/cite&gt;
&lt;/blockquote&gt;

&lt;p&gt;There is an inherent truth to the quote above. Our brain pools from past experiences and combines them in intricate ways to solve new and unseen tasks. It is hardwired to work with sequences of information that we perpetually store and call upon over the course of our lives. At its core, &lt;em&gt;human learning&lt;/em&gt; can be distilled into two fundamental processes:&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;&lt;strong&gt;memorization&lt;/strong&gt;: every time we gain new information, we store it for future reference.&lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;combination&lt;/strong&gt;: not all tasks are the same, so we couple our analytical skills with a combination of our memorized, previous experiences to reason about the world.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;Consider the following pictures.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;
 &lt;img src=&quot;/assets/rnn/weird_cat.jpg&quot; alt=&quot;Drawing&quot; width=&quot;200px&quot; /&gt;&lt;img src=&quot;/assets/rnn/weird_cat2.jpg&quot; alt=&quot;Drawing&quot; width=&quot;200px&quot; /&gt;
&lt;/p&gt;

&lt;p&gt;Even though it’s in a very weird position, a child can instantly tell that the fur ball in front of it is a cat. It’ll recognize the ears, the whiskers and the snout (memory) but the shape of it all may throw it off. Subconciously however, the child may recall how human stretching deforms shape and pose (combination), and infer that the same is happening to the cat.&lt;/p&gt;

&lt;p&gt;Not all tasks require the distant past however. At times, solving a problem makes use of information that was processed only moments ago. For example, take a look at this incomplete sentence:&lt;/p&gt;

&lt;blockquote&gt;
  &lt;p&gt;I bought my usual caramel-covered popcorn with iced tea and headed to the ___.&lt;/p&gt;
&lt;/blockquote&gt;

&lt;p&gt;If I asked you to fill-in the missing word, you’d probably guess “movies”. How did you know that &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;library&lt;/code&gt; or &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;starbucks&lt;/code&gt; were invalid words? Well, it’s probably because you used context, or information from earlier in the sentence to infer the correct answer. Now think about the following. If I asked you to recite the lyrics of your favorite song backwards, would you be able to do it? Probably not… What about counting backwards? Yeah, piece of cake!&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;
 &lt;img src=&quot;/assets/rnn/yarn.jpg&quot; alt=&quot;Drawing&quot; width=&quot;200px&quot; /&gt;
&lt;/p&gt;

&lt;p&gt;So what makes reciting the song backwards so excruciatingly difficult? The answer is that counting backwards is done &lt;strong&gt;on the fly&lt;/strong&gt;. There is a logical relationship between each number, and knowing the order of the 9 digits and how subtraction works means you can count backwards from say 1845098 even if you’ve never done it before. On the other hand, you memorized the lyrics of the song in a specific order. Your brain works by &lt;strong&gt;indexing&lt;/strong&gt; from one word to the next, starting from the first word. It’s hard to index backwards for the simple reason that your brain has never done it before, so that specific sequence was never stored. Think of the memorized lyric sequence as a giant ball of yarn whose unraveled end can only be accessed with the correct first word in the forward sequence.&lt;/p&gt;

&lt;p&gt;The main takeaway is that our brains are naturally talented at working with sequences and they do so by relying on a deceptively simple, yet powerful concept called &lt;strong&gt;information persistence&lt;/strong&gt;.&lt;/p&gt;

&lt;p&gt;&lt;a name=&quot;toc2&quot;&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h3 id=&quot;the-woes-of-traditional-neural-nets&quot;&gt;The Woes of Traditional Neural Nets&lt;/h3&gt;

&lt;p&gt;We live in a world that is inherently sequential. Audio, video, and language (even your DNA!) are but a few examples of data in which information at a given time step is intricately dependent on information from previous timesteps. So how is all this related to deep learning? Well, think about feeding a sequence of frames from a video into a neural network and asking it to predict what comes next… Or, back to our previous example, feeding a set of words and asking it to complete the sentence.&lt;/p&gt;

&lt;p&gt;It should be obvious to you that information from the past is crucial for outputting a sane and plausible prediction. But traditional neural networks can’t do this because they operate on the fundamental assumption that inputs are independent! This is a problem because it means our output at any given time is completely and &lt;strong&gt;solely&lt;/strong&gt; determined by the input at that same time. There is no previous history and our network cannot capitalize on the complex temporal dependencies that exist between the different frames or words to refine its predictions.&lt;/p&gt;

&lt;p&gt;This is where &lt;em&gt;Recurrent Neural Networks&lt;/em&gt; come in! RNNs allow us to deal with sequences by incorporating a mechanism that stores and leverages information from previous history, sort of like a memory. Put differently, whereas a traditional net maps &lt;strong&gt;one&lt;/strong&gt; input to an output, a recurrent net maps an &lt;strong&gt;entire history&lt;/strong&gt; of previous inputs to each output. If that’s still obscure to you, just think of RNNs as a traditional neural net enhanced with a loop&lt;sup id=&quot;fnref:1&quot;&gt;&lt;a href=&quot;#fn:1&quot; class=&quot;footnote&quot;&gt;1&lt;/a&gt;&lt;/sup&gt;, one that allows for information to persist across timesteps.&lt;/p&gt;

&lt;div class=&quot;imgcap&quot;&gt;
&lt;p align=&quot;center&quot;&gt;
 &lt;img src=&quot;/assets/rnn/draw2.gif&quot; width=&quot;400&quot; style=&quot;border:none;&quot; /&gt;
 &lt;div class=&quot;thecap&quot; style=&quot;text-align:center&quot;&gt;(&lt;a href=&quot;https://www.youtube.com/watch?v=Zt-7MI9eKEo&quot;&gt;Video Courtesy&lt;/a&gt;) DRAW model improving its output by iterating over the canvas rather than producing the image in one shot.&lt;/div&gt;
&lt;/p&gt;
&lt;/div&gt;

&lt;p&gt;It is important to note that recurrent neural nets aren’t just bound to sequential data in the sense that many problems can be tackled by decomposing them into a series of smaller subproblems. The idea is that instead of burdening our model with predicting an output in one go, we allow it the much easier task of predicting iterative sub-outputs, where each sub-output is an improvement or refinement on the previous step. As an example, a recurrent net&lt;sup id=&quot;fnref:2&quot;&gt;&lt;a href=&quot;#fn:2&quot; class=&quot;footnote&quot;&gt;2&lt;/a&gt;&lt;/sup&gt; was used to generate handwritten digits in a sequential fashion, mimicking the way artists refine and reassess their work with brushstrokes.&lt;/p&gt;

&lt;blockquote&gt;
  &lt;p&gt;The idea is that instead of burdening our model with predicting an output in one go, we allow it the much easier task of predicting iterative sub-outputs, where each sub-output is an improvement or refinement on the previous step.&lt;/p&gt;
&lt;/blockquote&gt;

&lt;p&gt;&lt;a name=&quot;toc3&quot;&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h3 id=&quot;enhancing-neural-nets-with-memory&quot;&gt;Enhancing Neural Nets with Memory&lt;/h3&gt;

&lt;p&gt;So how exactly can we endow our networks with the ability to memorize? To answer this question, let’s recall our basic hidden layer neural network, which takes as input a vector &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;X&lt;/code&gt;, dot products it with a weight matrix &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;W&lt;/code&gt; and applies a nonlinearity. We’ll consider the output &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;y&lt;/code&gt; when three successive inputs are fed through the network. Note that the bias term has been eliminated so as to simplify the notation, and I’ve taken the liberty of coloring the equations to make certain patterns stand out.&lt;/p&gt;

&lt;script type=&quot;math/tex; mode=display&quot;&gt;y_0 = f(W_x\color{blue}{X_0})&lt;/script&gt;

&lt;script type=&quot;math/tex; mode=display&quot;&gt;y_1 = f(W_x \color{green}{X_1})&lt;/script&gt;

&lt;script type=&quot;math/tex; mode=display&quot;&gt;y_2 = f(W_x \color{red}{X_2})&lt;/script&gt;

&lt;p&gt;Given the simple API above, it’s pretty clear that each output is solely determined by its input, i.e. there is no trace of past inputs in the calculation of its value. So let’s alter the API by allowing our hidden layer to use a combination of both the current input and the previous input, and visualize what happens.&lt;/p&gt;

&lt;script type=&quot;math/tex; mode=display&quot;&gt;y_0 = f(W_x\color{blue}{X_0})&lt;/script&gt;

&lt;script type=&quot;math/tex; mode=display&quot;&gt;y_1 = f(W_x \color{green}{X_1} + W_h\color{blue}{X_0})&lt;/script&gt;

&lt;script type=&quot;math/tex; mode=display&quot;&gt;y_2 = f(W_x \color{red}{X_2} + W_h\color{green}{X_1})&lt;/script&gt;

&lt;p&gt;Nice! By introducing recurrence into the formula, we’ve managed to obtain a mix of 2 colors in each hidden layer. Intuitively, our network now has a memory depth of 1, equivalent to “seeing” one step backwards in time. Remember though that our goal is to be able to capture information across &lt;strong&gt;all&lt;/strong&gt; previous timesteps, so this does not cut it.&lt;/p&gt;

&lt;p&gt;Hmm… What if we feed in a combination of the current input and the previous hidden layer?&lt;/p&gt;

&lt;script type=&quot;math/tex; mode=display&quot;&gt;y_0 = f(W_x\color{blue}{X_0})&lt;/script&gt;

&lt;script type=&quot;math/tex; mode=display&quot;&gt;y_1 = f\big(W_x \color{green}{X_1} + W_h \ f(W_x\color{blue}{X_0}) \big)&lt;/script&gt;

&lt;script type=&quot;math/tex; mode=display&quot;&gt;y_2 = f\bigg(W_x \color{red}{X_2} + W_h \ f\big(W_x \color{green}{X_1} + W_h \ f(W_x\color{blue}{X_0}) \big)\bigg)&lt;/script&gt;

&lt;p&gt;Much better! Our layer at each timestep is now a blend of all the colors that have come before it, allowing our network to take into account all its past history when computing its output. This is the power of recurrence in all its glory: creating a loop where information can persist across timesteps.&lt;/p&gt;

&lt;p&gt;&lt;a name=&quot;toc4&quot;&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h3 id=&quot;the-nitty-gritty-details&quot;&gt;The Nitty Gritty Details&lt;/h3&gt;

&lt;div class=&quot;imgcap&quot;&gt;
&lt;img src=&quot;/assets/rnn/rnn-1_layer-unrolled.svg&quot; width=&quot;300px&quot; style=&quot;border:none;&quot; /&gt;&lt;div class=&quot;thecap&quot; style=&quot;text-align:center&quot;&gt;&lt;a href=&quot;http://kbullaughey.github.io/lstm-play/rnn/&quot;&gt;Image Courtesy&lt;/a&gt;&lt;/div&gt;
&lt;/div&gt;

&lt;p&gt;At its core, an RNN can be represented by an internal, hidden state &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;h&lt;/code&gt; that gets updated with every timestep and from which an output &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;y&lt;/code&gt; can be optionally derived&lt;sup id=&quot;fnref:3&quot;&gt;&lt;a href=&quot;#fn:3&quot; class=&quot;footnote&quot;&gt;3&lt;/a&gt;&lt;/sup&gt;. This update behavior is governed by the following equations:&lt;/p&gt;

&lt;script type=&quot;math/tex; mode=display&quot;&gt;\begin{cases}
    h_t = f \big(W_{xh}x_t + W_{hh}h_{t-1}+b_1\big) \\
    y_t = g \big(W_{hy}h_t + b_2\big)
\end{cases}&lt;/script&gt;

&lt;p&gt;Don’t let the above notation scare you. It’s actually very simple once you dissect it.&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;&lt;script type=&quot;math/tex&quot;&gt;W_{xh}x_t&lt;/script&gt; -  we’re multiplying the input &lt;script type=&quot;math/tex&quot;&gt;x_t&lt;/script&gt; by a weight matrix &lt;script type=&quot;math/tex&quot;&gt;W_{xh}&lt;/script&gt;. You can think of this dot product as a way for the hidden layer to extract information out of the input.&lt;/li&gt;
  &lt;li&gt;&lt;script type=&quot;math/tex&quot;&gt;W_{hh}h_{t-1}&lt;/script&gt; - this dot product is allowing the network to extract information from an entire history of past inputs which it will use in conjunction with information gathered from the current input, to compute its output. This is the crucial, self-defining property of RNNs.&lt;/li&gt;
  &lt;li&gt;&lt;script type=&quot;math/tex&quot;&gt;f&lt;/script&gt; and &lt;script type=&quot;math/tex&quot;&gt;g&lt;/script&gt; are activation functions that squash the dot products to a specific range. The function &lt;script type=&quot;math/tex&quot;&gt;f&lt;/script&gt; is usually &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;tanh&lt;/code&gt; or &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;ReLU&lt;/code&gt;. &lt;script type=&quot;math/tex&quot;&gt;g&lt;/script&gt; can be a &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;softmax&lt;/code&gt; when we want to output class probabilities.&lt;/li&gt;
  &lt;li&gt;&lt;script type=&quot;math/tex&quot;&gt;b_1&lt;/script&gt; and &lt;script type=&quot;math/tex&quot;&gt;b_2&lt;/script&gt; are biases that help offset the outputs away from the origin (similar to the b in your typical &lt;script type=&quot;math/tex&quot;&gt;ax+b&lt;/script&gt; line).&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;As you can see, the Vanilla RNN model is quite simple. Once its architecture has been defined, training it is exactly the same as with normal neural nets, i.e. initializing the weight matrices and biases, defining a loss function and minimizing that loss function using some form of gradient descent.&lt;/p&gt;

&lt;p&gt;This conclues our first installment in the series. In next week’s blog post, we’ll be coding our very own RNN from the ground up in numpy and apply it to a language modeling task. Stay tuned until then…&lt;/p&gt;

&lt;p&gt;&lt;a name=&quot;toc5&quot;&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h3 id=&quot;references&quot;&gt;References&lt;/h3&gt;

&lt;p&gt;There are a ton of resources that helped me better grasp the fundamentals of RNNs. I’d like to thank &lt;a href=&quot;https://twitter.com/iamtrask&quot;&gt;iamtrask&lt;/a&gt; especially, for letting me use his idea of colors to explain neural memory. You can read his amazing blog post &lt;a href=&quot;https://iamtrask.github.io/2015/11/15/anyone-can-code-lstm/&quot;&gt;here&lt;/a&gt;.&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;Denny Britz’s RNN series - click &lt;a href=&quot;http://www.wildml.com/2015/09/recurrent-neural-networks-tutorial-part-1-introduction-to-rnns/&quot;&gt;here&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;Andrej Karpathy’s Blog Post - click &lt;a href=&quot;http://karpathy.github.io/2015/05/21/rnn-effectiveness/&quot;&gt;here&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;Chris Olah’s Blog Post - click &lt;a href=&quot;http://colah.github.io/posts/2015-08-Understanding-LSTMs/&quot;&gt;here&lt;/a&gt;&lt;/li&gt;
&lt;/ul&gt;

&lt;hr /&gt;
&lt;div class=&quot;footnotes&quot;&gt;
  &lt;ol&gt;
    &lt;li id=&quot;fn:1&quot;&gt;
      &lt;p&gt;If you’re familiar with Control Theory, this should be slightly reminiscent of a feedback loop, although not quite. &lt;a href=&quot;#fnref:1&quot; class=&quot;reversefootnote&quot;&gt;&amp;#8617;&lt;/a&gt;&lt;/p&gt;
    &lt;/li&gt;
    &lt;li id=&quot;fn:2&quot;&gt;
      &lt;p&gt;I’m referring to the &lt;a href=&quot;https://arxiv.org/abs/1502.04623&quot;&gt;DRAW&lt;/a&gt; model introduced by Gregor et. al at Deepmind. &lt;a href=&quot;#fnref:2&quot; class=&quot;reversefootnote&quot;&gt;&amp;#8617;&lt;/a&gt;&lt;/p&gt;
    &lt;/li&gt;
    &lt;li id=&quot;fn:3&quot;&gt;
      &lt;p&gt;In the simplest of cases, the hidden state &lt;script type=&quot;math/tex&quot;&gt;h_t&lt;/script&gt; is used as both the output &lt;script type=&quot;math/tex&quot;&gt;y_t&lt;/script&gt; and input to the next hidden state &lt;script type=&quot;math/tex&quot;&gt;h_{t+1}&lt;/script&gt;. &lt;a href=&quot;#fnref:3&quot; class=&quot;reversefootnote&quot;&gt;&amp;#8617;&lt;/a&gt;&lt;/p&gt;
    &lt;/li&gt;
  &lt;/ol&gt;
&lt;/div&gt;
</description>
        <pubDate>Thu, 20 Jul 2017 00:00:00 +0000</pubDate>
        <link>http://kevinzakka.github.io/2017/07/20/rnn/</link>
        <guid isPermaLink="true">http://kevinzakka.github.io/2017/07/20/rnn/</guid>
        
        <category>deep learning</category>
        
        <category>rnn</category>
        
        <category>sequences</category>
        
        <category>2017</category>
        
        
      </item>
    
      <item>
        <title>Deep Learning Paper Implementations: Spatial Transformer Networks - Part II</title>
        <description>&lt;div class=&quot;imgcap&quot;&gt;
&lt;img src=&quot;/assets/stn2/ai.jpg&quot; width=&quot;45%&quot; style=&quot;border:none;&quot; /&gt;&lt;div class=&quot;thecap&quot; style=&quot;text-align:center&quot;&gt;&lt;a href=&quot;https://www.technologyreview.com/s/601519/how-to-create-a-malevolent-artificial-intelligence/&quot;&gt;Image Courtesy&lt;/a&gt;&lt;/div&gt;
&lt;/div&gt;

&lt;p&gt;In last week’s &lt;a href=&quot;https://kevinzakka.github.io/2017/01/10/stn-part1/&quot;&gt;blog post&lt;/a&gt;, we introduced two very important concepts: &lt;strong&gt;affine transformations&lt;/strong&gt; and &lt;strong&gt;bilinear interpolation&lt;/strong&gt; and mentioned that they would prove crucial in understanding Spatial Transformer Networks.&lt;/p&gt;

&lt;p&gt;Today, we’ll provide a detailed, section-by-section summary of the &lt;a href=&quot;https://arxiv.org/abs/1506.02025&quot;&gt;Spatial Transformer Networks&lt;/a&gt; paper, a concept originally introduced by researchers &lt;em&gt;Max Jaderberg, Karen Simonyan, Andrew Zisserman and Koray Kavukcuoglu&lt;/em&gt; of Google Deepmind.&lt;/p&gt;

&lt;p&gt;Hopefully, it’ll will give you a clear understanding of the module and prove useful for next week’s blog post where we’ll cover its implementation in Tensorflow.&lt;/p&gt;

&lt;h4 id=&quot;table-of-contents&quot;&gt;Table of Contents&lt;/h4&gt;

&lt;ul&gt;
  &lt;li&gt;&lt;a href=&quot;#toc1&quot;&gt;Motivation&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#toc2&quot;&gt;Pooling Operator&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#toc3&quot;&gt;Spatial Transformer Network&lt;/a&gt;
    &lt;ul&gt;
      &lt;li&gt;&lt;a href=&quot;#toc4&quot;&gt;Localisation Network&lt;/a&gt;&lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#toc5&quot;&gt;Parametrised Sampling Grid&lt;/a&gt;&lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#toc6&quot;&gt;Differentiable Image Sampling&lt;/a&gt;&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#toc7&quot;&gt;Fun with STNs&lt;/a&gt;
    &lt;ul&gt;
      &lt;li&gt;&lt;a href=&quot;#toc8&quot;&gt;Distorted MNIST&lt;/a&gt;&lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#toc9&quot;&gt;GTSRB dataset&lt;/a&gt;&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#toc10&quot;&gt;Summary&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#toc11&quot;&gt;References&lt;/a&gt;&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;&lt;a name=&quot;toc1&quot;&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h2 id=&quot;motivation&quot;&gt;Motivation&lt;/h2&gt;

&lt;p&gt;When working on a classification task, it is usually desirable that our system be &lt;strong&gt;robust&lt;/strong&gt; to input variations. By this, we mean to say that should an input undergo a certain “transformation” so to speak, our classification model should in theory spit out the same class label as before that transformation. A few examples of the “challenges” our image classification model may face include:&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;&lt;strong&gt;scale variation&lt;/strong&gt;: variations in size both in the real world and in the image.&lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;viewpoint variation&lt;/strong&gt;: different object orientation with respect to the viewer.&lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;deformation&lt;/strong&gt;: non rigid bodies can be deformed and twisted in unusual shapes.&lt;/li&gt;
&lt;/ul&gt;

&lt;div class=&quot;imgcap&quot;&gt;
&lt;div&gt;
&lt;img src=&quot;/assets/stn2/var1.png&quot; style=&quot;max-width:49%; height:350px;&quot; /&gt;
&lt;img src=&quot;/assets/stn2/var2.png&quot; style=&quot;max-width:49%; height:200px;&quot; /&gt;
&lt;/div&gt;
&lt;div class=&quot;thecap&quot; style=&quot;text-align:center&quot;&gt;&lt;a href=&quot;http://cs231n.github.io/classification/&quot;&gt;Image Courtesy&lt;/a&gt;&lt;/div&gt;
&lt;/div&gt;

&lt;p&gt;For illustration purposes, take a look at the images above. While the task of classifying them may seem trivial to a human being, recall that our computer algorithms only work with raw 3D arrays of brightness values so a tiny change in an input image can alter every single pixel value in the corresponding array. Hence, our ideal image classification model should in theory be able to disentangle object pose and deformation from texture and shape.&lt;/p&gt;

&lt;p&gt;For a different type of intuition, let’s again take a look at the following cat images.&lt;/p&gt;

&lt;div class=&quot;imgcap&quot;&gt;
&lt;div&gt;
&lt;img src=&quot;/assets/stn2/cat2.jpg&quot; style=&quot;max-width:49%; height:300px;&quot; /&gt;
&lt;img src=&quot;/assets/stn2/cat2_.jpg&quot; style=&quot;max-width:49%; height:300px;&quot; /&gt;
&lt;img src=&quot;/assets/stn2/cat1.jpg&quot; style=&quot;max-width:49%; height:250px;&quot; /&gt;
&lt;img src=&quot;/assets/stn2/cat1_.jpg&quot; style=&quot;max-width:49%; height:250px;&quot; /&gt;
&lt;/div&gt;
&lt;div class=&quot;thecap&quot; style=&quot;text-align:center&quot;&gt; &lt;b&gt;Left:&lt;/b&gt; Cat images which may present classification challenges. &lt;b&gt;Right:&lt;/b&gt; Transformed images which yield a simplified classification pipeline.&lt;/div&gt;
&lt;/div&gt;

&lt;p&gt;Would it not be extremely desirable if our model could go from left to right using some sort of crop and scale-normalize combination so as to simplify the subsequent classification task?&lt;/p&gt;

&lt;p&gt;&lt;a name=&quot;toc2&quot;&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h2 id=&quot;pooling-layers&quot;&gt;Pooling Layers&lt;/h2&gt;

&lt;p&gt;It turns out that the pooling layers we use in our neural network architectures actually endow our models with a certain degree of spatial invariance. Recall that the pooling operator acts as a sort of downsampling mechanism. It progressively reduces the spatial size of the feature map along the depth dimension, cutting down the amount of parameters and computational cost.&lt;/p&gt;

&lt;hr /&gt;

&lt;div class=&quot;fig figcenter fighighlight&quot;&gt;
  &lt;img src=&quot;/assets/stn2/pool.jpeg&quot; width=&quot;36%&quot; /&gt;
  &lt;img src=&quot;/assets/stn2/maxpool.jpeg&quot; width=&quot;59%&quot; style=&quot;border-left: 1px solid black;&quot; /&gt;
  &lt;div class=&quot;figcaption&quot;&gt;
    Pooling layer downsamples the volume spatially. &lt;b&gt;Left:&lt;/b&gt; In this example, the input volume of size [224x224x64] is pooled with filter size 2, stride 2 into output volume of size [112x112x64]. &lt;b&gt;Right:&lt;/b&gt; 2x2 max pooling. (&lt;a href=&quot;http://cs231n.github.io/convolutional-networks/#pool&quot;&gt;Image Courtesy&lt;/a&gt;)
  &lt;/div&gt;
&lt;/div&gt;

&lt;hr /&gt;

&lt;p&gt;&lt;strong&gt;How exactly does it provide invariance?&lt;/strong&gt; Well think of it this way. The idea behind pooling is to take a complex input, split it up into cells, and “pool” the information from these complex cells to produce a set of simpler cells that describe the output. So for example, say we have 3 images of the number 7, each in a different orientation. A pool over a small grid in each image would detect the number 7 regardless of its position in that grid since we’d be capturing approximately the same information by aggregating pixel values.&lt;/p&gt;

&lt;p&gt;Now there are a few downsides to pooling which make it an undesirable operator. For one, pooling is &lt;strong&gt;destructive&lt;/strong&gt;. It discards 75% of feature activations when it is used, meaning we are guaranteed to lose exact positional information. Now you may be wondering why this is bad since we mentioned earlier that it endowed our network with some spatial robustness. Well the thing is that positional information is invaluable in visual recognition tasks. Think of our cat classifier above. It may be important to know where the position of the whiskers are relative to, say the snout. This can’t be achieved when it is this sort of information we throw away when we use max pooling.&lt;/p&gt;

&lt;p&gt;Another limitation of pooling is that it is &lt;strong&gt;local and predefined&lt;/strong&gt;. With a small receptive field, the effects of a pooling operator are only felt towards deeper layers of the network meaning intermediate feature maps may suffer from large input distortions. And remember, we can’t just increase the receptive field arbitrarily because then that would downsample our feature map too agressively.&lt;/p&gt;

&lt;p&gt;The main takeaway is that ConvNets are not invariant to relatively large input distortions. This limitation is due to having only a restricted, pre-defined pooling mechanism for dealing with spatial variation of the data. This is where Spatial Transformer Networks come into play!&lt;/p&gt;

&lt;blockquote&gt;
  &lt;p&gt;The pooling operation used in convolutional neural networks is a big mistake and the fact that it works so well is a disaster. (Geoffrey Hinton, Reddit AMA)&lt;/p&gt;
&lt;/blockquote&gt;

&lt;p&gt;&lt;a name=&quot;toc3&quot;&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h2 id=&quot;spatial-transformer-networks-stns&quot;&gt;Spatial Transformer Networks (STNs)&lt;/h2&gt;

&lt;p&gt;The Spatial Transformer mechanism addresses the issues above by providing Convolutional Neural Networks with explicit spatial transformation capabilities. It possesses 3 defining properties that make it very appealing.&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;&lt;strong&gt;modular&lt;/strong&gt;: STNs can be inserted anywhere into existing architectures with relatively small tweaking.&lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;differentiable&lt;/strong&gt;: STNs can be trained with backprop allowing for end-to-end training of the models they are injected in.&lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;dynamic:&lt;/strong&gt; STNs perform active spatial transformation on a feature map for each input sample as compared to the pooling layer which acted identically for all input samples.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;As you can see, the Spatial Transformer is superior to the Pooling operator in all regards. So this begs the following question: &lt;strong&gt;what exactly is a Spatial Transformer?&lt;/strong&gt;&lt;/p&gt;

&lt;div class=&quot;imgcap&quot;&gt;
&lt;img src=&quot;/assets/stn2/stn_arch.png&quot; width=&quot;65%&quot; style=&quot;border:none;&quot; /&gt;&lt;div class=&quot;thecap&quot; style=&quot;text-align:center&quot;&gt;&lt;a href=&quot;https://arxiv.org/abs/1506.02025&quot;&gt;Image Courtesy&lt;/a&gt;&lt;/div&gt;
&lt;/div&gt;

&lt;p&gt;The Spatial Transformer module consists in three components shown in the figure above: a &lt;strong&gt;localisation network&lt;/strong&gt;, a &lt;strong&gt;grid generator&lt;/strong&gt; and a &lt;strong&gt;sampler&lt;/strong&gt;. Before we dive into each of their details, I’d like to briefly remind you of a 3 step pipeline we talked about last week.&lt;/p&gt;

&lt;div class=&quot;imgcap&quot;&gt;
&lt;img src=&quot;/assets/stn2/pipeline.png&quot; width=&quot;75%&quot; style=&quot;border:none;&quot; /&gt;&lt;div class=&quot;thecap&quot; style=&quot;text-align:center&quot;&gt;&lt;a href=&quot;https://kevinzakka.github.io/2017/01/10/stn-part1/&quot;&gt;Affine Transformation Pipeline&lt;/a&gt;&lt;/div&gt;
&lt;/div&gt;

&lt;p&gt;Recall that we can’t just blindly rush to the input image and apply our affine transformation. It’s important to first create a sampling grid, transform it, and then sample the input image using the grid. With that being said, let’s jump into the core components of the Spatial Transformer.&lt;/p&gt;

&lt;p&gt;&lt;a name=&quot;toc4&quot;&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h3 id=&quot;localisation-network&quot;&gt;Localisation Network&lt;/h3&gt;

&lt;p&gt;The goal of the localisation network is to spit out the parameters &lt;script type=&quot;math/tex&quot;&gt;\theta&lt;/script&gt; of the  affine transformation that’ll be applied to the input feature map. More formally, our localisation net is defined as follows:&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;&lt;strong&gt;input&lt;/strong&gt;: feature map U of shape (H, W, C)&lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;output&lt;/strong&gt;: transformation matrix &lt;script type=&quot;math/tex&quot;&gt;\theta&lt;/script&gt; of shape (6,)&lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;architecture&lt;/strong&gt;: fully-connected network or ConvNet as well.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;As we train our network, we would like our localisation net to output more and more accurate thetas. &lt;strong&gt;What do we mean by accurate?&lt;/strong&gt; Well, think of our digit 7 rotated by 90 degrees counterclockwise. After say 2 epochs, our localisation net may output a transformation matrix which performs a 45 degree clockwise rotation and after 5 epochs for example, it may actually learn to do a complete 90 degree clockwise rotation. The effect is that our output image looks like a standard digit 7, something our neural network has seen in the training data and can easily classify.&lt;/p&gt;

&lt;p&gt;Another way to look at it is that the localisation network learns to store the knowledge of how to transform each training sample in the weights of its layers.&lt;/p&gt;

&lt;p&gt;&lt;a name=&quot;toc5&quot;&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h3 id=&quot;parametrised-sampling-grid&quot;&gt;Parametrised Sampling Grid&lt;/h3&gt;

&lt;p&gt;The grid generator’s job is to output a parametrised sampling grid, which is a set of points where the input map &lt;strong&gt;should&lt;/strong&gt; be sampled to produce the desired transformed output.&lt;/p&gt;

&lt;p&gt;Concretely, the grid generator first creates a normalized meshgrid of the same size as the input image U of shape (H, W), that is, a set of indices &lt;script type=&quot;math/tex&quot;&gt;(x^t, y^t)&lt;/script&gt; that cover the whole input feature map (the subscript t here stands for target coordinates in the output feature map). Then, since we’re applying an affine transformation to this grid and would like to use translations, we proceed by adding a row of ones to our coordinate vector to obtain its homogeneous equivalent. This is the little trick we also talked about last week. Finally, we reshape our 6 parameter &lt;script type=&quot;math/tex&quot;&gt;\theta&lt;/script&gt; to a 2x3 matrix and perform the following multiplication which results in our desired parametrised sampling grid.&lt;/p&gt;

&lt;script type=&quot;math/tex; mode=display&quot;&gt;% &lt;![CDATA[
\begin{bmatrix}
x^{s} \\
y^{s} \\
\end{bmatrix} = \begin{bmatrix}
\theta_{11} &amp; \theta_{12} &amp; \theta_{13} \\
\theta_{21} &amp; \theta_{22} &amp; \theta_{23}
\end{bmatrix}
%
\begin{bmatrix}
x^t \\
y^t \\
1
\end{bmatrix} %]]&gt;&lt;/script&gt;

&lt;p&gt;The column vector &lt;script type=&quot;math/tex&quot;&gt;\begin{bmatrix}
x^s \\
y^s
\end{bmatrix}&lt;/script&gt; consists in a set of indices that tell us where we should sample our input to obtain the desired transformed output.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;But wait a minute, what if those indices are fractional?&lt;/strong&gt; Bingo! That’s why we learned about bilinear interpolation and this is exactly what we do next.&lt;/p&gt;

&lt;p&gt;&lt;a name=&quot;toc6&quot;&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h3 id=&quot;differentiable-image-sampling&quot;&gt;Differentiable Image Sampling&lt;/h3&gt;

&lt;p&gt;Since bilinear interpolation is differentiable, it is perfectly suitable for the task at hand. Armed with the input feature map and our parametrised sampling grid, we proceed with bilinear sampling and obtain our output feature map V of shape (H’, W’, C’). Note that this implies that we can perform downsampling and upsampling by specifying the shape of our sampling grid. (take that pooling!) We definitely aren’t restricted to bilinear sampling, and there are other sampling kernels we can use, but the important takeaway is that it must be differentiable to allow the loss gradients to flow all the way back to our localisation network.&lt;/p&gt;

&lt;div class=&quot;imgcap&quot;&gt;
&lt;img src=&quot;/assets/stn2/transformation.png&quot; width=&quot;60%&quot; style=&quot;border:none;&quot; /&gt;
&lt;div class=&quot;thecap&quot; style=&quot;text-align:justify&quot;&gt;(&lt;a href=&quot;https://arxiv.org/abs/1506.02025&quot;&gt;Image Courtesy&lt;/a&gt;) Two examples of applying the parameterised sampling grid to an image U producing the output V. &lt;b&gt;(a)&lt;/b&gt; Identity transform (i.e. U = V) &lt;b&gt;(2)&lt;/b&gt; Affine Transformation (i.e. rotation)&lt;/div&gt;
&lt;/div&gt;

&lt;p&gt;The above illustrates the inner workings of the Spatial Transformer. Basically it boils down to 2 crucial concepts we’ve been talking about all week: an affine transformation followed by bilinear interpolation. Take a moment and admire the elegance of such a mechanism! We’re letting our network learn the optimal affine transformation parameters that will help it ultimately succeed in the classification task &lt;strong&gt;all on its own&lt;/strong&gt;.&lt;/p&gt;

&lt;p&gt;&lt;a name=&quot;toc7&quot;&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h2 id=&quot;fun-with-spatial-transformers&quot;&gt;Fun with Spatial Transformers&lt;/h2&gt;

&lt;p&gt;As a final note, I’ll provide 2 examples that illustrate the power of Spatial Transformers. I’ve attached the references for each example at the bottom of the post, so make sure to look those up if they pique your interest.&lt;/p&gt;

&lt;p&gt;&lt;a name=&quot;toc8&quot;&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h3 id=&quot;distorted-mnist&quot;&gt;Distorted MNIST&lt;/h3&gt;

&lt;p&gt;Here is the result of using a spatial transformer as the first layer of a fully-connected network trained for distorted MNIST digit classification.&lt;/p&gt;

&lt;div class=&quot;imgcap&quot;&gt;
&lt;img src=&quot;/assets/stn2/mnist.png&quot; width=&quot;45%&quot; style=&quot;border:none;&quot; /&gt;&lt;div class=&quot;thecap&quot; style=&quot;text-align:center&quot;&gt;(&lt;a href=&quot;https://arxiv.org/abs/1506.02025&quot;&gt;Image Courtesy&lt;/a&gt;)&lt;/div&gt;
&lt;/div&gt;

&lt;p&gt;Notice how it has learned to do exactly what we wanted our theoretical “robust” image classification model to do: by zooming in and eliminating background clutter, it has “standardized” the input to facilitate classification. If you want to view a live animation of the transformer in action, click &lt;a href=&quot;https://drive.google.com/file/d/0B1nQa_sA3W2iN3RQLXVFRkNXN0k/view&quot;&gt;here&lt;/a&gt;.&lt;/p&gt;

&lt;p&gt;&lt;a name=&quot;toc9&quot;&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h3 id=&quot;german-traffic-sign-recognition-benchmark-gtsrb-dataset&quot;&gt;German Traffic Sign Recognition Benchmark (GTSRB) dataset&lt;/h3&gt;

&lt;div class=&quot;imgcap&quot;&gt;
&lt;div&gt;
&lt;img src=&quot;/assets/stn2/epoch_evolution.gif&quot; style=&quot;max-width:49%; height:250px;&quot; /&gt;
&lt;img src=&quot;/assets/stn2/moving_evolution.gif&quot; style=&quot;max-width:49%; height:250px;&quot; /&gt;
&lt;/div&gt;
&lt;div class=&quot;thecap&quot;&gt;(&lt;a href=&quot;http://torch.ch/blog/2015/09/07/spatial_transformers.html&quot;&gt;Image Courtesy&lt;/a&gt;) &lt;b&gt;Left&lt;/b&gt;: Behavior of the Spatial Transformer during training. Notice how it learns to focus on the traffic sign,  gradually removing background. &lt;b&gt;Right&lt;/b&gt;: Output for different input images. Note how it stays approximately contant regardless of the input variability and distortion. Pretty neat!&lt;/div&gt;
&lt;/div&gt;

&lt;p&gt;&lt;a name=&quot;toc10&quot;&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h2 id=&quot;summary&quot;&gt;Summary&lt;/h2&gt;

&lt;p&gt;In today’s blog post, we went over Google Deepmind’s Spatial Transformer Network paper. We started by introducing the different challenges classification models face, mainly how distortions in the input images can cause our classifiers to fail. One remedy is to use pooling layers; however they possess a few glaring limitations that have made them fall into disuse. The other remedy, and the subject of this blog post, is to use Spatial Transformer Networks.&lt;/p&gt;

&lt;p&gt;This consists in a differentiable module that can be inserted anywhere in ConvNet architecture to increase its geometric invariance. It effectively endows our networks with the ability to spatially transform feature maps at no extra data or supervision cost. Finally, we saw how the whole mechanism boils down to 2 familiar operations: an affine transformation and bilinear interpolation.&lt;/p&gt;

&lt;p&gt;In next week’s blog post we’ll be using what we’ve learned so far to aid us in coding this paper from scratch in Tensorflow. In the meantime, if you have any questions, feel free to post them in the comment section below.&lt;/p&gt;

&lt;p&gt;Cheers and see you next week!&lt;/p&gt;

&lt;p&gt;&lt;a name=&quot;toc11&quot;&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h2 id=&quot;references&quot;&gt;References&lt;/h2&gt;

&lt;ul&gt;
  &lt;li&gt;The original Deepmind paper - click &lt;a href=&quot;https://arxiv.org/abs/1506.02025&quot;&gt;here&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;Kudos to the Torch blog post on STNs which really helped me during the learning process - click &lt;a href=&quot;http://torch.ch/blog/2015/09/07/spatial_transformers.html&quot;&gt;here&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;Torch Implementation also helped me grasp the inner workings of STNs - check out this &lt;a href=&quot;https://github.com/qassemoquab/stnbhwd&quot;&gt;repo&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;Stanford’s CS231n as always - click &lt;a href=&quot;cs231n.github.io&quot;&gt;here&lt;/a&gt;&lt;/li&gt;
&lt;/ul&gt;
</description>
        <pubDate>Wed, 18 Jan 2017 00:00:00 +0000</pubDate>
        <link>http://kevinzakka.github.io/2017/01/18/stn-part2/</link>
        <guid isPermaLink="true">http://kevinzakka.github.io/2017/01/18/stn-part2/</guid>
        
        <category>deepmind</category>
        
        <category>google</category>
        
        <category>spatial transformer networks</category>
        
        <category>transformations</category>
        
        <category>affine</category>
        
        <category>linear</category>
        
        <category>bilinear interpolation</category>
        
        
      </item>
    
      <item>
        <title>Deep Learning Paper Implementations: Spatial Transformer Networks - Part I</title>
        <description>&lt;div class=&quot;imgcap&quot;&gt;
&lt;img src=&quot;/assets/stn/ai.jpg&quot; width=&quot;40%&quot; style=&quot;border:none;&quot; /&gt;&lt;div class=&quot;thecap&quot; style=&quot;text-align:center&quot;&gt;&lt;a href=&quot;https://www.technologyreview.com/s/601519/how-to-create-a-malevolent-artificial-intelligence/&quot;&gt;Image Courtesy&lt;/a&gt;&lt;/div&gt;
&lt;/div&gt;

&lt;p&gt;The first three blog posts in my “Deep Learning Paper Implementations” series will cover &lt;a href=&quot;https://arxiv.org/abs/1506.02025&quot;&gt;Spatial Transformer Networks&lt;/a&gt; introduced by &lt;em&gt;Max Jaderberg, Karen Simonyan, Andrew Zisserman and Koray Kavukcuoglu&lt;/em&gt; of Google Deepmind in 2016. The Spatial Transformer Network is a learnable module aimed at increasing the spatial invariance of Convolutional Neural Networks in a computationally and parameter efficient manner.&lt;/p&gt;

&lt;p&gt;In this first installment, we’ll be introducing two very important concepts that will prove crucial in understanding the inner workings of the Spatial Transformer layer. We’ll first start by examining a subset of image transformation techniques that fall under the umbrella of &lt;strong&gt;affine transformations&lt;/strong&gt;, and then dive into a procedure that commonly follows these transformations: &lt;strong&gt;bilinear interpolation&lt;/strong&gt;.&lt;/p&gt;

&lt;p&gt;In the second installment, we’ll be going over the Spatial Transformer Layer in detail and summarizing the paper, and then in the third and final part, we’ll be coding it from scratch in Tensorflow and applying it to the &lt;a href=&quot;http://benchmark.ini.rub.de/?section=gtsrb&amp;amp;subsection=news&quot;&gt;GTSRB dataset&lt;/a&gt; (German Traffic Sign Recognition Benchmark).&lt;/p&gt;

&lt;p&gt;For the full code that appears on this page, visit my &lt;a href=&quot;https://github.com/kevinzakka/blog-code/tree/master/spatial_transformer&quot;&gt;Github Repository&lt;/a&gt;.&lt;/p&gt;

&lt;h4 id=&quot;table-of-contents&quot;&gt;Table of Contents&lt;/h4&gt;

&lt;ul&gt;
  &lt;li&gt;&lt;a href=&quot;#toc1&quot;&gt;Image Transformations&lt;/a&gt;
    &lt;ul&gt;
      &lt;li&gt;&lt;a href=&quot;#toc2&quot;&gt;Scale&lt;/a&gt;&lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#toc3&quot;&gt;Rotate&lt;/a&gt;&lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#toc4&quot;&gt;Shear&lt;/a&gt;&lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#toc5&quot;&gt;Translate&lt;/a&gt;&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#toc6&quot;&gt;Bilinear Interpolation&lt;/a&gt;
    &lt;ul&gt;
      &lt;li&gt;&lt;a href=&quot;#toc7&quot;&gt;Motivation&lt;/a&gt;&lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#toc8&quot;&gt;Algorithm&lt;/a&gt;&lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#toc9&quot;&gt;Python Code&lt;/a&gt;&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#toc10&quot;&gt;Results&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#toc11&quot;&gt;Conclusion&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#toc12&quot;&gt;References&lt;/a&gt;&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;&lt;a name=&quot;toc1&quot;&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h3 id=&quot;image-transformations&quot;&gt;Image Transformations&lt;/h3&gt;

&lt;p&gt;To lay the groundwork for affine transformations, we first need to talk about linear transformations. To that end, we’ll be restricting ourselves to 2 dimensions and work with matrices.&lt;/p&gt;

&lt;p&gt;We define the following:&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;a point K with coordinates
&lt;script type=&quot;math/tex&quot;&gt;\begin{bmatrix}
  x \\
  y
\end{bmatrix}&lt;/script&gt; represented as a &lt;script type=&quot;math/tex&quot;&gt;(2\times1)&lt;/script&gt; column vector.&lt;/li&gt;
  &lt;li&gt;a matrix
&lt;script type=&quot;math/tex&quot;&gt;% &lt;![CDATA[
M=
\begin{bmatrix}
  a &amp; b \\
  c &amp; d
\end{bmatrix} %]]&gt;&lt;/script&gt; represented as a square matrix of shape &lt;script type=&quot;math/tex&quot;&gt;(2\times2)&lt;/script&gt;.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;and would like to examine the linear transformation &lt;script type=&quot;math/tex&quot;&gt;T&lt;/script&gt; defined by the matrix product &lt;script type=&quot;math/tex&quot;&gt;K' = T(K) = MK&lt;/script&gt; as we vary the parameters a, b, c and d of M.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Warm-Up Question.&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;Say we set &lt;script type=&quot;math/tex&quot;&gt;a = d = 1&lt;/script&gt; and &lt;script type=&quot;math/tex&quot;&gt;b = c = 0&lt;/script&gt; as follows:&lt;/p&gt;

&lt;script type=&quot;math/tex; mode=display&quot;&gt;% &lt;![CDATA[
M = \begin{bmatrix}
1 &amp; 0 \\
0 &amp; 1
\end{bmatrix} %]]&gt;&lt;/script&gt;

&lt;p&gt;In that case, what transform do you think we would obtain? Go ahead and give it a few moment’s thought…&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Solution.&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;Let’s write it out:&lt;/p&gt;

&lt;script type=&quot;math/tex; mode=display&quot;&gt;% &lt;![CDATA[
K' =  \begin{bmatrix}
1 &amp; 0 \\
0 &amp; 1
\end{bmatrix}
%
\begin{bmatrix}
x \\
y
\end{bmatrix} =
\begin{bmatrix}
x \\
y
\end{bmatrix} = K %]]&gt;&lt;/script&gt;

&lt;p&gt;We’ve actually represented the identity transform, meaning that the point K does not move in the plane. Let us now jump to more interesting transforms.&lt;/p&gt;

&lt;p&gt;&lt;a name=&quot;toc2&quot;&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Scaling.&lt;/strong&gt;&lt;/p&gt;

&lt;div class=&quot;imgcap&quot;&gt;
&lt;img src=&quot;/assets/stn/scale.png&quot; width=&quot;27%&quot; style=&quot;border:none;&quot; /&gt;&lt;div class=&quot;thecap&quot; style=&quot;text-align:center&quot;&gt;&lt;a href=&quot;https://people.cs.clemson.edu/~dhouse/courses/401/notes/affines-matrices.pdf&quot;&gt;Image Courtesy&lt;/a&gt;&lt;/div&gt;
&lt;/div&gt;

&lt;p&gt;We let &lt;script type=&quot;math/tex&quot;&gt;b = c = 0&lt;/script&gt;, and &lt;script type=&quot;math/tex&quot;&gt;a&lt;/script&gt; and &lt;script type=&quot;math/tex&quot;&gt;d&lt;/script&gt; take on any positive value.&lt;/p&gt;

&lt;script type=&quot;math/tex; mode=display&quot;&gt;% &lt;![CDATA[
M = \begin{bmatrix}
p &amp; 0 \\
0 &amp; q
\end{bmatrix} %]]&gt;&lt;/script&gt;

&lt;p&gt;Note that there is a special case of scaling called &lt;em&gt;isotropic&lt;/em&gt; scaling in which the scaling factor for both the x and y direction is the same, say &lt;script type=&quot;math/tex&quot;&gt;s&lt;/script&gt;. In that case, enlarging an image would correspond to &lt;script type=&quot;math/tex&quot;&gt;s &gt; 1&lt;/script&gt; while shrinking would correspond to &lt;script type=&quot;math/tex&quot;&gt;% &lt;![CDATA[
s &lt; 1 %]]&gt;&lt;/script&gt;. It’s a bit non-intuitive then that to zoom-in on an image, you need &lt;script type=&quot;math/tex&quot;&gt;% &lt;![CDATA[
s &lt; 1 %]]&gt;&lt;/script&gt; (think about it).&lt;/p&gt;

&lt;p&gt;Anyway, performing the matrix product, we obtain&lt;/p&gt;

&lt;script type=&quot;math/tex; mode=display&quot;&gt;% &lt;![CDATA[
K' =  \begin{bmatrix}
p &amp; 0 \\
0 &amp; q
\end{bmatrix}
%
\begin{bmatrix}
x \\
y
\end{bmatrix} =
\begin{bmatrix}
px \\
qy
\end{bmatrix} %]]&gt;&lt;/script&gt;

&lt;p&gt;&lt;a name=&quot;toc3&quot;&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Rotation.&lt;/strong&gt;&lt;/p&gt;

&lt;div class=&quot;imgcap&quot;&gt;
&lt;img src=&quot;/assets/stn/rot.png&quot; width=&quot;19%&quot; style=&quot;border:none;&quot; /&gt;&lt;div class=&quot;thecap&quot; style=&quot;text-align:center&quot;&gt;&lt;a href=&quot;https://people.cs.clemson.edu/~dhouse/courses/401/notes/affines-matrices.pdf&quot;&gt;Image Courtesy&lt;/a&gt;&lt;/div&gt;
&lt;/div&gt;

&lt;p&gt;Suppose we want to rotate by an angle &lt;script type=&quot;math/tex&quot;&gt;\theta&lt;/script&gt; about the origin. To do so, we set &lt;script type=&quot;math/tex&quot;&gt;a = d = \cos{\theta}&lt;/script&gt; and &lt;script type=&quot;math/tex&quot;&gt;b = c = \sin{\theta}&lt;/script&gt; as follows:&lt;/p&gt;

&lt;script type=&quot;math/tex; mode=display&quot;&gt;% &lt;![CDATA[
M = \begin{bmatrix}
\cos{\theta} &amp; -\sin{\theta} \\
\sin{\theta} &amp; \cos{\theta}
\end{bmatrix} %]]&gt;&lt;/script&gt;

&lt;p&gt;We thus obtain&lt;/p&gt;

&lt;script type=&quot;math/tex; mode=display&quot;&gt;% &lt;![CDATA[
K' =  \begin{bmatrix}
\cos{\theta} &amp; -\sin{\theta} \\
\sin{\theta} &amp; \cos{\theta}
\end{bmatrix}
%
\begin{bmatrix}
x \\
y
\end{bmatrix} =
\begin{bmatrix}
x\cos{\theta}- y\sin{\theta} \\
x\sin{\theta} + y\cos{\theta}
\end{bmatrix} %]]&gt;&lt;/script&gt;

&lt;p&gt;&lt;a name=&quot;toc4&quot;&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Shear.&lt;/strong&gt;&lt;/p&gt;

&lt;div class=&quot;imgcap&quot;&gt;
&lt;img src=&quot;/assets/stn/shear.png&quot; width=&quot;27%&quot; style=&quot;border:none;&quot; /&gt;&lt;div class=&quot;thecap&quot; style=&quot;text-align:center&quot;&gt;&lt;a href=&quot;https://people.cs.clemson.edu/~dhouse/courses/401/notes/affines-matrices.pdf&quot;&gt;Image Courtesy&lt;/a&gt;&lt;/div&gt;
&lt;/div&gt;

&lt;p&gt;When we shear an image, we offset the y direction by a distance proportional to x, and the x direction by a distance proportional to y. For example, when we go from normal text to italics, we are effectively applying a shear transform (think about shearing a deck of cards if that helps).&lt;/p&gt;

&lt;p&gt;To achieve shearing, we set &lt;script type=&quot;math/tex&quot;&gt;a = d = 1&lt;/script&gt;, &lt;script type=&quot;math/tex&quot;&gt;b = m&lt;/script&gt; and &lt;script type=&quot;math/tex&quot;&gt;c = n&lt;/script&gt; as follows:&lt;/p&gt;

&lt;script type=&quot;math/tex; mode=display&quot;&gt;% &lt;![CDATA[
M =  \begin{bmatrix}
1 &amp; m \\
n &amp; 1
\end{bmatrix} %]]&gt;&lt;/script&gt;

&lt;p&gt;This yields&lt;/p&gt;

&lt;script type=&quot;math/tex; mode=display&quot;&gt;% &lt;![CDATA[
K' =  \begin{bmatrix}
1 &amp; m \\
n &amp; 1
\end{bmatrix}
%
\begin{bmatrix}
x \\
y
\end{bmatrix} =
\begin{bmatrix}
x + my \\
y + nx
\end{bmatrix} %]]&gt;&lt;/script&gt;

&lt;hr /&gt;

&lt;p&gt;In summary, we have defined 3 basic linear transformations:&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;&lt;strong&gt;scaling:&lt;/strong&gt; scales the x and y direction by a scalar.&lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;shearing:&lt;/strong&gt; offsets the x by a number proportional to y and x by a number proportional to x.&lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;rotating:&lt;/strong&gt; rotates the points around the origin by an angle &lt;script type=&quot;math/tex&quot;&gt;\theta&lt;/script&gt;.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;Now the nice thing about matrices is that we can collapse sequential linear transformations into a single transformation matrix. For example, say we would like to apply a shear, a scale and then a rotation to our column vector K. Given that these transformations can be represented by the matrices &lt;script type=&quot;math/tex&quot;&gt;H&lt;/script&gt;, &lt;script type=&quot;math/tex&quot;&gt;S&lt;/script&gt; and &lt;script type=&quot;math/tex&quot;&gt;R&lt;/script&gt;, and respecting the order of transformations, we can write down this operation as&lt;/p&gt;

&lt;script type=&quot;math/tex; mode=display&quot;&gt;K' = R \big[ S \big( HK \big) \big]&lt;/script&gt;

&lt;p&gt;But recall that matrix multiplication is associative! So this reduces to&lt;/p&gt;

&lt;script type=&quot;math/tex; mode=display&quot;&gt;\boxed{K' = MK}&lt;/script&gt;

&lt;p&gt;where &lt;script type=&quot;math/tex&quot;&gt;M = RSH&lt;/script&gt;. Be mindful of the order since matrix multiplication &lt;script type=&quot;math/tex&quot;&gt;\color{red}{\text{is not}}&lt;/script&gt; commutative.&lt;/p&gt;

&lt;p&gt;A beautiful consequence of this formula is that if we are given multiple transformations to do for a very high-dimensional vector, then we can basically carry out a single matrix multiplication rather than repeatedly manipulating the high-dimensional vector for every sequential transformation.&lt;/p&gt;

&lt;hr /&gt;

&lt;p&gt;&lt;a name=&quot;toc5&quot;&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Translation.&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;The only downside to this &lt;script type=&quot;math/tex&quot;&gt;2 \times 2&lt;/script&gt; matrix representation is that we cannot represent translation since it isn’t a linear transformation. Translation however, is a very important and needed transformation, so we would like to be able to encapsulate it in our matrix representation.&lt;/p&gt;

&lt;p&gt;To solve this dilemna, we represent our 2D vectors in 3D using &lt;strong&gt;homogeneous coordinates&lt;/strong&gt; as follows:&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;our point K becomes a &lt;script type=&quot;math/tex&quot;&gt;(3\times1)&lt;/script&gt; column vector
&lt;script type=&quot;math/tex&quot;&gt;\begin{bmatrix}
  x \\
  y \\
  1
\end{bmatrix}&lt;/script&gt;&lt;/li&gt;
  &lt;li&gt;our matrix M becomes a &lt;script type=&quot;math/tex&quot;&gt;(3\times3)&lt;/script&gt; square matrix
&lt;script type=&quot;math/tex&quot;&gt;% &lt;![CDATA[
M=
\begin{bmatrix}
  a &amp; b &amp; 0 \\
  c &amp; d &amp; 0 \\
  0 &amp; 0 &amp; 1
\end{bmatrix} %]]&gt;&lt;/script&gt;&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;To represent a translation, all we have to do is place 2 new parameters &lt;script type=&quot;math/tex&quot;&gt;e&lt;/script&gt; and &lt;script type=&quot;math/tex&quot;&gt;f&lt;/script&gt; in our third column like so&lt;/p&gt;

&lt;script type=&quot;math/tex; mode=display&quot;&gt;% &lt;![CDATA[
M=
  \begin{bmatrix}
    a &amp; b &amp; e \\
    c &amp; d &amp; f \\
    0 &amp; 0 &amp; 1
  \end{bmatrix} %]]&gt;&lt;/script&gt;

&lt;p&gt;and we can thus carry out translations as linear transformations in homogeneous coordinates. Note that if we require a 2D output, then all we need to do is represent M as a &lt;script type=&quot;math/tex&quot;&gt;2 \times 3&lt;/script&gt; matrix and leave K untouched.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Example.&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;Translate both the x and y direction by &lt;script type=&quot;math/tex&quot;&gt;\Delta&lt;/script&gt;. Result should be 2D.&lt;/p&gt;

&lt;script type=&quot;math/tex; mode=display&quot;&gt;% &lt;![CDATA[
K' =  \begin{bmatrix}
1 &amp; 0 &amp; \Delta \\
0 &amp; 1 &amp; \Delta
\end{bmatrix}
%
\begin{bmatrix}
x \\
y \\
1
\end{bmatrix} =
\begin{bmatrix}
x + \Delta \\
y + \Delta
\end{bmatrix} %]]&gt;&lt;/script&gt;

&lt;p&gt;&lt;strong&gt;Summary.&lt;/strong&gt;&lt;/p&gt;

&lt;div class=&quot;imgcap&quot;&gt;
&lt;img src=&quot;/assets/stn/affine.png&quot; width=&quot;40%&quot; style=&quot;border:none;&quot; /&gt;&lt;div class=&quot;thecap&quot; style=&quot;text-align:center&quot;&gt;&lt;a href=&quot;https://people.cs.clemson.edu/~dhouse/courses/401/notes/affines-matrices.pdf&quot;&gt;Image Courtesy&lt;/a&gt;&lt;/div&gt;
&lt;/div&gt;

&lt;p&gt;By using a little trick, we were able to add a new transformation to our repertoire of linear transformations. This transformation, called translation, is an affine transformation. Hence, we can generalize our results and represent our 4 affine transformations (all linear transformations are affine) by the 6 parameter matrix&lt;/p&gt;

&lt;script type=&quot;math/tex; mode=display&quot;&gt;% &lt;![CDATA[
M=
  \begin{bmatrix}
    a &amp; b &amp; c \\
    d &amp; e &amp; f
  \end{bmatrix} %]]&gt;&lt;/script&gt;

&lt;p&gt;&lt;a name=&quot;toc6&quot;&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h3 id=&quot;bilinear-interpolation&quot;&gt;Bilinear Interpolation&lt;/h3&gt;

&lt;p&gt;&lt;a name=&quot;toc7&quot;&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Motivation.&lt;/strong&gt; When an image undergoes an affine transformation such as a rotation or scaling, the pixels in the image get moved around. This can be especially problematic when a pixel location in the output does not map directly to one in the input image.&lt;/p&gt;

&lt;p&gt;In the illustration below, you can clearly see that the rotation places some points at locations that are not centered in the squares. This means that they would not have a corresponding pixel value in the original image.&lt;/p&gt;

&lt;div class=&quot;imgcap&quot;&gt;
&lt;img src=&quot;/assets/stn/stickman.png&quot; width=&quot;70%&quot; style=&quot;border:none;&quot; /&gt;&lt;div class=&quot;thecap&quot; style=&quot;text-align:center&quot;&gt;&lt;a href=&quot;http://northstar-www.dartmouth.edu/doc/idl/html_6.2/Interpolation_Methods.html&quot;&gt;Image Courtesy&lt;/a&gt;&lt;/div&gt;
&lt;/div&gt;

&lt;p&gt;So for example, suppose that after rotating an image, we need to find the pixel value at the location (6.7, 3.2). The problem with this is that there is no such thing as fractional pixel locations.&lt;/p&gt;

&lt;p&gt;To solve this problem, bilinear interpolation uses the 4 nearest pixel values which are located in diagonal directions from a given location in order to find the appropriate color intensity values of that pixel. The result is smoother and more realistic images!&lt;/p&gt;

&lt;p&gt;&lt;a name=&quot;toc8&quot;&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Algorithm.&lt;/strong&gt;&lt;/p&gt;

&lt;div class=&quot;imgcap&quot;&gt;
&lt;img src=&quot;/assets/stn/interpol.png&quot; width=&quot;35%&quot; style=&quot;border:none;&quot; /&gt;&lt;div class=&quot;thecap&quot; style=&quot;text-align:center&quot;&gt;&lt;a href=&quot;https://en.wikipedia.org/wiki/Bilinear_interpolation&quot;&gt;Image Courtesy&lt;/a&gt;&lt;/div&gt;
&lt;/div&gt;

&lt;p&gt;Our goal is to find the pixel value of the point P. To do so, we calculate the pixel value of &lt;script type=&quot;math/tex&quot;&gt;R_1&lt;/script&gt; and &lt;script type=&quot;math/tex&quot;&gt;R_2&lt;/script&gt; using a weighted average of &lt;script type=&quot;math/tex&quot;&gt;(Q_{11}, Q_{21})&lt;/script&gt; and &lt;script type=&quot;math/tex&quot;&gt;(Q_{12}, Q_{22})&lt;/script&gt; respectively. Then, we use a weighted average of &lt;script type=&quot;math/tex&quot;&gt;R_2&lt;/script&gt; and &lt;script type=&quot;math/tex&quot;&gt;R_1&lt;/script&gt; to find the value of P.&lt;/p&gt;

&lt;p&gt;Effectively, we are interpolating in the x direction and then the y direction, hence the name bilinear interpolation. You could just as well flip the order of interpolation and get the exact same value.&lt;/p&gt;

&lt;p&gt;So given a point &lt;script type=&quot;math/tex&quot;&gt;P = (x, y)&lt;/script&gt; and 4 corner coordinates &lt;script type=&quot;math/tex&quot;&gt;Q_{11} = (x_1, y_1)&lt;/script&gt;, &lt;script type=&quot;math/tex&quot;&gt;Q_{21} = (x_2, y_1)&lt;/script&gt;, &lt;script type=&quot;math/tex&quot;&gt;Q_{12} = (x_1, y_2)&lt;/script&gt; and &lt;script type=&quot;math/tex&quot;&gt;Q_{22} = (x_2, y_2)&lt;/script&gt;, we first interpolate in the x-direction:&lt;/p&gt;

&lt;script type=&quot;math/tex; mode=display&quot;&gt;R_1 = \frac{x_2 - x}{x_2 - x_1}Q_{11} + \frac{x - x_1}{x_2 - x_1}Q_{21}&lt;/script&gt;

&lt;script type=&quot;math/tex; mode=display&quot;&gt;R_2 = \frac{x_2 - x}{x_2 - x_1}Q_{12} + \frac{x - x_1}{x_2 - x_1}Q_{22}&lt;/script&gt;

&lt;p&gt;and finally in the y-direction:&lt;/p&gt;

&lt;script type=&quot;math/tex; mode=display&quot;&gt;\boxed{P = \frac{y_2 - y}{y_2 - y_1}R_1 + \frac{y - y_1}{y_2 - y_1}R_2}&lt;/script&gt;

&lt;p&gt;&lt;a name=&quot;toc9&quot;&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Python Code.&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;One very very important note before we jump into the code!&lt;/p&gt;

&lt;hr /&gt;

&lt;p&gt;An image processing affine transformation usually follows the 3-step pipeline below:&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;First, we create a sampling grid composed of &lt;script type=&quot;math/tex&quot;&gt;(x, y)&lt;/script&gt; coordinates. For example, given a 400x400 grayscale image, we create a meshgrid of same dimension, that is, evenly spaced &lt;script type=&quot;math/tex&quot;&gt;x \in [0, W]&lt;/script&gt; and &lt;script type=&quot;math/tex&quot;&gt;y \in [0, H]&lt;/script&gt;.&lt;/li&gt;
  &lt;li&gt;We then apply the transformation matrix to the sampling grid generated in the step above.&lt;/li&gt;
  &lt;li&gt;Finally, we sample the resulting grid from the original image using the desired interpolation technique.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;As you can see, this is different than directly applying a transform to the original image.&lt;/p&gt;

&lt;hr /&gt;

&lt;p&gt;I’ve attached 2 cat images in the Github Repository mentioned at the top of this page which you should go ahead and download. Save them to your Desktop in a folder called &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;data/&lt;/code&gt; or make sure to update the path location if you choose differently.&lt;/p&gt;

&lt;p&gt;I’ve also written a function &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;load_img()&lt;/code&gt; that converts images to numpy arrays. I won’t go into its details but it’s pretty basic and you shouldn’t take long to understand what it does. Note that you’ll need both PIL and Numpy to reproduce the results below.&lt;/p&gt;

&lt;p&gt;Armed with this function, let’s load both cat images and concatenate them into a single input array. We’re working with 2 images because we want to make our code as general as possible.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;numpy&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;PIL&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Image&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# params
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;DIMS&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;400&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;400&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;CAT1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;'cat1.jpg'&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;CAT2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;'cat2.jpg'&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# load both cat images
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;img1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;load_img&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;CAT1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;DIMS&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;img2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;load_img&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;CAT2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;DIMS&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;view&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# concat into tensor of shape (2, 400, 400, 3)
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;input_img&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;concatenate&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;img1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;img2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;axis&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# dimension sanity check
&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;Input Img Shape: {}&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;format&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;input_img&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Given that we have 2 images, our batch size is equal to 2. This means that we need an equal amount of transformation matrices M for each image in the batch.&lt;/p&gt;

&lt;p&gt;Let’s go ahead and initialize 2 identity transform matrices. This is the simplest case, and if we implement our bilinear sampler correctly, we should expect our output image to be almost exact to the input image.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# grab shape
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_batch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;H&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;W&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;C&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;input_img&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# initialize M to identity transform
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;M&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([[&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;1.&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;1.&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]])&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# repeat num_batch times
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;M&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;resize&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;M&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_batch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;(Recall that our general affine transformation matrix is &lt;script type=&quot;math/tex&quot;&gt;2 \times 3&lt;/script&gt; if we want to include translation.)&lt;/p&gt;

&lt;p&gt;Now we need to write a function that will generate a meshgrid for us and output a sampling grid resulting from the product of this meshgrid and our transformation matrix M.&lt;/p&gt;

&lt;p&gt;Let’s go ahead and generate our meshgrid. We’ll create a normalized one, that is the values of x and y range from -1 to 1 and there are &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;width&lt;/code&gt; and &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;height&lt;/code&gt; of them respectively. In fact, note that for images, x corresponds to the width of the image (i.e. number of columns of the matrix) while y corresponds to the height of the image (i.e. number of rows of the matrix).&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# create normalized 2D grid
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;linspace&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;W&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;y&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;linspace&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;H&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;x_t&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y_t&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;meshgrid&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Then we need to augment the dimensions to create homogeneous coordinates.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# reshape to (xt, yt, 1)
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ones&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ones&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;prod&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x_t&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;sampling_grid&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;vstack&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x_t&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;flatten&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y_t&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;flatten&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;ones&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;So we’ve created 1 grid here, but we need &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;num_batch&lt;/code&gt; grids. Same as above, our one-liner below repeats our array &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;num_batch&lt;/code&gt; times.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# repeat grid num_batch times
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sampling_grid&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;resize&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sampling_grid&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_batch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;H&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;W&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Now we perform step 2 of our image transformation pipeline.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# transform the sampling grid i.e. batch multiply
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;batch_grids&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;matmul&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;M&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;sampling_grid&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;c1&quot;&gt;# batch grid has shape (num_batch, 2, H*W)
&lt;/span&gt;
&lt;span class=&quot;c1&quot;&gt;# reshape to (num_batch, height, width, 2)
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;batch_grids&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;batch_grids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;reshape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_batch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;H&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;W&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;batch_grids&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;moveaxis&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;batch_grids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Finally, let’s write our bilinear sampler. Given our coordinates &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;x&lt;/code&gt; and &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;y&lt;/code&gt; in the sampling grid, we want interpolate the pixel value in the original image.&lt;/p&gt;

&lt;p&gt;Let’s start by seperating the x and y dimensions and rescaling them to belong in the height/width interval.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;x_s&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;batch_grids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[:,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;].&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;squeeze&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;y_s&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;batch_grids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[:,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;].&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;squeeze&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# rescale x and y to [0, W/H]
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x_s&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;1.&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;W&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.5&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;y&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y_s&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;1.&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;H&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.5&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Now for each coordinate &lt;script type=&quot;math/tex&quot;&gt;(x_i, y_i)&lt;/script&gt; we want to grab 4 corner coordinates.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# grab 4 nearest corner points for each (x_i, y_i)
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x0&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;floor&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;astype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;int64&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;x1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x0&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;y0&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;floor&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;astype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;int64&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;y1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y0&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;(Note that we could just as well use the ceiling function rather than the increment by 1).&lt;/p&gt;

&lt;p&gt;Now we must make sure that no value goes beyond the image boundaries. For example, suppose we have &lt;script type=&quot;math/tex&quot;&gt;x = 399&lt;/script&gt;, then &lt;script type=&quot;math/tex&quot;&gt;x_0 = 399&lt;/script&gt; and &lt;script type=&quot;math/tex&quot;&gt;x_1 = x0 + 1 = 400&lt;/script&gt; which would result in a numpy error. Thus we clip our corner coordinates in the following way:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# make sure it's inside img range [0, H] or [0, W]
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x0&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;clip&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;W&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;x1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;clip&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;W&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;y0&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;clip&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;H&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;y1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;clip&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;H&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Now we use advanced numpy indexing to grab the pixel value for each corner coordinate. These correspond to &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;(x0, y0)&lt;/code&gt;, &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;(x0, y1)&lt;/code&gt;, &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;(x1, y0)&lt;/code&gt; and &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;(x_1, y_1)&lt;/code&gt;.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# look up pixel values at corner coords
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Ia&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;input_img&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;arange&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_batch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)[:,&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;Ib&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;input_img&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;arange&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_batch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)[:,&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;Ic&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;input_img&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;arange&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_batch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)[:,&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;Id&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;input_img&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;arange&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_batch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)[:,&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Almost there! Now, we calculate the weight coefficients,&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# calculate deltas
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;wa&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x1&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y1&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;wb&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x1&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;wc&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y1&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;wd&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;and finally, multiply and add according to the formula mentioned previously.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# add dimension for addition
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;wa&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;expand_dims&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;wa&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;axis&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;wb&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;expand_dims&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;wb&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;axis&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;wc&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;expand_dims&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;wc&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;axis&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;wd&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;expand_dims&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;wd&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;axis&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# compute output
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;out&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;wa&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Ia&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;wb&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Ib&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;wc&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Ic&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;wd&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Id&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;hr /&gt;

&lt;p&gt;&lt;a name=&quot;toc10&quot;&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h3 id=&quot;results&quot;&gt;Results&lt;/h3&gt;

&lt;p&gt;So now that we’ve gone through the whole code incrementally, let’s have some fun and experiment with different values of the transformation matrix M.&lt;/p&gt;

&lt;p&gt;The first thing you need to do is copy and paste the whole code which has been made more modular. Now let’s test if our function works correctly.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Identity Transform.&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;Add the following 2 lines as the end of the script and execute.&lt;/p&gt;

&lt;div class=&quot;language-plaintext highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;plt.imshow(out[1])
plt.show()
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p align=&quot;center&quot;&gt;
&lt;img src=&quot;/assets/stn/bef1.png&quot; width=&quot;200&quot; /&gt; &lt;img src=&quot;/assets/stn/aft1.png&quot; width=&quot;300&quot; /&gt;
&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Translation.&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;Say we want to translate the picture by &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;0.5&lt;/code&gt; only in the x direction. This should shift the image to the left.&lt;/p&gt;

&lt;p&gt;Edit the following line of your code as follows:&lt;/p&gt;

&lt;div class=&quot;language-plaintext highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;M = np.array([[1., 0., 0.5], [0., 1., 0.]])
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p align=&quot;center&quot;&gt;
&lt;img src=&quot;/assets/stn/bef1.png&quot; width=&quot;200&quot; /&gt; &lt;img src=&quot;/assets/stn/aft2.png&quot; width=&quot;300&quot; /&gt;
&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Rotation.&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;Finally, say we want to rotate the picture by &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;45&lt;/code&gt; degrees. Given that &lt;script type=&quot;math/tex&quot;&gt;\cos{(45)} = \sin{(45)} = \frac{\sqrt{2}}{2} \approx 0.707&lt;/script&gt;, edit just this line of your code as follows:&lt;/p&gt;

&lt;div class=&quot;language-plaintext highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;M = np.array([[0.707, -0.707, 0.], [0.707, 0.707, 0.]])
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p align=&quot;center&quot;&gt;
&lt;img src=&quot;/assets/stn/bef1.png&quot; width=&quot;200&quot; /&gt; &lt;img src=&quot;/assets/stn/aft3.png&quot; width=&quot;300&quot; /&gt;
&lt;/p&gt;

&lt;p&gt;&lt;a name=&quot;toc11&quot;&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h3 id=&quot;conclusion&quot;&gt;Conclusion&lt;/h3&gt;

&lt;p&gt;In this blog post, we went over basic linear transformations such as rotation, shear and scale before generalizing to affine transformations which included translations. Then, we saw the importance of bilinear interpolation in the context of these transformations. Finally, we went over the algorithm, coded it from scratch in Python and wrote 2 methods that helped us visualize these transformations according to a 3 step image processing pipeline.&lt;/p&gt;

&lt;p&gt;In the next installment of this series, we’ll go over the Spatial Transformer Network layer in detail as well as summarize the paper it is described in.&lt;/p&gt;

&lt;p&gt;See you next week!&lt;/p&gt;

&lt;p&gt;&lt;a name=&quot;toc12&quot;&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h3 id=&quot;references&quot;&gt;References&lt;/h3&gt;

&lt;p&gt;A big thank you to &lt;a href=&quot;https://twitter.com/edersantana&quot;&gt;Eder Santana&lt;/a&gt; for introducing me to this paper!&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;&lt;a href=&quot;https://en.wikipedia.org/wiki/Bilinear_interpolation&quot;&gt;Bilinear Interpolation Wikipedia&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;http://supercomputingblog.com/graphics/coding-bilinear-interpolation/&quot;&gt;Bilinear Interpolation&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;https://people.cs.clemson.edu/~dhouse/courses/401/notes/affines-matrices.pdf&quot;&gt;Matrix Transformations PDF&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;http://stackoverflow.com/questions/12729228/simple-efficient-bilinear-interpolation-of-images-in-numpy-and-python&quot;&gt;Bilinear Interpolation Code&lt;/a&gt;&lt;/li&gt;
&lt;/ul&gt;
</description>
        <pubDate>Tue, 10 Jan 2017 00:00:00 +0000</pubDate>
        <link>http://kevinzakka.github.io/2017/01/10/stn-part1/</link>
        <guid isPermaLink="true">http://kevinzakka.github.io/2017/01/10/stn-part1/</guid>
        
        <category>deepmind</category>
        
        <category>google</category>
        
        <category>spatial transformer networks</category>
        
        <category>transformations</category>
        
        <category>affine</category>
        
        <category>linear</category>
        
        <category>bilinear interpolation</category>
        
        
      </item>
    
      <item>
        <title>Nuts and Bolts of Applying Deep Learning</title>
        <description>&lt;div class=&quot;imgcap&quot;&gt;
&lt;img src=&quot;/assets/app_dl/bolts.jpg&quot; width=&quot;40%&quot; style=&quot;border:none;&quot; /&gt;&lt;div class=&quot;thecap&quot; style=&quot;text-align:center&quot;&gt;&lt;a href=&quot;http://nutsandbolts.mit.edu/&quot;&gt;Image Courtesy&lt;/a&gt;&lt;/div&gt;
&lt;/div&gt;

&lt;p&gt;This weekend was very hectic (catching up on courses and studying for a statistics quiz), but I managed to squeeze in some time to watch the &lt;a href=&quot;http://www.bayareadlschool.org/&quot;&gt;Bay Area Deep Learning School&lt;/a&gt; livestream on YouTube. For those of you wondering what that is, BADLS is a 2-day conference hosted at Stanford University, and consisting of back-to-back presentations on a variety of topics ranging from NLP, Computer Vision, Unsupervised Learning and Reinforcement Learning. Additionally, top DL software libraries were presented such as Torch, Theano and Tensorflow.&lt;/p&gt;

&lt;p&gt;There were some super interesting talks from leading experts in the field: &lt;a href=&quot;http://www.dmi.usherb.ca/~larocheh/index_en.html&quot;&gt;Hugo Larochelle&lt;/a&gt; from Twitter, &lt;a href=&quot;http://cs.stanford.edu/people/karpathy/&quot;&gt;Andrej Karpathy&lt;/a&gt; from OpenAI, &lt;a href=&quot;http://www.iro.umontreal.ca/~bengioy/yoshua_en/index.html&quot;&gt;Yoshua Bengio&lt;/a&gt; from the Université de Montreal, and &lt;a href=&quot;http://www.andrewng.org/&quot;&gt;Andrew Ng&lt;/a&gt; from Baidu to name a few. Of the plethora of presentations, there was one somewhat non-technical one given by Andrew that really piqued my interest.&lt;/p&gt;

&lt;p&gt;In this blog post, I’m gonna try and give an overview of the main ideas outlined in his talk. The goal is to pause a bit and examine the ongoing trends in Deep Learning thus far, as well as gain some insight into applying DL in practice.&lt;/p&gt;

&lt;p&gt;By the way, if you missed out on the livestreams, you can still view them at the following: &lt;a href=&quot;https://www.youtube.com/watch?v=eyovmAtoUx0&quot;&gt;Day 1&lt;/a&gt; and &lt;a href=&quot;https://www.youtube.com/watch?v=9dXiAecyJrY&quot;&gt;Day 2&lt;/a&gt;.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Table of Contents&lt;/strong&gt;:&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;&lt;a href=&quot;#toc1&quot;&gt;Major Deep Learning Trends&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#toc2&quot;&gt;End-to-End Deep Learning&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#toc3&quot;&gt;Bias-Variance Tradeoff&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#toc4&quot;&gt;Human-level Performance&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#toc5&quot;&gt;Personal Advice&lt;/a&gt;&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;&lt;a name=&quot;toc1&quot;&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h3 id=&quot;major-deep-learning-trends&quot;&gt;Major Deep Learning Trends&lt;/h3&gt;

&lt;p&gt;&lt;strong&gt;Why do DL algorithms work so well?&lt;/strong&gt; According to Ng, with the rise of the Internet, Mobile and IOT era, the amount of data accessible to us has greatly increased. This correlates directly to a boost in the performance of neural network models, especially the larger ones which have the capacity to absorb all this data.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;
 &lt;img src=&quot;/assets/app_dl/perf_vs_data.png&quot; width=&quot;450&quot; /&gt;
&lt;/p&gt;

&lt;p&gt;However, in the small data regime (left-hand side of the x-axis), the relative ordering of the algorithms is not that well defined and really depends on who is more motivated to engineer their features better, or refine and tune the hyperparameters of their model.&lt;/p&gt;

&lt;p&gt;Thus this trend is more prevalent in the big data realm where hand engineering effectively gets replaced by end-to-end approaches and bigger neural nets combined with a lot of data tend to outperform all other models.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Machine Learning and HPC team.&lt;/strong&gt; The rise of big data and the need for larger models has started to put pressure on companies to hire a Computer Systems team. This is because some of the HPC (high-performance computing) applications require highly specialized knowledge and it is difficult to find researchers and engineers with sufficient knowledge in both fields. Thus, cooperation from both teams is the key to boosting performance in AI companies.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Categorizing DL models.&lt;/strong&gt; Work in DL can be categorized in the following 4 buckets:&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;
 &lt;img src=&quot;/assets/app_dl/bucket.svg&quot; width=&quot;350&quot; /&gt;
&lt;/p&gt;

&lt;p&gt;Most of the value in the industry today is driven by the models in the orange blob (innovation and monetization mostly) but Andrew believes that &lt;strong&gt;unsupervised deep learning&lt;/strong&gt; is a super-exciting field that has loads of potential for the future.&lt;/p&gt;

&lt;p&gt;&lt;a name=&quot;toc2&quot;&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h3 id=&quot;the-rise-of-end-to-end-dl&quot;&gt;The rise of End-to-End DL&lt;/h3&gt;

&lt;p&gt;A major improvement in the end-to-end approach has been the fact that outputs are becoming more and more complicated. For example, rather than just outputting a simple class score such as 0 or 1, algorithms are starting to generate richer outputs: images like in the case of GAN’s, full captions with RNN’s and most recently, audio like in DeepMind’s WaveNet.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;So what exactly does end-to-end training mean?&lt;/strong&gt; Essentially, it means that AI practitioners are shying away from intermediate representations and going directly from one end (raw input) to the other end (output) Here’s an example from speech recognition.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;
 &lt;img src=&quot;/assets/app_dl/end-to-end.svg&quot; width=&quot;340&quot; /&gt;
&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Are there any disadvantages to this approach?&lt;/strong&gt; End-to-end approaches are data hungry meaning they only perform well when provided with a huge dataset of labelled examples. In practice, not all applications have the luxury of large labelled datasets so other approaches which allow hand-engineered information and field expertise to be added into the model have gained the upper hand. As an example, in a self-driving car setting, going directly from the raw image to the steering direction is pretty difficult. Rather, many features such as trajectory and pedestrian location are calculated first as intermediate steps.&lt;/p&gt;

&lt;p&gt;The main take-away from this section is that we should always be cautious of end-to-end approaches in applications where huge data is hard to come by.&lt;/p&gt;

&lt;p&gt;&lt;a name=&quot;toc3&quot;&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h3 id=&quot;bias-variance-tradeoff&quot;&gt;Bias-Variance Tradeoff&lt;/h3&gt;

&lt;p&gt;&lt;strong&gt;Splitting your data.&lt;/strong&gt; In most deep learning problems, train and test come from different distributions. For example, suppose you are working on implementing an AI powered rearview mirror and have gathered 2 chunks of data: the first, larger chunk comes from many places (could be partly bought, and partly crowdsourced) and the second, much smaller chunk is actual car data.&lt;/p&gt;

&lt;p&gt;In this case, splitting the data into train/dev/test can be tricky. One might be tempted to carve the dev set out of the training chunk like in the first example of the diagram below. (Note that the chunk on the left corresponds to data mined from the first distribution and the one on the right to the one from the second distribution.)&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;
 &lt;img src=&quot;/assets/app_dl/split.svg&quot; width=&quot;500&quot; /&gt;
&lt;/p&gt;

&lt;p&gt;This is bad because we usually want our dev and test to come from the same distribution. The reason for this is that because a part of the team will be spending a lot of time tuning the model to work well on the dev set, if the test set were to turn out very different from the dev set, then pretty much all the work would have been wasted effort.&lt;/p&gt;

&lt;p&gt;Hence, a smarter way of splitting the above dataset would be just like the second line of the diagram. Now in practice, Andrew recommends creating dev sets from both data distributions: a train-dev and test-dev set. In this manner, any gap between the different errors can help you tackle the problem more clearly.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;
 &lt;img src=&quot;/assets/app_dl/errors.svg&quot; width=&quot;450&quot; /&gt;
&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Flowchart for working with a model.&lt;/strong&gt; Given what we have described above, here’s a simplified flowchart of the actions you should take when confronted with training/tuning a DL model.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;
 &lt;img src=&quot;/assets/app_dl/flowachart.svg&quot; width=&quot;500&quot; /&gt;
&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;The importance of data synthesis.&lt;/strong&gt; Andrew also stressed the importance of data synthesis as part of any workflow in deep learning. While it may be painful to manually engineer training examples, the relative gain in performance you obtain once the parameters and the model fit well are huge and worth your while.&lt;/p&gt;

&lt;p&gt;&lt;a name=&quot;toc4&quot;&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h3 id=&quot;human-level-performance&quot;&gt;Human-level Performance&lt;/h3&gt;

&lt;p&gt;One of the very important concepts underlined in this lecture was that of human-level performance. In the basic setting, DL models tend to plateau once they have reached or surpassed human-level accuracy. While it is important to note that human-level performance doesn’t necessarily coincide with the golden bayes error rate, it can serve as a very reliable proxy which can be leveraged to determine your next move when training your model.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;
 &lt;img src=&quot;/assets/app_dl/perf.png&quot; width=&quot;550&quot; /&gt;
&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Reasons for the plateau.&lt;/strong&gt; There could be a theoretical limit on the dataset which makes further improvement futile (i.e. a noisy subset of the data). Humans are also very good at these tasks so trying to make progress beyond that suffers from diminishing returns.&lt;/p&gt;

&lt;p&gt;Here’s an example that can help illustrate the usefulness of human-level accuracy. Suppose you are working on an image recognition task and measure the following:&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;&lt;strong&gt;Train error&lt;/strong&gt;: 8%&lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;Dev Error&lt;/strong&gt;: 10%&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;If I were to tell you that human accuracy for such a task is on the order of 1%, then this would be a blatant bias problem and you could subsequently try increasing the size of your model, train longer etc. However, if I told you that human-level accuracy was on the order of 7.5%, then this would be more of a variance problem and you’d focus your efforts on methods such as data synthesis or gathering data more similar to the test.&lt;/p&gt;

&lt;p&gt;By the way, there’s always room for improvement. Even if you are close to human-level accuracy overall, there could be subsets of the data where you perform poorly and working on those can boost production performance greatly.&lt;/p&gt;

&lt;p&gt;Finally, one might ask what is a good way of defining human-level accuracy. For example, in the following image diagnosis setting, ignoring the cost of obtaining data, how should one pick the criteria for human-level accuracy?&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;&lt;strong&gt;typical human&lt;/strong&gt;: 5%&lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;general doctor&lt;/strong&gt;: 1%&lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;specialized doctor&lt;/strong&gt;: 0.8%&lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;group of specialized doctor&lt;/strong&gt;s: 0.5%&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;The answer is always the best accuracy possible. This is because, as we mentioned earlier, human-level performance is a proxy for the bayes optimal error rate, so providing a more accurate upper bound to your performance can help you strategize your next move.&lt;/p&gt;

&lt;p&gt;&lt;a name=&quot;toc5&quot;&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h3 id=&quot;personal-advice&quot;&gt;Personal Advice&lt;/h3&gt;

&lt;p&gt;Andrew ended the presentation with 2 ways one can improve his/her skills in the field of deep learning.&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;&lt;strong&gt;Practice, Practice, Practice&lt;/strong&gt;: compete in Kaggle competitions and read associated blog posts and forum discussions.&lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;Do the Dirty Work&lt;/strong&gt;: read a lot of papers and try to replicate the results. Soon enough, you’ll get your own ideas and build your own models.&lt;/li&gt;
&lt;/ul&gt;
</description>
        <pubDate>Mon, 26 Sep 2016 00:00:00 +0000</pubDate>
        <link>http://kevinzakka.github.io/2016/09/26/applying-deep-learning/</link>
        <guid isPermaLink="true">http://kevinzakka.github.io/2016/09/26/applying-deep-learning/</guid>
        
        <category>deep learning</category>
        
        <category>bias</category>
        
        <category>variance</category>
        
        <category>advice</category>
        
        <category>end-to-end</category>
        
        <category>machine learning</category>
        
        
      </item>
    
  </channel>
</rss>
