PyTorch 量化感知训练:QAT 与 PTQ 实践指南
1. 技术分析
1.1 量化类型
| 量化类型 | 描述 | 训练要求 | 精度损失 |
|---|
| PTQ (Post-Training Quantization) | 训练后量化 | 无需重新训练 | 中 |
| QAT (Quantization-Aware Training) | 训练中量化 | 需要重新训练 | 低 |
| Dynamic Quantization | 仅量化权重 | 无需数据 | 低 |
| Static Quantization | 量化权重和激活 | 需要校准数据 | 低 |
1.2 量化位宽对比
| 位宽 | 内存节省 | 速度提升 | 精度损失 |
|---|
| INT8 | 4x | 2-4x | 低 |
| INT4 | 8x | 4-8x | 中 |
| FP16 | 2x | 1.5x | 极低 |
1.3 量化流程
量化流程 1. 训练模型 (FP32) 2. 量化准备 (插入量化节点) 3. 校准/微调 (QAT) 4. 量化转换 (INT8) 5. 推理部署
2. 核心功能实现
2.1 动态量化
import torch import torch.nn as nn class SimpleModel(nn.Module): def __init__(self): super().__init__() self.fc1 = nn.Linear(100, 200) self.relu = nn.ReLU() self.fc2 = nn.Linear(200, 10) def forward(self, x): x = self.fc1(x) x = self.relu(x) x = self.fc2(x) return x def dynamic_quantization(): model = SimpleModel() model.eval() quantized_model = torch.quantization.quantize_dynamic( model, {nn.Linear}, dtype=torch.qint8 ) input = torch.randn(1, 100) output = quantized_model(input) print(f"原始模型大小: {sum(p.numel() for p in model.parameters()) * 4 / 1024:.2f} KB") print(f"量化模型大小: {sum(p.numel() for p in quantized_model.parameters()) * 1 / 1024:.2f} KB") class DynamicQuantWrapper: def __init__(self, model): self.model = model self.quantized = None def quantize(self, qconfig=None): if qconfig is None: qconfig = torch.quantization.default_dynamic_qconfig self.quantized = torch.quantization.quantize_dynamic( self.model, qconfig_spec={nn.Linear: qconfig} ) return self.quantized def inference(self, input): if self.quantized is None: return self.model(input) return self.quantized(input)
2.2 静态量化
class StaticQuantModel(nn.Module): def __init__(self): super().__init__() self.quant = torch.quantization.QuantStub() self.conv1 = nn.Conv2d(3, 64, kernel_size=3) self.relu1 = nn.ReLU() self.conv2 = nn.Conv2d(64, 128, kernel_size=3) self.relu2 = nn.ReLU() self.fc = nn.Linear(128 * 28 * 28, 10) self.dequant = torch.quantization.DeQuantStub() def forward(self, x): x = self.quant(x) x = self.conv1(x) x = self.relu1(x) x = self.conv2(x) x = self.relu2(x) x = x.view(x.size(0), -1) x = self.fc(x) x = self.dequant(x) return x def static_quantization(model, calibration_data): model.qconfig = torch.quantization.get_default_qconfig('fbgemm') model = torch.quantization.prepare(model, inplace=False) with torch.no_grad(): for data in calibration_data: model(data) model = torch.quantization.convert(model, inplace=False) return model class StaticQuantPipeline: def __init__(self, model): self.model = model self.calibration_data = None def set_calibration_data(self, data): self.calibration_data = data def run(self, backend='fbgemm'): self.model.qconfig = torch.quantization.get_default_qconfig(backend) self.model = torch.quantization.prepare(self.model) self._calibrate() self.model = torch.quantization.convert(self.model) return self.model def _calibrate(self): if self.calibration_data is None: raise ValueError("Calibration data not set") self.model.eval() with torch.no_grad(): for batch in self.calibration_data: self.model(batch)
2.3 量化感知训练 (QAT)
class QATModel(nn.Module): def __init__(self): super().__init__() self.quant = torch.quantization.QuantStub() self.layers = nn.Sequential( nn.Linear(100, 200), nn.ReLU(), nn.Linear(200, 200), nn.ReLU(), nn.Linear(200, 10) ) self.dequant = torch.quantization.DeQuantStub() def forward(self, x): x = self.quant(x) x = self.layers(x) x = self.dequant(x) return x def train_qat(): model = QATModel() model.train() model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm') model = torch.quantization.prepare_qat(model, inplace=False) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) loss_fn = nn.CrossEntropyLoss() for epoch in range(10): inputs = torch.randn(32, 100) targets = torch.randint(0, 10, (32,)) optimizer.zero_grad() outputs = model(inputs) loss = loss_fn(outputs, targets) loss.backward() optimizer.step() model.eval() model = torch.quantization.convert(model, inplace=False) return model class QATTraining: def __init__(self, model, optimizer, loss_fn): self.model = model self.optimizer = optimizer self.loss_fn = loss_fn def prepare_qat(self, backend='fbgemm'): self.model.qconfig = torch.quantization.get_default_qat_qconfig(backend) self.model = torch.quantization.prepare_qat(self.model) def train_epoch(self, dataloader): self.model.train() for inputs, targets in dataloader: self.optimizer.zero_grad() outputs = self.model(inputs) loss = self.loss_fn(outputs, targets) loss.backward() self.optimizer.step() def finalize(self): self.model.eval() self.model = torch.quantization.convert(self.model) return self.model
2.4 量化模型导出
class QuantizationExporter: @staticmethod def export_onnx(model, input_shape, filepath): model.eval() dummy_input = torch.randn(*input_shape) torch.onnx.export( model, dummy_input, filepath, opset_import=13, do_constant_folding=True, input_names=['input'], output_names=['output'] ) @staticmethod def export_torchscript(model, input_shape, filepath): model.eval() dummy_input = torch.randn(*input_shape) traced_model = torch.jit.trace(model, dummy_input) traced_model.save(filepath) @staticmethod def optimize_for_mobile(model): model.eval() optimized = torch.quantization.convert(model) return torch.jit.script(optimized)
3. 性能对比
3.1 量化方法对比
| 方法 | 精度损失 | 速度提升 | 内存节省 | 实现复杂度 |
|---|
| 动态量化 | 低 | 2x | 4x | 低 |
| 静态量化 | 低 | 3x | 4x | 中 |
| QAT | 极低 | 3x | 4x | 高 |
| FP16 | 极低 | 1.5x | 2x | 低 |
3.2 推理性能对比
| 模型 | FP32延迟 | INT8延迟 | 加速比 |
|---|
| ResNet-18 | 10ms | 3ms | 3.3x |
| MobileNet | 8ms | 2.5ms | 3.2x |
| BERT-base | 50ms | 15ms | 3.3x |
| GPT-2 | 200ms | 60ms | 3.3x |
3.3 量化精度对比
| 模型 | FP32 | PTQ | QAT | 精度损失 |
|---|
| ResNet-18 | 71.8% | 71.5% | 71.7% | PTQ: -0.3% |
| MobileNet | 70.0% | 69.5% | 69.9% | PTQ: -0.5% |
| BERT-base | 82.5% | 81.8% | 82.3% | PTQ: -0.7% |
4. 最佳实践
4.1 量化模型验证
class QuantizationValidator: def __init__(self, fp32_model, quantized_model): self.fp32_model = fp32_model self.quantized_model = quantized_model def validate(self, test_data): self.fp32_model.eval() self.quantized_model.eval() fp32_predictions = [] quantized_predictions = [] with torch.no_grad(): for data in test_data: fp32_pred = self.fp32_model(data).argmax(dim=1) quantized_pred = self.quantized_model(data).argmax(dim=1) fp32_predictions.append(fp32_pred) quantized_predictions.append(quantized_pred) accuracy = sum( torch.equal(f, q) for f, q in zip(fp32_predictions, quantized_predictions) ) / len(test_data) return accuracy
4.2 量化配置选择
def select_qconfig(backend='fbgemm'): if backend == 'fbgemm': return torch.quantization.get_default_qconfig('fbgemm') elif backend == 'qnnpack': return torch.quantization.get_default_qconfig('qnnpack') else: raise ValueError(f"Unknown backend: {backend}") class QuantizationConfig: @staticmethod def for_device(device_type): if device_type == 'cpu': return torch.quantization.get_default_qconfig('fbgemm') elif device_type == 'mobile': return torch.quantization.get_default_qconfig('qnnpack') else: return torch.quantization.get_default_qconfig('fbgemm')
5. 总结
量化是模型优化的重要技术:
- PTQ:快速部署,无需重新训练
- QAT:更高精度,需要重新训练
- 动态量化:仅量化权重,适用于Transformer
- 静态量化:量化权重和激活,适用于CNN
对比数据如下:
- INT8量化可获得3-4倍加速
- 内存占用减少75%
- QAT精度损失小于0.3%
- PTQ精度损失约0.3-0.7%