3 NIPS Papers We Loved

Jeremy Stanley
tech-at-instacart
Published in
12 min readDec 14, 2017

--

Know your model’s limits, interpret it’s behavior and learn from variable length sets.

One of two “breakout sessions” with presenter and GIANT screen for scale.

At NIPS 2017 what surprised me the most was not the size of the crowds (they were huge), the extravagance of the parties (I sleep early) or the controversy of the “rigor police” debate (it was entertaining).

No, what surprised me the most was the number of papers I saw that (when combined with talks and posters) were both relatively easy to understand and of immediate practical use.

In this post, I will briefly explain three of our favorites:

  1. Knowing your model’s limits
    Simple and Scalable Predictive Uncertainty Estimation using Deep Ensembles
    Lakshminarayanan et. al 2017, paper & video (1:00:10)
  2. Interpreting model behavior
    A Unified Approach to Interpreting Model Predictions
    Lundberg et al. 2017, paper, video (17:45) & github
  3. Learning from variable length sets
    Deep Sets
    Zaheer et al. 2017, paper & video (16:00)

I’d like to extend a huge thank-you to Balaji Lakshminarayanan, Scott Lundberg, Manzil Zaheer and their co-authors for doing this work and presenting at NIPS. Their cogent presentations and detailed answers to my many questions at their poster sessions enabled and inspired this post.

Knowing your model’s limits

Lakshminarayanan et. al 2017, paper & video (1:00:10)

Deep learning models can be surprisingly brittle. They can fail to generalize on data drawn from slightly different distributions and can give very different predictions given minor changes in the learning algorithm or initialization.

This begs the question — can we know when our deep learning models are uncertain about their predictions?

If so, this would help in many applications at Instacart, such as:

  • How uncertain are we about an item being in stock at a store location?
  • How much risk is there in a grocery delivery being late?
  • Is there a chance we should explore showing a rare item for a search?
  • What range of delivery demand should we anticipate at a store location?

In particular, anytime you make a decision based upon many noisy predictions, you risk favoring observations with large noise values (common in ranking for search or ads, or in optimization for pricing or logistics applications). Ensuring you control for prediction uncertainty to avoid this effect can be important.

Other methods can be used to quantify uncertainty, but have drawbacks. For example, bayesian methods require assumptions about priors and are computationally expensive.

This paper provides an elegant method to quantify the uncertainty in deep learning models:

Lakshminarayanan et. al 2017 (video)

In practice you:

  • Choose a distribution for your output (gaussian if you are optimizing for MSE, poisson for counts, etc.)
  • Change the final layer in your deep network to output a variance estimate (or other distribution parameters) in addition to an estimate for the mean
  • Minimize the negative log-likelihood for the output distribution (e.g., with a custom loss function in Keras)
  • Train M networks in this way, each with a different random initialization
  • Let your final predicted distribution be the evenly weighted mixture of distributions from the M networks

While the paper also adds adversarial training (hard to implement for discrete inputs), some of their experiments showed that this was less important.

What is critical is that your network must produce an estimate of mean and variance, and then optimize the negative log likelihood loss function. If you assume your errors are gaussian distributed, then your loss function is:

Lakshminarayanan et. al 2017 (paper)

Where 𝜇 is the network’s estimate of the mean (conditioned on weights θ and input 𝒙), and σ² is the networks’ estimate of the variance. If you assume a constant σ, this can be simplified to classical regression with MSE.

For an example on implementing a similar loss function in Keras, see the WTTE package, which uses a Weibull distribution instead of a Gaussian.

The following toy example from their paper illustrates the impact, where each red point is drawn from y = x³ + ε where ε ∼ N(0, 32), the blue line is y = x³ and the grey range is the method’s variance estimate conditioned on x:

The leftmost plot shows the variance of training M=5 simple networks which only output the mean and were optimized for MSE. Each model produces only a point estimate, and there is little variance observed over the ensemble.

The second plot shows the results of following the above recipe but with M=1. In this case, the network produces a distribution, but it’s level of uncertainty remains constant even when generalizing outside of it’s domain.

The third plot includes adversarial training (note how little difference it makes) with M=1, and the final plot does everything (mean and variance outputs, adversarial training and M=5.) Only the final plot does a reasonable job of estimating uncertainty outside of the range of the training data.

