High-entropy Advantage in Neural Networks' Generalizability
By: Entao Yang , Xiaotian Zhang , Yue Shang and more
Potential Business Impact:
Makes computers learn better by using "energy" ideas.
One of the central challenges in modern machine learning is understanding how neural networks generalize knowledge learned from training data to unseen test data. While numerous empirical techniques have been proposed to improve generalization, a theoretical understanding of the mechanism of generalization remains elusive. Here we introduce the concept of Boltzmann entropy into neural networks by re-conceptualizing such networks as hypothetical molecular systems where weights and biases are atomic coordinates, and the loss function is the potential energy. By employing molecular simulation algorithms, we compute entropy landscapes as functions of both training loss and test accuracy (or test loss), on networks with up to 1 million parameters, across four distinct machine learning tasks: arithmetic question, real-world tabular data, image recognition, and language modeling. Our results reveal the existence of high-entropy advantage, wherein high-entropy network states generally outperform those reached via conventional training techniques like stochastic gradient descent. This entropy advantage provides a thermodynamic explanation for neural network generalizability: the generalizable states occupy a larger part of the parameter space than its non-generalizable analog at low train loss. Furthermore, we find this advantage more pronounced in narrower neural networks, indicating a need for different training optimizers tailored to different sizes of networks.
Similar Papers
Neural Thermodynamics I: Entropic Forces in Deep and Universal Representation Learning
Machine Learning (CS)
Explains how AI learns by using "entropic forces."
Overfitting has a limitation: a model-independent generalization error bound based on Rényi entropy
Machine Learning (Stat)
Lets computers learn with more data, not bigger.
Emergence of Structure in Ensembles of Random Neural Networks
Machine Learning (CS)
Makes computers learn better from random guesses.