Understanding deep learning requires re-thinking generalization

Understanding deep learning requires re-thinking generalization Zhang et al., ICLR’17

This paper has a wonderful combination of properties: the results are easy to understand, somewhat surprising, and then leave you pondering over what it all might mean for a long while afterwards!

The question the authors set out to answer was this:

What is it that distinguishes neural networks that generalize well from those that don’t? A satisfying answer to this question would not only help to make neural networks more interpretable, but it might also lead to more principled and reliable model architecture design.

By “generalize well,” the authors simply mean “what causes a network that performs well on training data to also perform well on the (held out) test data?” (As opposed to transfer learning, which involves applying the trained network to a related but different problem). If you think about that for a moment, the question pretty much boils down to “why do neural networks work as well as they do?” Generalisation is the difference between just memorising portions of the training data and parroting it back, and actually developing some meaningful intuition about the dataset that can be used to make predictions. So it would be somewhat troubling, would it not, if the answer to the question “why do neural networks work (generalize) as well as they do?” turned out to be “we don’t really know!”

The curious case of the random labels

Our story begins in a familiar place – the CIFAR 10 (50,000 training images split across 10 classes, 10,000 validation images) and the ILSVRC (ImageNet) 2012 (1,281,167 training, 50,000 validation images, 1000 classes) datasets and variations of the Inception network architecture.

Train the networks using the training data, and you won’t be surprised to hear that they can reach zero errors on the training set. This is highly indicative of overfitting – memorising training examples rather than learning true predictive features. We can use techniques such as regularisation to combat overfitting, leading to networks that generalise better. More on that later.

Take the same training data, but this time randomly jumble the labels (i.e., such that there is no longer any genuine correspondence between the label and what’s in the image). Train the networks using these random labels and what do you get? Zero training error!

In [this] case, there is no longer any relationship between the instances and the class labels. As a result, learning is impossible. Intuition suggests that this impossibility should manifest itself clearly during training, e.g., by training not converging or slowing down substantially. To our suprise, several properties of the training process for multiple standard architectures is largely unaffected by this transformation of the labels.

As the authors succinctly put it, “Deep neural networks easily fit random labels.” Here are three key observations from this first experiment:

  1. The effective capacity of neural networks is sufficient for memorising the entire data set.
  2. Even optimisation on random labels remains easy. In fact, training time increases by only a small constant factor compared with training on the true labels.
  3. Randomising labels is solely a data transformation, leaving all other properties of the learning problem unchanged.

If you take the network trained on random labels, and then see how well it performs on the test data, it of course doesn’t do very well at all because it hasn’t truly learned anything about the dataset. A fancy way of saying this is that it has a high generalisation error. Put all this together and you realise that:

… by randomizing labels alone we can force the generalization error of a model to jump up considerably without changing the model, its size, hyperparameters, or the optimizer. We establish this fact for several different standard architectures trained on the CIFAR 10 and ImageNet classification benchmarks. (Emphasis mine).

Or in other words: the model, its size, hyperparameters, and the optimiser cannot explain the generalisation performance of state-of-the-art neural networks. This must be the case because the generalisation performance can vary significantly while they all remain unchanged.

The even more curious case of the random images

What happens if we don’t just mess with the labels, but we also mess with the images themselves. In fact, what if just replace the true images with random noise?? In the figures this is labeled as the ‘Gaussian’ experiment because a Gaussian distribution with matching mean and variance to the original image dataset is used to generate random pixels for each image.

In turns out that what happens is the networks train to zero training error still, but they get there even faster than the random labels case! A hypothesis for why this happens is that the random pixel images are more separated from each other than the random label case of images that originally all belonged to the same class, but now must be learned as differing classes due to label swaps.

The team experiment with a spectrum of changes introducing different degrees and kinds of randomisation into the dataset:

  • true labels (original dataset without modification)
  • partially corrupted labels (mess with some of the labels)
  • random labels (mess with all of the labels)
  • shuffled pixels (choose a pixel permutation, and then apply it uniformly to all images)
  • random pixels (apply a different random permutation to each image independently)
  • Guassian (just make stuff up for each image, as described previously)

