news 2026/5/16 9:28:08

从简单CNN到ResNet18:我是如何一步步把MNIST手写数字识别准确率刷到99.5%以上的

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
从简单CNN到ResNet18:我是如何一步步把MNIST手写数字识别准确率刷到99.5%以上的

从简单CNN到ResNet18:我是如何一步步把MNIST手写数字识别准确率刷到99.5%以上的

第一次用CNN跑MNIST时,看着测试集上98%的准确率还挺满意。直到在Kaggle上看到有人用相同数据集跑出99.5%+的成绩,才发现自己连入门级数据集的潜力都没榨干。这就像以为掌握了加减乘除就能解微积分——深度学习的水远比想象中深。经过两个月的反复实验,终于让模型突破了99.5%大关,整个过程堪称一部"调参侠"的进化史。

1. 基础CNN的瓶颈与突破

初始的CNN架构简单得可怜:两个卷积层夹着ReLU和最大池化,最后接全连接层。这个模型在10个epoch后就稳定在98.1%左右,典型的"早熟"表现。通过TensorBoard可视化发现,验证集准确率在第5轮后就几乎走平,说明模型容量根本不够。

第一批改进方案:

# 关键改进点代码示例 class EnhancedCNN(nn.Module): def __init__(self): super().__init__() self.block1 = nn.Sequential( nn.Conv2d(1, 32, 5, padding=2), # 保持特征图尺寸 nn.BatchNorm2d(32), # 新增批归一化 nn.ReLU(inplace=True), nn.MaxPool2d(2)) self.block2 = nn.Sequential( nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), # 新增批归一化 nn.ReLU(inplace=True), nn.MaxPool2d(2)) self.classifier = nn.Sequential( nn.Flatten(), nn.Dropout(0.5), # 新增Dropout nn.Linear(64*7*7, 10))

调整后的模型出现了几个明显变化:

  • 通道数从[10,20]扩展到[32,64],增强特征提取能力
  • 添加BatchNorm层后,学习率可以提升3倍而不发散
  • 引入Dropout后训练集准确率下降,但验证集提升0.6%

注意:BatchNorm一定要放在卷积层和激活函数之间,这个顺序错误会导致效果大打折扣

验证集准确率变化:

改进措施准确率提升训练时间增幅
基础CNN98.1%-
+BatchNorm+0.9%+15%
+通道扩展+0.7%+25%
+Dropout+0.6%可忽略

2. 数据增强的艺术

当模型在原始数据上达到98.7%后,我开始在数据层面寻找突破点。MNIST的简单特性决定了不能使用太激进的数据增强,经过反复测试,最终确定了最佳组合:

transform = transforms.Compose([ transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)), # 微小平移 transforms.RandomRotation((-5, 5)), # 小角度旋转 transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])

数据增强效果对比:

  • 纯平移增强:+0.25%
  • 纯旋转增强:+0.18%
  • 组合增强:+0.42%
  • 添加弹性变形:-0.3%(过犹不及)

有趣的是,当增强幅度过大时(如旋转±15度),模型准确率反而下降。这是因为MNIST数字的形态特征比自然图像更敏感,过度变形会让"9"变得像"4"、"7"像"1"。

3. 学习率动态调整策略

固定学习率就像用固定速度爬山——平缓处太慢,陡峭处又容易翻车。尝试了三种动态调整方案:

  1. StepLR:每30个epoch乘以0.1
    scheduler = lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
  2. Cosine退火
    scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
  3. ReduceLROnPlateau
    scheduler = lr_scheduler.ReduceLROnPlateau( optimizer, mode='max', factor=0.5, patience=3)

实验结果:

  • StepLR在中期表现最好,但后期下降过快
  • Cosine退火整体平稳,但最高点不如ReduceLROnPlateau
  • ReduceLROnPlateau最终达到99.23%,是最佳选择

提示:监控验证集准确率而非训练损失作为调整依据,这样更可靠

4. 残差连接的降维打击

当传统CNN改进陷入瓶颈时,ResNet18带来了质的飞跃。但直接将ImageNet的架构用于MNIST会过犹不及,需要做针对性调整:

