optimizer扩展实践:引入新型优化算法
在大模型训练的现实战场上,显存瓶颈常常成为压垮实验的最后一根稻草。你是否经历过这样的场景:满怀期待地启动 Llama-3-8B 的全参数微调任务,结果刚进入第一个 step 就遭遇 OOM(Out of Memory)?传统 AdamW 优化器对显存的“贪婪”需求,在千亿级参数面前几乎不可持续——仅优化器状态就可能吃掉数十 GB 显存,让单卡训练沦为奢望。
正是在这种背景下,ms-swift 框架走出了一条不一样的路。它没有选择简单堆算力,而是从优化器本身入手,系统性集成 GaLore、Q-Galore、LISA 等前沿轻量级优化技术,将原本沉重的训练负担“瘦身”到消费级显卡也能承载的程度。这不仅是工程上的巧思,更是一种范式转变:我们不再被动适应硬件限制,而是主动重构训练流程本身。
低秩投影的艺术:GaLore 如何重塑梯度更新
当模型参数动辄上亿时,每一层的梯度本质上是一个巨大的二维矩阵。比如一个 $5120 \times 4096$ 的 FFN 层权重,其梯度也是同样尺寸。AdamW 需要在每个参数上维护动量和方差,这意味着每步都要存储两个同形矩阵——显存消耗直接翻三倍。
GaLore 的突破在于,它意识到并非所有梯度方向都同等重要。通过奇异值分解(SVD),我们可以把梯度 $G$ 近似为:
$$
G \approx U S V^T, \quad U\in\mathbb{R}^{m\times r}, V\in\mathbb{R}^{n\times r}, r \ll \min(m,n)
$$
这里的 $r$ 通常设为 16 或 32,意味着我们将原始高维梯度压缩到一个极低维度的子空间中进行优化。真正聪明的是后续操作:只在这个低秩空间里跑 Adam 更新逻辑,然后再将更新量反投影回原空间。
这种设计带来了惊人的显存收益。以 LLaMA-7B 为例,全参数微调下 AdamW 所需优化器状态约为 60GB,而 GaLore 可将其压缩至不足 10GB。更重要的是,由于保留了前 $r$ 个主成分,关键更新方向得以保留,收敛稳定性远超随机稀疏化等粗暴手段。
实际使用中有个经验法则:对于 >7B 的模型,建议设置rank=16;若追求更高精度可尝试rank=32,但边际收益递减明显。另外,并非所有层都适合应用 GaLore——Embedding 和 LayerNorm 通常保持原样,避免引入不必要的噪声。
@register_optimizer('galore') class GaLoreAdam(Adam): def __init__(self, params, lr=1e-3, rank=16, update_proj_gap=50, alpha=1.0, **kwargs): self.rank = rank self.update_proj_gap = update_proj_gap # 控制U/V更新频率 self.alpha = alpha super().__init__(params, lr=lr, **kwargs) def step(self, closure=None): for group in self.param_groups: for p in group['params']: if p.grad is None or p.ndim < 2: # 跳过偏置等小张量 continue grad = p.grad.data state = self.state[p] if len(state) == 0: shape = grad.shape rows, cols = shape[0], shape[1:].numel() # 初始化左右投影矩阵 state['U'] = torch.randn(rows, self.rank, device=p.device) / (self.rank ** 0.5) state['V'] = torch.randn(cols, self.rank, device=p.device) / (self.rank ** 0.5) state['grad_shape'] = shape U, V = state['U'], state['V'] grad_flat = grad.view(U.size(0), -1) g_projected = U.t() @ grad_flat @ V # 投影到[r,r]空间 if 'exp_avg' not in state: state['exp_avg'] = torch.zeros_like(g_projected) state['exp_avg_sq'] = torch.zeros_like(g_projected) # 标准 Adam 更新(略) ... # 反投影并更新原始参数 update_flat = U @ state['exp_avg'] @ V.t() update = update_flat.view(p.shape) p.data.add_(update, alpha=-group['lr'])注意这个实现中的几个细节:
-update_proj_gap参数控制是否定期更新 $U/V$ 矩阵,防止投影空间僵化;
- 仅对二维及以上张量启用,避免干扰标量类参数;
- 使用.view()而非.reshape()保证内存连续性,提升性能。
双重压缩:Q-Galore 的量化飞跃
如果说 GaLore 是“减肥”,那 Q-Galore 就是“脱水+真空压缩”。它在低秩基础上叠加了 8-bit 量化,实现了真正的极致轻量化。
其核心思想非常直接:既然我们在低维空间做优化,那么连这个空间内的动量和方差也可以压缩。具体来说,采用仿射量化方案:
$$
x_{quant} = \text{round}\left(\frac{x - \min(x)}{\text{scale}}\right), \quad
x_{dequant} = x_{quant} \cdot \text{scale} + \min(x)
$$
其中 scale 通常是 $\max(x)-\min(x)/127$。关键在于,这种变换是可逆的,且误差可控。更重要的是,现代 GPU 对 int8 计算有专门加速单元,使得量化/反量化的开销几乎可以忽略。
实际效果令人震撼。在 Llama-3-8B 微调任务中,标准 AdamW 优化器状态约需 40GB 显存,GaLore 压缩至 ~6GB,而 Q-Galore 进一步降至2.5GB 左右——这意味着 RTX 3090(24GB)甚至 A10(24GB)都能胜任全参数微调!
def quantize_tensor(t, group_size=128): """按组进行 per-tensor 量化""" t_flatten = t.reshape(-1, group_size) scale = t_flatten.abs().max(dim=-1, keepdim=True)[0] / 127 t_quant = (t_flatten / scale).round().clamp(-128, 127).to(torch.int8) return t_quant, scale def dequantize_tensor(t_quant, scale, group_size=128): t_dequant = t_quant.float() * scale return t_dequant.reshape(-1) class QGaloreAdam(GaLoreAdam): def step(self, closure=None): for group in self.param_groups: for p in group['params']: if p.grad is None or p.ndim < 2: continue grad = p.grad.data state = self.state[p] U, V = state['U'], state['V'] # 先量化原始梯度 grad_quant, g_scale = quantize_tensor(grad) grad = dequantize_tensor(grad_quant, g_scale) # 投影到低秩空间 grad_flat = grad.view(U.size(0), -1) g_projected = U.t() @ grad_flat @ V # 再次量化用于优化器状态 proj_quant, p_scale = quantize_tensor(g_projected) g_projected = dequantize_tensor(proj_quant, p_scale) if 'exp_avg' not in state: state['exp_avg'] = torch.zeros_like(g_projected) state['exp_avg_sq'] = torch.zeros_like(g_projected) state['p_scale'] = p_scale # 缓存scale供反投影用 # Adam 更新逻辑(作用于 dequant 后的数据) # 反投影前必须反量化 update_projected = dequantize_tensor(state['exp_avg'].data, state['p_scale']) update_flat = U @ update_projected @ V.t() update = update_flat.view(p.shape) p.data.add_(update, alpha=-group['lr'])这里最易出错的地方是 scale 的管理。如果忘记缓存或误用了错误的 scale,会导致更新量失真,进而破坏收敛过程。建议每 100~500 步动态更新一次 scale,平衡稳定性和精度。
稀疏中的智慧:LISA 的动态感知策略
与前两者不同,LISA 不走降维路线,而是另辟蹊径——只更新重要的参数。它的灵感来源于观察:在大多数训练步骤中,只有少数参数具有显著梯度,其余近乎静止。
LISA 的工作方式如下:
1. 每隔若干步计算各参数的梯度绝对值作为“重要性评分”;
2. 按分数排序,选取 top-k%(如 50%)作为活跃集;
3. 仅对该子集维护独立的动量/方差;
4. 非活跃参数共享全局平均状态。
这种方法既减少了状态存储,又避免了完全冻结带来的表达能力损失。尤其适合长序列任务(如文档理解、视频分析),因为这些场景下 batch size 往往受限于显存,而 LISA 能有效缓解这一压力。
@register_optimizer('lisa_adam') class LISAAdam(Adam): def __init__(self, params, lr=1e-3, sparsity=0.5, update_freq=100, **kwargs): self.sparsity = sparsity # 活跃参数比例 self.update_freq = update_freq self.step_count = 0 super().__init__(params, lr=lr, **kwargs) def step(self, closure=None): self.step_count += 1 for group in self.param_groups: for p in group['params']: if p.grad is None: continue grad = p.grad.data state = self.state[p] # 定期更新重要性掩码 if 'mask' not in state or (self.step_count % self.update_freq == 0): importance = grad.abs() k = int((1 - self.sparsity) * importance.numel()) if k > 0 and k < importance.numel(): threshold = torch.kthvalue(importance.flatten(), k).values state['mask'] = (importance >= threshold) else: state['mask'] = torch.ones_like(importance, dtype=torch.bool) mask = state['mask'] sparse_grad = grad * mask if 'exp_avg' not in state: state['exp_avg'] = torch.zeros_like(sparse_grad) state['exp_avg_sq'] = torch.zeros_like(sparse_grad) # 仅更新 masked 区域 exp_avg = state['exp_avg'] exp_avg_sq = state['exp_avg_sq'] grad_masked = sparse_grad[mask] exp_avg[mask].mul_(group['betas'][0]).add_(grad_masked, alpha=1 - group['betas'][0]) sq = grad_masked.pow(2) exp_avg_sq[mask].mul_(group['betas'][1]).add_(sq, alpha=1 - group['betas'][1]) # 非活跃区域用均值填充(防止完全停滞) if mask.any() and (~mask).any(): avg_mom = exp_avg[mask].mean() avg_var = exp_avg_sq[mask].mean() exp_avg[~mask] = avg_mom exp_avg_sq[~mask] = avg_var denom = (exp_avg_sq.sqrt() + group['eps']) update = exp_avg / denom p.data.add_(update, alpha=-group['lr'])实践中发现,sparsity=0.5是个不错的起点,兼顾效率与性能。对于多模态任务(如 VQA),由于输入复杂度高,建议降低稀疏度至 0.7 以上,确保足够的更新灵活性。
插件化架构:如何让这一切无缝协作
ms-swift 的真正强大之处,在于它把这些复杂的优化逻辑封装成了即插即用的模块。整个系统基于注册机制构建:
optimizer: type: q_galore rank: 16 lr: 2e-5 betas: [0.9, 0.99] weight_decay: 0.01当你提交这样的配置时,框架会自动通过@register_optimizer查找对应类并实例化。整个过程对用户透明,无需修改任何训练脚本。
其底层架构清晰而稳健:
[用户配置 YAML] ↓ [Swift Trainer] ↓ [Optimizer Registry] ← 支持 galore/q_galore/lisa_adam ↓ [Distributed Training Engine (DDP/FSDP)] ↓ [Model Forward/Backward] ↓ [Custom Optimizer Step]更重要的是,这套设计遵循严格接口规范,确保兼容性。所有自定义优化器必须继承torch.optim.Optimizer,实现标准step()方法。同时内置故障回退机制:一旦检测到数值异常(如 NaN),可自动切换至 AdamW 继续训练,保障实验连续性。
实战建议与未来展望
结合大量实测经验,给出以下推荐组合:
| 场景 | 推荐配置 | 效果 |
|---|---|---|
| >7B 模型微调 | GaLore + LoRA | 显存节省 60%,精度无损 |
| 单卡消费级设备 | Q-Galore (rank=16) | A10 上跑通 Llama-3-8B |
| 长序列多模态任务 | LISA (sparsity=0.5) | 提升 batch size 2~3 倍 |
值得注意的是,这些方法并非互斥。例如,你可以对 QKV 投影层用 GaLore,FFN 层用 LISA,形成混合策略。未来随着 K-FAC 低秩近似、动态精度调度等新思路的融入,optimizer 扩展将变得更加智能。
最终,这种从“硬拼资源”到“软硬协同”的转变,正在重新定义大模型训练的可能性边界。它不仅降低了 AI 研发的准入门槛,也让绿色 AI、边缘训练等愿景变得触手可及。