All the way along the spectrum, the networks are still able to perfectly fit the training data.

We furthermore vary the amount of randomization, interpolating smoothly between the case of no noise and complete noise. This leads to a range of intermediate learning problems where there remains some level of signal in the labels. We observe a steady deterioration of the generalization error as we increase the noise level. This shows that neural networks are able to capture the remaining signal in the data, while at the same time fit the noisy part using brute-force.

For me that last sentence is key. Certain choices we make in model architecture clearly do make a difference in the ability of a model to generalise (otherwise all architectures would generalise the same). The best generalising network in the world is still going to have to fallback on memorisation when there is no other true signal in the data though. So maybe we need a way to tease apart the true potential for generalisation that exists in the dataset, and how efficient a given model architecture is at capturing this latent potential. A simple way of doing that is to train different architectures on the same dataset! (Which we do all the time of course). That still doesn’t help us with the original quest though – understanding why some models generalise better than others.

Regularization to the rescue?

The model architecture itself is clearly not a sufficient regulariser (can’t prevent overfitting / memorising). But what about commonly used regularisation techniques?

We show that explicit forms of regularization, such as weight decay, dropout, and data augmentation, do not adequately explain the generalization error of neural networks: Explicit regularization may improve generalization performance, but is neither necessary nor by itself sufficient for controlling generalization error.

Explicit regularisation seems to be more of a tuning parameter that helps improve generalisation, but its absence does not necessarily imply poor generalisation error. It is certainly not the case that not all models that fit the training data generalise well though. An interesting piece of analysis in the paper shows that we pick up a certain amount of regularisation just through the process of using gradient descent:

We analyze how SGD acts as an implicit regularizer. For linear models, SGD always converges to a solution with small norm. Hence, the algorithm itself is implicitly regularizing the solution… Though this doesn’t explain why certain architectures generalize better than other architectures, it does suggest that more investigation is needed to understand exactly what the properties are that are inherited by models trained using SGD.

The effective capacity of machine learning models

Consider the case of neural networks working with a finite sample size of n. If a network has p parameters, where p is greater than n, then even a simple two-layer neural network can represent any function of the input sample. The authors prove (in an appendix), the following theorem:

There exists a two-layer neural network with ReLU activations and 2n + d weights that can represent any function on a sample of size n in d dimensions.

Even depth-2 networks of linear size can already represent any labeling of the training data!

So where does this all leave us?

This situation poses a conceptual challenge to statistical learning theory as traditional measures of model complexity struggle to explain the generalization ability of large artificial neural networks. We argue that we have yet to discover a precise formal measure under which these enormous models are simple. Another insight resulting from our experiments is that optimization continues to be empirically easy even if the resulting model does not generalize. This shows that the reasons for why optimization is empirically easy must be different from the true cause of generalization.