class MNISTResNet(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(1, 16, 3, padding=1) # 输入通道改为1 self.bn1 = nn.BatchNorm2d(16) self.relu = nn.ReLU(inplace=True) # 简化版的残差块 self.layer1 = self._make_layer(16, 16, 2) self.layer2 = self._make_layer(16, 32, 2, stride=2) self.layer3 = self._make_layer(32, 64, 2, stride=2) self.avgpool = nn.AdaptiveAvgPool2d((1,1)) self.fc = nn.Linear(64, 10) def _make_layer(self, in_channels, out_channels, blocks, stride=1): downsample = None if stride != 1 or in_channels != out_channels: downsample = nn.Sequential( nn.Conv2d(in_channels, out_channels, 1, stride), nn.BatchNorm2d(out_channels)) layers = [] layers.append(ResidualBlock(in_channels, out_channels, stride, downsample)) for _ in range(1, blocks): layers.append(ResidualBlock(out_channels, out_channels)) return nn.Sequential(*layers)

关键改进点:

  • 将原始ResNet18的4层残差块减为3层
  • 初始卷积核从7x7改为3x3
  • 最终平均池化层输出尺寸设为1x1
  • 通道数缩减为[16,32,64]以适应小图像

性能对比:

模型类型参数量测试准确率训练时间(epoch)
增强版CNN1.2M99.23%45min
简化ResNet180.8M99.47%68min
标准ResNet1811.2M99.31%2.5h

5. 突破99.5%的终极组合

最终的突破来自多个微创新的叠加效应:

  1. 权重初始化:改用Kaiming初始化

    def init_weights(m): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') model.apply(init_weights)
  2. 优化器切换:从SGD改为RMSprop

    optimizer = torch.optim.RMSprop(model.parameters(), lr=0.001, alpha=0.99)
  3. 标签平滑:缓解过拟合

    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
  4. 测试时增强:对测试图像做5次随机变换取平均

    def predict(image): model.eval() outputs = [] for _ in range(5): aug_img = test_transform(image) # 包含随机变换 outputs.append(model(aug_img.unsqueeze(0))) return torch.mean(torch.stack(outputs), dim=0)

最终在测试集上的准确率曲线呈现出有趣的规律:每当引入一个新技巧,准确率就会上一个台阶,但提升幅度越来越小。从98%到99%相对容易,但从99%到99.5%需要付出十倍努力——这大概就是深度学习的边际效应吧。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/5/16 9:27:23

Arm Neoverse CMN-650错误处理机制详解

1. Arm Neoverse CMN-650错误处理机制概述在现代计算架构中,错误处理机制是确保系统可靠性的基石。Arm Neoverse CMN-650作为高性能互连网络,其错误处理寄存器组为系统级容错提供了硬件级支持。这套机制不仅能检测和纠正瞬时性错误,还能处理永…

作者头像 李华
网站建设 2026/5/16 9:26:21

Heightmapper终极教程:5步创建专业3D地形高度图的免费方案

Heightmapper终极教程:5步创建专业3D地形高度图的免费方案 【免费下载链接】heightmapper interactive heightmaps from terrain data 项目地址: https://gitcode.com/gh_mirrors/he/heightmapper 还在为3D地形建模而烦恼吗?Heightmapper是一款免…

作者头像 李华
网站建设 2026/5/16 9:23:07

构建现代化个人知识库:从信息孤岛到互联研究金库

1. 项目概述:从“信息孤岛”到“个人研究金库”如果你和我一样,常年混迹于学术圈、技术社区或者任何一个需要深度信息处理的领域,那么你一定对下面这个场景深恶痛绝:为了一个研究课题,你打开了十几个浏览器标签页&…

作者头像 李华
网站建设 2026/5/16 9:22:15

把旧路由器改造成远程ADB调试服务器:OpenWrt安装adb与公网访问指南

旧路由器变身远程ADB调试服务器:OpenWrt实战指南 在移动应用开发过程中,频繁连接USB数据线进行调试不仅效率低下,更限制了开发者的工作灵活性。想象一下,当你需要同时调试多台设备,或者在不同网络环境下快速切换测试场…

作者头像 李华