The authors then show that an ensemble of networks trained in this way on digit classification with MNIST data do a far better job of estimating their uncertainty than other techniques like monte-carlo dropout:

Lakshminarayanan et. al 2017 (paper)

In the above visualization, they vary the number of networks in the ensemble, and compare monte-carlo dropout (green) to a simple ensemble (red) to an ensemble with adversarial training (blue). The grey curves use random data augmentation (rather than adversarial), and show that using the adversarial approach is what adds incremental value to a simple ensemble.

Finally, and perhaps most impressive of all, the authors show that their method responds appropriately when presented with data from an entirely different domain (letters rather than numbers):

Lakshminarayanan et. al 2017 (video)

The blue plots show the uncertainty (measured in entropy given this is a classification problem) for digit classification when presented with numbers. The bottom red plots show the uncertainty when presented with letters.

When using just 1 network in the ensemble (how most deep learning models are deployed), the model trained only on numbers gives equally confident (but obviously wrong) classification results for letters! But increasing to even just 5 networks produces significantly less confident predictions.

Interpreting model behavior

Lundberg et al. 2017, paper, video (17:45) & github

Most complex machine learning models are black boxes — we simply cannot fully understand how they work. However, we can gain deeper insight locally into the predictions that they make, and through this insight can better understand our data and models.

This understanding can be used to:

  • Build intuition for how our algorithms behave
  • Alter end user experiences to provide more context for predictions
  • Debug model building issues arising from data quality, model fit or generalization ability
  • Measure the value of different features in a model, and inform decisions for future data collection and engineering

At Instacart, we often want to deeply understand models we build such as:

  • The expected time until a user places their next order, as a function of their past order, delivery, site and rating behavior
  • What product pairs are good replacements for each-other in case we cannot find what the customer originally requested
  • How our customers react to limited delivery availability options or busy pricing

The SHAP (SHapley Additive exPlanations) paper and package provides an elegant way to decompose a model’s predictions into additive effects, which can then be easily visualized.

For example, here is a visualization that explains a Light GBM prediction of the chance a household earns $50k or more from a UCI census dataset:

In this case, the log-odds likelihood of high income is -1.94, and the largest factor depressing this chance is young age (blue), and the largest factor increasing income is marital status (red).

Furthermore, you can visualize the aggregate impact of features on model predictions over an entire dataset with visualizations like these:

Lundberg et al. 2017 (github)

Here they find that Age is most predictive, but really because there is a group (young) which is separated and low income. Capital Gain is the next most predictive, in part because of both very high and very low contributions.

This is a huge improvement over the typical information gain based variable importance visualizations commonly used with packages like XGBoost and LightGBM, which only show the relative importance of each feature:

R XGBoost Vignette

The package can also provide rich partial dependence plots which show the range of impact that a feature has across the training dataset population:

Lundberg et al. 2017 (github)

Note that the vertical spread of values in the above plot represent interaction effects between Age and other variables (the effect of Age changes with other variables). This is in contrast to traditional partial dependence plots which show only the effect of varying Age in isolation.

To understand how the SHAP algorithm works, consider this example for a single observation:

Lundberg et al. 2017 (video)

Their model is predicting the chance of high income, and on average predicts a base rate of 20% for the entire population, denoted by E[f(x)]. For this specific example (named John in the talk), they predict a 55% probability, denoted by f(x).

The SHAP values answer the question of how they got from 20% to 50% for John.

Lundberg et al. 2017 (video)

They begin by ordering the features randomly, perhaps starting with Age, and ask how much the average prediction of 20% changes for users whose age is the same as John’s, denoted E[f(x) | x₁]. This can be found by integrating f(x) over all other features besides x₁ in the training dataset (a process that can be done efficiently in trees).

Suppose that they find that the prediction goes up to 35%, and so this gives them an estimate for the effect of Age, ϕ₁=15%. They then iteratively repeat this process through the remaining variables (concluding with marital status), to estimate ϕ₂, ϕ₃ and ϕ₄ for each of the other three features in this example:

Lundberg et al. 2017 (video)