34 thoughts on “Understanding deep learning requires re-thinking generalization

  1. Can you explain in a bit more detail what you mean by:

    Or in other words: the model, its size, hyperparameters, and the optimiser cannot explain the generalisation performance of state-of-the-art neural networks. This must be the case because the generalisation performance can vary significantly while they all remain unchanged.

    For me it seems obvious that there must exist a task (distribution) that would show that deep learning can be made arbitrarily bad on the test set. It seems trivially true just from the no free lunch theorem.

    How is it then that that experiment shows that they cannot explain generalization properties of deep nets. If anything it would seem that they do since we showed there is some distribution that they don’t work on (in fact, there shouldn’t be any algorithm that can generalize in the random labels experiment). Or what did you have in mind when you said that?

    1. Hi Brando, this paragraph is directly related to the quote from the paper immediately above it, where I’m trying to give an interpretation to make it a little clearer what the authors are saying. There is no function f :: (Model x Size x Hyperparameters x Optimiser) -> GeneralisationAbility since we can have exactly the same inputs to that function and get different generalisation performance. I’m missing a little bit of context in that pull quote: here’s some additional background lifted from the abstract “Conventional wisdom attributes small generalization error either to properties of the model family, or to the regularization techniques used during training. Through extensive systematic experiments, we show how these traditional approaches fail to explain why large neural networks generalize well in practice.” And also though from section 1.1 “In this work, we problematize the traditional view of generalization by showing that it is incapable of distinguishing between different neural networks that have radically different generalization performance.”

      Now where I’m not still not sure about all this myself, is if we take a dataset that does have some genuine signal in it (i.e., generalisation is possible if a model can learn that signal successfully), and hold that as an additional fixed input to my hypothetical function, then it seems to me we *can* now write such a function, and changes in the model etc., do clearly impact that. (I ramble about this a little later in the write-up). Or as you put it quite well, no models can be distinguished on generalisation performance when generalisation isn’t possible! But that doesn’t tell us that they can’t when generalisation *is* possible. I’m still working through in my own mind exactly what implications we can draw – certainly the paper shows the tremendous memorisation capabilities, and demonstrates that regularisation doesn’t seem to contribute as much to generalisation ability as perhaps was once thought.

      Sorry for the long rambling reply!
      Regards, A.

      1. If I understand correctly, the experiments don’t rule out the possibility, that what explains good generalization is the structure of the data. This actually was my intuition for a long time – that to explain (mathematically) a lot of the generalization performance we will need good mathematical description of the data itself – but then we could probably also build models better suited to data …

  2. This will all boil down to what does generalization even mean? I wrote about this several months ago (see: https://medium.com/intuitionmachine/rethinking-generalization-in-deep-learning-ec66ed684ace ) and more recently this ( https://medium.com/intuitionmachine/deep-learning-knowable-knowns-and-unknowns-17efb8822059 ).

    So generalization is about knowable unknowns. Accuracy in prediction therefore will depend on whatever knowledge you have currently. You can ‘rig’ that by selecting a neural network that enforces certain kinds of invariances or you can try to rig it with a regularization that demands something like smoothness. However, the more kinds of priors you can cram into the Deep Learning network (via neural embeddings etc) or alternative perspectives (i.e mixture of experts) the better your generalization.

    Ultimately though, it’s just going to be some approximation (or guess) that the network will perform based on what knowledge it has.

    Why do these networks work so well? Nobody has a good idea yet.

  3. > It is certainly not the case that not all models that fit the training data generalise well though.

    Thus there does not exist a model that, if it fits the training data, doesn’t generalize well? This goes against my understanding – perhaps one of the “not”s (or my parser) is out of alignment.

    1. That is a terribly clumsy (and wrong!) sentence on my part, sorry. It should just say “Not all models that fit the training data generalise well.” Regards, A.

  4. Hi adriancolyer. I’m Wenfei from China and am self-studying deep learning. Excellently wrote. Nice reading and learned a lot. May I translate it into Chinese and share with my deep learning mates? The translation would only be circulated in our group. Many Thanks!

  5. This paper was well thought out, and touched on great points, particualar the images that are part of this blog that show the learning curves and the training steps. Thank you for taking the time to blog about this.

    1. I use WordPress.com hosting, and I’m not sure if they include any out-of-the-box spam plugins by default, but so far it’s been tolerable (but on the rise still). I do have to moderate every single comment because of this though.

    1. I noticed the excitement around that work, but I haven’t yet had a chance to study the paper – but it’s on my backlog! If I can make enough sense of it, I’ll do a write-up on the blog :)

  6. Dear all, As Ann amateur from another field, I ask, what happens if one trains a DNN on say 60k different fixed images with good labels, but each training iteration is corrupted by a different instance of say 30% noise (ranging in different experiments from 10-90%. For the label swapping experiment, maybe the labels are 90% swapped in each epoch and 10% invariant. How does the performance degrade then? Any refs? Thanks, -sg

Leave a Reply to jpt4 Cancel reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google photo

You are commenting using your Google account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s

This site uses Akismet to reduce spam. Learn how your comment data is processed.