Score: 0

Gradient Dynamics of Attention: How Cross-Entropy Sculpts Bayesian Manifolds

Published: December 27, 2025 | arXiv ID: 2512.22473v1

By: Naman Aggarwal, Siddhartha R. Dalal, Vishal Misra

Transformers empirically perform precise probabilistic reasoning in carefully constructed ``Bayesian wind tunnels'' and in large-scale language models, yet the mechanisms by which gradient-based learning creates the required internal geometry remain opaque. We provide a complete first-order analysis of how cross-entropy training reshapes attention scores and value vectors in a transformer attention head. Our core result is an \emph{advantage-based routing law} for attention scores, \[ \frac{\partial L}{\partial s_{ij}} = α_{ij}\bigl(b_{ij}-\mathbb{E}_{α_i}[b]\bigr), \qquad b_{ij} := u_i^\top v_j, \] coupled with a \emph{responsibility-weighted update} for values, \[ Δv_j = -η\sum_i α_{ij} u_i, \] where $u_i$ is the upstream gradient at position $i$ and $α_{ij}$ are attention weights. These equations induce a positive feedback loop in which routing and content specialize together: queries route more strongly to values that are above-average for their error signal, and those values are pulled toward the queries that use them. We show that this coupled specialization behaves like a two-timescale EM procedure: attention weights implement an E-step (soft responsibilities), while values implement an M-step (responsibility-weighted prototype updates), with queries and keys adjusting the hypothesis frame. Through controlled simulations, including a sticky Markov-chain task where we compare a closed-form EM-style update to standard SGD, we demonstrate that the same gradient dynamics that minimize cross-entropy also sculpt the low-dimensional manifolds identified in our companion work as implementing Bayesian inference. This yields a unified picture in which optimization (gradient flow) gives rise to geometry (Bayesian manifolds), which in turn supports function (in-context probabilistic reasoning).

Category
Statistics:
Machine Learning (Stat)