上周调一个YOLOv5的工业检测模型,输入分辨率拉到1280x1280,batch_size刚调到8,12G的显存直接爆了。监控显存曲线发现,前向计算时显存占用突然飙升——典型的激活值显存瓶颈。这时候要么砍分辨率,要么减batch,但检测任务对小目标敏感,分辨率不能降;batch太小又影响BN统计。这种两难境地,搞过大规模模型训练的兄弟应该都遇到过。
这时候就该混合精度训练上场了。不是换硬件,而是让计算更聪明。
混合精度到底混了什么?
简单说,就是让模型在训练时,权重、梯度用FP32(单精度)存,但前向计算用FP16(半精度)跑。FP16的显存占用是FP32的一半,计算速度还能利用Tensor Core加速。但直接全用FP16训练会出问题:数值范围太小(FP16最大表示65504,FP32是3.4e38),梯度容易下溢变成0,特别是YOLO这种有大量卷积和归一化操作的结构。
所以需要“混合”:关键地方用FP32保平安,其他时候用FP16冲速度。
PyTorch的Amp怎么用
PyTorch从1.6开始把Amp集成到torch.cuda.amp里,用起来比老版的NVIDIA Apex清爽很多。看个最小化的例子:
importtorchfromtorch.