GradientSpace: Unsupervised Data Clustering for Improved Instruction Tuning
By: Shrihari Sridharan , Deepak Ravikumar , Anand Raghunathan and more
Potential Business Impact:
Teaches AI to learn many skills better.
Instruction tuning is one of the key steps required for adapting large language models (LLMs) to a broad spectrum of downstream applications. However, this procedure is difficult because real-world datasets are rarely homogeneous; they consist of a mixture of diverse information, causing gradient interference, where conflicting gradients pull the model in opposing directions, degrading performance. A common strategy to mitigate this issue is to group data based on semantic or embedding similarity. However, this fails to capture how data influences model parameters during learning. While recent works have attempted to cluster gradients directly, they randomly project gradients into lower dimensions to manage memory, which leads to accuracy loss. Moreover, these methods rely on expert ensembles which necessitates multiple inference passes and expensive on-the-fly gradient computations during inference. To address these limitations, we propose GradientSpace, a framework that clusters samples directly in full-dimensional gradient space. We introduce an online SVD-based algorithm that operates on LoRA gradients to identify latent skills without the infeasible cost of storing all sample gradients. Each cluster is used to train a specialized LoRA expert along with a lightweight router trained to select the best expert during inference. We show that routing to a single, appropriate expert outperforms expert ensembles used in prior work, while significantly reducing inference latency. Our experiments across mathematical reasoning, code generation, finance, and creative writing tasks demonstrate that GradientSpace leads to coherent expert specialization and consistent accuracy gains over state-of-the-art clustering methods and finetuning techniques.
Similar Papers
SPACE: Noise Contrastive Estimation Stabilizes Self-Play Fine-Tuning for Large Language Models
Machine Learning (CS)
Makes AI learn better from less real-world examples.
An unsupervised tour through the hidden pathways of deep neural networks
Machine Learning (CS)
Helps computers understand what they see better.
Structured Contrastive Learning for Interpretable Latent Representations
Machine Learning (CS)
Makes AI understand data even when it's slightly changed.