news 2026/6/10 18:42:42

别再死记ResNet18结构图了!用PyTorch代码逐行拆解,搞懂残差连接到底怎么跑的

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
别再死记ResNet18结构图了!用PyTorch代码逐行拆解,搞懂残差连接到底怎么跑的

用PyTorch代码逐行解析ResNet18:残差连接的数据流动之谜

当你第一次看到ResNet18的结构图时,那些密密麻麻的箭头和方框是否让你感到困惑?实线与虚线有什么区别?1x1卷积到底在做什么?本文将带你用PyTorch代码一步步拆解这个经典网络,让你真正理解残差连接是如何工作的。

1. 残差网络的核心思想

传统的深度神经网络随着层数增加会出现梯度消失网络退化问题。ResNet的创新之处在于引入了残差学习的概念——不再让网络直接学习目标映射,而是学习目标映射与输入之间的残差。

想象一下教小孩投篮:与其让他直接从三分线投进篮筐(难度大),不如先让他站在篮下练习,然后逐步后退。残差学习就是这个原理——网络只需要学习当前输出与理想输出之间的"小差距"。

# 残差块的基本数学表达 output = F(x) + x # F(x)是残差函数,x是恒等映射

这种设计带来了两个关键优势:

  • 梯度可以直接通过恒等映射反向传播,缓解梯度消失
  • 网络可以更容易地学习微小调整,而不是完整的复杂变换

2. ResNet18的整体架构解析

让我们先看看PyTorch官方实现的ResNet18结构:

import torchvision.models as models resnet18 = models.resnet18() print(resnet18)

输出显示网络由以下几部分组成:

  1. 初始卷积层 (conv1)
  2. 批归一化层 (bn1)
  3. ReLU激活函数
  4. 最大池化层 (maxpool)
  5. 四个残差块阶段 (layer1-layer4)
  6. 全局平均池化 (avgpool)
  7. 全连接层 (fc)

关键点:四个残差块阶段分别包含[2, 2, 2, 2]个残差块,共8个残差块。由于每个残差块有2个卷积层,所以卷积层总数为1(初始conv) + 8×2 = 17层,加上最后的全连接层,正好18层。

3. 残差块的代码级解析

PyTorch实现中的基础残差块(BasicBlock)代码如下:

class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None): super(BasicBlock, self).__init__() self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.downsample = downsample self.stride = stride def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out

关键组件解析

组件作用参数说明
conv1第一个3x3卷积stride决定是否下采样
bn1批归一化加速训练,稳定梯度
conv2第二个3x3卷积固定stride=1
downsample下采样模块当维度不匹配时使用

4. 实线与虚线的秘密:维度匹配问题

结构图中的实线和虚线实际上代表了残差连接是否需要处理维度不匹配的情况:

  1. 实线连接:输入和输出维度完全相同,可以直接相加

    • 发生在每个阶段内部的残差块之间
    • 例如:layer1中的两个残差块之间
  2. 虚线连接:当跨阶段时,特征图尺寸减半,通道数翻倍

    • 需要下采样模块(1x1卷积)调整维度
    • 例如:layer1到layer2的过渡
# 下采样模块的典型实现 downsample = nn.Sequential( nn.Conv2d(inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(planes * block.expansion) )

维度变化示例

  • 输入:64通道,112x112
  • 经过stride=2的conv1后:128通道,56x56
  • 恒等映射也需要从64->128通道,112->56尺寸

5. 数据流动的完整追踪

让我们跟踪一个224x224输入图像在ResNet18中的完整旅程:

  1. 初始卷积层

    x = self.conv1(x) # 7x7卷积,stride=2,输出通道64 x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) # 3x3池化,stride=2
    • 尺寸变化:224 -> 112 -> 56
    • 通道变化:3 -> 64
  2. layer1阶段

    • 两个BasicBlock,保持56x56尺寸
    • 实线连接,无需下采样
  3. layer2阶段

    • 第一个BasicBlock使用stride=2
    • 虚线连接,通过1x1卷积下采样
    • 尺寸:56 -> 28
    • 通道:64 -> 128
  4. 后续阶段

    • layer3:28 -> 14,128 -> 256
    • layer4:14 -> 7,256 -> 512
    • 最后通过全局平均池化得到512维向量

6. 常见问题与调试技巧

问题1:维度不匹配错误

  • 检查残差连接两端的张量形状
  • 确保downsample模块正确配置

问题2:训练不稳定

  • 确认批归一化层处于训练模式
  • 检查残差连接是否真的起作用(可以打印中间值)

调试技巧

# 打印各层输出形状的实用函数 def print_shapes(model, input_size=(1,3,224,224)): x = torch.randn(input_size) for name, layer in model.named_children(): x = layer(x) print(f"{name}: {x.shape}")

7. 残差网络的变体与实践建议

ResNet系列有多种变体,区别主要在于:

  • 残差块设计(BasicBlock/Bottleneck)
  • 网络深度(18/34/50/101/152)
  • 注意力机制引入(ResNeXt)

实践建议

  • 对于小数据集,ResNet18通常是足够的选择
  • 当需要更高精度时,可以考虑ResNet50
  • 修改残差块时,务必保持维度匹配原则
# 自定义残差块的示例 class CustomBlock(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() mid_channels = in_channels // 4 self.conv1 = nn.Conv2d(in_channels, mid_channels, 1) self.conv2 = nn.Conv2d(mid_channels, mid_channels, 3, padding=1) self.conv3 = nn.Conv2d(mid_channels, out_channels, 1) self.bn = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU() def forward(self, x): identity = x out = self.relu(self.conv1(x)) out = self.relu(self.conv2(out)) out = self.bn(self.conv3(out)) out += identity return self.relu(out)

理解ResNet的最好方式就是亲手实现它。我在第一次复现时,最大的收获是认识到残差连接实际上创建了多条梯度传播路径,这使得深层网络能够有效训练。当你自己用PyTorch写出这些代码后,那些结构图中的箭头会突然变得清晰明了——它们不再是抽象的符号,而是真实的数据流动路径。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/6/6 22:07:27

实测5种Prompt模板对比100次查询性能

最近一篇关于RAG检索排序的论文给了我一些启发,我想在团队周会上分享我们对Prompt模板对DeepSeek品牌排名的实测结果。房地产行业的案例尤其典型:我们抽样了50家物业公司,覆盖住宅、商业、写字楼等不同物业类型,调研周期30天&…

作者头像 李华
网站建设 2026/6/6 22:04:56

全网首发:高通平台:gstreamer正确切换分辨率或帧率的代码

经过验证,确实可用。缺点:实际上等于重新打开一次摄像头,大约需要3秒。切换代码:WyCameraCtx* pCamera g_oStreamRtsp.pCamera;gst_element_unlink(pCamera->qtiqmmfsrc, g_oStreamRtsp.capsfilter);gst_element_set_state (p…

作者头像 李华