SAMO: A Lightweight Sharpness-Aware Approach for Multi-Task Optimization with Joint Global-Local Perturbation
By: Hao Ban, Gokul Ram Subramani, Kaiyi Ji
Potential Business Impact:
Makes AI learn many things better, faster.
Multi-task learning (MTL) enables a joint model to capture commonalities across multiple tasks, reducing computation costs and improving data efficiency. However, a major challenge in MTL optimization is task conflicts, where the task gradients differ in direction or magnitude, limiting model performance compared to single-task counterparts. Sharpness-aware minimization (SAM) minimizes task loss while simultaneously reducing the sharpness of the loss landscape. Our empirical observations show that SAM effectively mitigates task conflicts in MTL. Motivated by these findings, we explore integrating SAM into MTL but face two key challenges. While both the average loss gradient and individual task gradients-referred to as global and local information-contribute to SAM, how to combine them remains unclear. Moreover, directly computing each task gradient introduces significant computational and memory overheads. To address these challenges, we propose SAMO, a lightweight \textbf{S}harpness-\textbf{A}ware \textbf{M}ulti-task \textbf{O}ptimization approach, that leverages a joint global-local perturbation. The local perturbations are approximated using only forward passes and are layerwise normalized to improve efficiency. Extensive experiments on a suite of multi-task benchmarks demonstrate both the effectiveness and efficiency of our method. Code is available at https://github.com/OptMN-Lab/SAMO.
Similar Papers
Asynchronous Sharpness-Aware Minimization For Fast and Accurate Deep Learning
Machine Learning (CS)
Makes smart computer programs learn faster and better.
Modality-Aware SAM: Sharpness-Aware-Minimization Driven Gradient Modulation for Harmonized Multimodal Learning
CV and Pattern Recognition
Helps computers learn better from different kinds of information.
GCSAM: Gradient Centralized Sharpness Aware Minimization
Machine Learning (CS)
Makes computer learning better and faster.