news 2026/4/18 3:44:26

使用torch.compile与梯度累积加速模型训练

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
使用torch.compile与梯度累积加速模型训练

训练一个具有深度Transformer架构的语言模型是耗时的。然而,有些技巧可以用来加速训练。在本文中,你将学习到:

  • 使用 torch.compile() 加速模型
  • 使用梯度累积来训练具有更大有效批次大小的模型

让我们开始吧!

概述

本文分为两个部分:

  • 使用 torch.compile()
  • 梯度累积

使用 torch.compile

当你在PyTorch中编写并运行模型代码时,它是在eager模式下执行的。这意味着代码是一行一行执行的,结果存储在内存中。这是Python的原生方式,因为它是一种解释型语言。你知道这一点是因为当代码出现错误时,只有运行到该行时才会看到错误提示。

在eager模式下运行模型速度较慢。从PyTorch 2.0开始,你可以使用torch.compile()来编译模型以提高性能。这会生成一个经过优化的新模型对象。它不是你用nn.Module创建的原始模型对象,但它与原始模型共享相同的张量。你可以像往常一样使用这个编译后的模型进行前向传播、反向传播和优化器更新。

将模型构建并编译成计算图正是TensorFlow 1.0的设计思路。这使得调试更加困难,因为你执行的模型无法与你编写的代码逐行对应。因此,在运行试验并确认模型没有错误之前,你不应该编译模型。

并非所有模型都可以编译。但是,如果你的模型支持编译,你将立即受益于速度提升。要编译一个模型,你只需要在准备使用模型之前替换模型对象:

... model = LlamaForPretraining(model_config).to(device) model.load_state_dict(checkpoint) model = torch.compile(model) ...

不要在编译后加载模型权重。这是因为编译后的模型是一个与原始模型共享权重的对象。在编译过程中,构建的计算图引用了原始模型的权重张量。如果你在编译后加载权重,模型可能无法按预期工作。

同样,要保存编译后的模型,你应该引用原始模型的状态字典,如下所示:

torch.save(getattr(model, "_orig_mod", model).state_dict(), "model.pth")

可以通过model._orig_mod访问编译模型中的原始模型。在上面的代码中,我们使用getattr(model, "_orig_mod", model)来获取原始模型(如果存在),或者如果不存在则使用模型本身。这行代码对编译模型和原始模型都适用。

梯度累积

当你训练一个模型时,你在反向传播上花费的时间可能是前向传播的两到三倍。这是因为反向传播计算强度更大,并且占用更多内存。

一个简单的加速训练技巧是减少反向传播的次数。这可以通过增加批次大小来实现:对于相同数量的数据样本,更大的批次大小意味着要处理的批次更少。

然而,更大的批次大小需要更多内存。在内存受限的环境中,你可以通过运行多次前向传播并累积梯度来模拟更大的批次大小。这被称为梯度累积

用代码来解释这个想法更容易:

.. accumulate_steps = 4 for epoch in range(num_epochs): optimizer.zero_grad() for i, batch in enumerate(dataloader): # 获取批次数据 input_ids, target_ids = batch # 创建注意力掩码:因果掩码 + 填充掩码 attn_mask = create_causal_mask(input_ids.shape[1], device) + \ create_padding_mask(input_ids, PAD_TOKEN_ID, device) # 从模型提取输出 logits = model(input_ids, attn_mask) # 计算损失:logits与目标之间的交叉熵,忽略填充标记 loss = loss_fn(logits.view(-1, logits.size(-1)), target_ids.view(-1)) loss = loss / accumulate_steps # 运行反向传播,但每`accumulate_steps`步才更新一次 loss.backward() if (i + 1) % accumulate_steps == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() optimizer.zero_grad() scheduler.step()

上面的训练循环摘自上一篇关于在本地GPU上训练Llama模型的文章。

通常,当你运行一次前向传播时,你会计算损失。然后调用loss.backward()通过模型参数反向传播损失梯度。在PyTorch中,backward()方法是累积的,这意味着梯度是相加的。因此,你需要在运行反向传播之前显式调用optimizer.zero_grad()来清除梯度。