However, unless a model is purely additive, the estimates for ϕ will vary with the ordering of features chosen. The SHAP algorithm solves this by averaging over all possible 2ᴺ orderings. The computational burden of computing all such orderings is alleviated by sampling M of them and using a regression model to attribute the impact from the samples to each feature.

The paper justifies the above approach using game theory, and further shows that this theory unifies other interpretation methodologies such as LIME and DeepLIFT:

Lundberg et al. 2017 (video)

And finally, because no NIPS paper would be complete without an MNIST example, they show that the SHAP algorithm does a better job at explaining what part of an 8 represents the essence of an 8 (as opposed to a 3):

Lundberg et al. 2017 (paper)

This shows that their approach can work well even for deep learning models.

Learning from variable length sets

Zaheer et al. 2017, paper & video (16:00)

Established deep learning architectures exist for modeling sparse categorical data (embeddings), sequence data (LSTMs) and image data (CNNs). But what do you do if you want your model to depend upon a variable length unordered set of inputs?

This was precisely the question we asked ourselves at Instacart a year ago while pondering our work on sorting grocery shopping lists in our Deep Learning with Emojis (Not Math) post.

I was overjoyed (and humbled) to see this paper at the NIPS poster session Wednesday night, which generalizes our work, and immediately reminded me of this tweet by Rachel Thomas:

Tweet by Rachael Thomas

In the Deep Sets paper, the authors explain that set based modeling problems fall into two classes:

Zaheer et al. 2017 (video)

In the permutation invariant case, you want to be able to re-order the inputs into your model without affecting the prediction (which is often into a space of a different dimension from your input).

For example, at Instacart we could predict:

  • How much time it will take to pick a basket of groceries at a store location
  • Will a user add to cart any item given a query and a set of product search results
  • How efficient will we be in a city given a set of deliveries and their location and due times, and a set of shoppers and their locations and current status

In the permutation equivariant case, you will produce a predicted value for every input in the set, and you want to be able to re-order the inputs and ensure that the ordering of the outputs changes accordingly.

For example, at Instacart we could predict:

The paper proves that any such set based architecture must take the following form:

Zaheer et al. 2017 (video)

For the permutation invariant case, the architecture will look like this:

Zaheer et al. 2017 (video)

Where ϕ is an arbitrary neural network architecture applied iteratively over every set element 𝒙 (for example, using the Keras TimeDistributed layer wrapper). The outputs of ϕ must then be summed along the set dimension, and can then be passed into yet another arbitrary neural network , which can produce the final output predictions.

For the permutation equivariant case, the architecture is the same as above, but instead of using you use DeepSets layers:

Zaheer et al. 2017 (video)

Where you can see that the output is invariant to the ordering of the input given the symmetry in weight sharing.

The paper provides an obligatory MNIST example, where they seek to learn an architecture that can sum hand-written digits:

Zaheer et al. 2017 (video)

In this case you want the architecture to be permutation invariant, so that sum(1, 2) = sum(2, 1), and to handle variable length input such as sum(1, 2, 7).

Two simple alternative approaches both fail:

On the left hand side, they concatenate the digits and pass them into a hidden layer, but this fails to handle variable sequence length inputs. On the right hand side, they pass them into a recurrent layer, but the results will not be order invariant.

How big of a deal is that? In practice, they found that both GRU and LSTM layers failed dramatically to generalize to sequence lengths longer than 10:

Zaheer et al. 2017 (video)

This paper is particularly rich with application examples, ranging from image tagging, to outlier detection, to point-cloud classification:

Zaheer et al. 2017 (paper)

Summary

Beyond all the hype, NIPS 2017 was an amazing event, and these three papers demonstrate how practically useful these conferences are for applied AI and Machine Learning work. In each case, the author’s work provided mathematical rigor, practical advice, and experimental validation for questions we have been pondering at Instacart.

I hope that you are now as excited by these ideas as we are! If you are interested in working on one of the many challenging problems we have at Instacart, check out our careers page at careers.instacart.com.

Again, I’d like to thank Balaji Lakshminarayanan, Scott Lundberg, Manzil Zaheer and their co-authors for their work, and to everyone involved in organizing NIPS 2017. I’d also like to thank Jeremy Howard for his feedback on this post.

--

--