NTKMTL: Mitigating Task Imbalance in Multi-Task Learning from Neural Tangent Kernel Perspective
By: Xiaohan Qin , Xiaoxing Wang , Ning Liao and more
Potential Business Impact:
Teaches computers to learn many things at once.
Multi-Task Learning (MTL) enables a single model to learn multiple tasks simultaneously, leveraging knowledge transfer among tasks for enhanced generalization, and has been widely applied across various domains. However, task imbalance remains a major challenge in MTL. Although balancing the convergence speeds of different tasks is an effective approach to address this issue, it is highly challenging to accurately characterize the training dynamics and convergence speeds of multiple tasks within the complex MTL system. To this end, we attempt to analyze the training dynamics in MTL by leveraging Neural Tangent Kernel (NTK) theory and propose a new MTL method, NTKMTL. Specifically, we introduce an extended NTK matrix for MTL and adopt spectral analysis to balance the convergence speeds of multiple tasks, thereby mitigating task imbalance. Based on the approximation via shared representation, we further propose NTKMTL-SR, achieving training efficiency while maintaining competitive performance. Extensive experiments demonstrate that our methods achieve state-of-the-art performance across a wide range of benchmarks, including both multi-task supervised learning and multi-task reinforcement learning. Source code is available at https://github.com/jianke0604/NTKMTL.
Similar Papers
Simplifying Multi-Task Architectures Through Task-Specific Normalization
Machine Learning (CS)
Makes AI learn many things better with less effort.
Tensorized Multi-Task Learning for Personalized Modeling of Heterogeneous Individuals with High-Dimensional Data
Machine Learning (CS)
Helps computers learn about different groups of people.
Injecting Imbalance Sensitivity for Multi-Task Learning
Machine Learning (CS)
Improves AI learning by balancing tasks.