在上面的代码中,你故意不在每次迭代中都调用optimizer.zero_grad()。相反,你对损失(除以accumulate_steps)运行反向传播。这样,梯度被缩小但在accumulate_steps次迭代中累积。每经过accumulate_steps次迭代,你才运行优化器来调整模型参数。

这种方法产生的结果与使用更大批次大小获得的结果相当。然而,由于你运行的优化器更新次数更少,学习率调度器应相应调整。这意味着你需要用不同的步数来初始化调度器:

... num_training_steps = (len(dataloader) // accumulate_steps) * num_epochs cosine_scheduler = lr_scheduler.CosineAnnealingLR( optimizer, T_max=num_training_steps - num_warmup_steps, eta_min=0 )

进一步阅读

以下是一些你可能感兴趣的资料:

  • torch.compile 文档
  • PyTorch 文档中的自动混合精度示例

总结

在本文中,你了解到使用torch.compile()可以通过编译计算图来帮助你加速模型。你还了解到,梯度累积是一种通过累积多个小批次的梯度来训练更大有效批次大小的技术。由于这种方式减少了优化器更新次数,你可以节省反向传播和参数更新的时间。
更多精彩内容 请关注我的个人公众号 公众号(办公AI智能小助手)或者 我的个人博客 https://blog.qife122.com/
对网络安全、黑客技术感兴趣的朋友可以关注我的安全公众号(网络安全技术点滴分享)

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/10 22:15:51

计算机毕业设计springboot房车旅途 基于SpringBoot的房车租赁与售卖一体化平台 SpringBoot+Vue智慧房车出行服务系统

计算机毕业设计springboot房车旅途(配套有源码 程序 mysql数据库 论文) 本套源码可以在文本联xi,先看具体系统功能演示视频领取,可分享源码参考。疫情之后,自驾露营热度飙升,把“家”装在车轮上成为年轻家庭的新宠。传…

作者头像 李华
网站建设 2026/4/11 21:15:02

2005-2024年上市公司管理者短视主义数据+stata代码

数据名称:2005-2024年管理者短视主义数据 时间:2005-2024年 数据量:53104条 范围:沪深A股上市公司 包含剔除金融stpt、未剔除版本 包含原始数据、处理代码(stata)、最终结果 指标构建:基于…

作者头像 李华
网站建设 2026/4/8 11:34:38

文字游戏:进化之路2.0二开完美版本源码 带后台

内容目录 一、详细介绍二、效果展示1.部分代码2.效果图展示 三、学习资料下载 一、详细介绍 文字游戏:进化之路2.0二开完美版本源码 带后台 基于原版二开。原版没有后台功能,前端某些功能也是没有的! 后端部分功能参考额曜崽i的版本思路&am…

作者头像 李华
网站建设 2026/4/8 21:39:42

Node.js——Node.js 中间件与控制器实现问题

问题难点 在实现复杂的业务逻辑时,如何正确使用中间件处理请求、如何设计高效的控制器成为关键问题。 解决方案 Egg.js提供了灵活的中间件机制和基于装饰器的控制器实现方式。 Demo代码: // app/middleware/auth.ts - 认证中间件 import { Context, Next…

作者头像 李华
网站建设 2026/4/5 23:32:06

uni-app—— 小程序表单页面键盘弹起布局错乱问题

问题现象 表单页面点击输入框,键盘弹起后:平台表现安卓输入框位置错位,光标飘到其他位置iOS键盘遮挡输入框,看不到输入内容问题原因 当页面同时存在以下三个因素时,容易出现布局错乱: scroll-view float布…

作者头像 李华
网站建设 2026/4/4 23:11:35

什么是Java可重入锁?

大家好,我是锋哥。今天分享关于【什么是Java可重入锁?】面试题。希望对大家有帮助; 什么是Java可重入锁? Java 可重入锁(Reentrant Lock)是 Java 中的一种高级同步工具,用于控制对共享资源的访…

作者头像 李华