news 2026/4/17 18:05:25

Day49 - CBAM注意力机制

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
Day49 - CBAM注意力机制

1. 简介

CBAM (Convolutional Block Attention Module) 是一种轻量级的注意力模块,它可以无缝集成到任何CNN架构中,通过引入额外的开销来显著提升模型的性能。

与SE (Squeeze-and-Excitation) 模块主要关注通道注意力不同,CBAM 同时结合了通道注意力 (Channel Attention)空间注意力 (Spatial Attention)

这种串联的注意力机制使得网络能够依次学习"关注什么" (What to focus on) 和 "关注哪里" (Where to focus on)。

2. 核心原理

CBAM 包含两个子模块,通常采用串联方式连接(先通道后空间):

2.1 通道注意力模块 (Channel Attention Module, CAM)

通道注意力旨在探索通道之间的依赖关系。CBAM 的通道注意力改进了 SE 模块:

  • 不仅使用全局平均池化 (Average Pooling),还引入了全局最大池化 (Max Pooling)。
  • 认为最大池化能收集到更独特的对象特征,与平均池化互补。
  • 两个池化后的特征向量共享同一个多层感知机 (MLP) 网络。
  • 最终将两个输出相加并通过 Sigmoid 激活函数生成通道权重。

2.2 空间注意力模块 (Spatial Attention Module, SAM)

空间注意力旨在探索特征图在空间维度上的重要性(即哪些区域更重要)。

  • 在通道维度上进行平均池化和最大池化,得到两个 2D 特征图。
  • 将这两个特征图在通道维度拼接 (Concat)。
  • 通过一个 7x7 的卷积层进行特征融合。
  • 通过 Sigmoid 激活函数生成空间权重图。

3. 代码实现

以下是基于 PyTorch 的 CBAM 完整实现,包括通道注意力、空间注意力及其在 CNN 中的集成。

3.1 通道注意力 (ChannelAttention)

import torch import torch.nn as nn class ChannelAttention(nn.Module): def __init__(self, in_channels, ratio=16): super().__init__() # 平均池化和最大池化 self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1) # 共享的全连接层 (MLP) # 使用1x1卷积代替全连接层,减少参数量并保持输入形状 self.fc = nn.Sequential( nn.Linear(in_channels, in_channels // ratio, bias=False), nn.ReLU(), nn.Linear(in_channels // ratio, in_channels, bias=False) ) self.sigmoid = nn.Sigmoid() def forward(self, x): b, c, h, w = x.shape # 平均池化分支 avg_out = self.fc(self.avg_pool(x).view(b, c)) # 最大池化分支 max_out = self.fc(self.max_pool(x).view(b, c)) # 结果相加后经过Sigmoid attention = self.sigmoid(avg_out + max_out).view(b, c, 1, 1) # 权重作用于原特征图 return x * attention

3.2 空间注意力 (SpatialAttention)

class SpatialAttention(nn.Module): def __init__(self, kernel_size=7): super().__init__() # padding计算保证输出大小不变 padding = kernel_size // 2 # 输入通道为2 (AvgPool 1 + MaxPool 1) self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) self.sigmoid = nn.Sigmoid() def forward(self, x): # 在通道维度上求平均 (b, 1, h, w) avg_out = torch.mean(x, dim=1, keepdim=True) # 在通道维度上求最大 (b, 1, h, w) max_out, _ = torch.max(x, dim=1, keepdim=True) # 拼接 (b, 2, h, w) pool_out = torch.cat([avg_out, max_out], dim=1) # 卷积 + Sigmoid attention = self.conv(pool_out) return x * self.sigmoid(attention)

3.3 CBAM 模块组合

class CBAM(nn.Module): def __init__(self, in_channels, ratio=16, kernel_size=7): super().__init__() self.channel_attention = ChannelAttention(in_channels, ratio) self.spatial_attention = SpatialAttention(kernel_size) def forward(self, x): # 串联结构:先通道后空间 x = self.channel_attention(x) x = self.spatial_attention(x) return x

3.4 集成到 CNN 模型

在经典的卷积神经网络中,CBAM 模块通常被放置在卷积层和激活函数之后,或者池化层之前。以下是一个简单的 CBAM-CNN 示例:

class CBAM_CNN(nn.Module): def __init__(self): super().__init__() # 第一层卷积块 self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1) self.bn1 = nn.BatchNorm2d(32) self.relu1 = nn.ReLU() self.pool1 = nn.MaxPool2d(kernel_size=2) self.cbam1 = CBAM(in_channels=32) # 集成CBAM # ... 后续层省略 ... # 假设这里还有更多层 # 全连接层 self.fc1 = nn.Linear(128 * 4 * 4, 512) self.dropout = nn.Dropout(p=0.5) self.fc2 = nn.Linear(512, 10) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu1(x) x = self.pool1(x) x = self.cbam1(x) # 应用注意力机制 # ... 后续前向传播 ... x = x.view(-1, 128 * 4 * 4) x = self.fc1(x) x = self.relu1(x) # 注意这里应该是对应的激活函数 x = self.dropout(x) x = self.fc2(x) return x

4. 训练与实验

在 CIFAR-10 数据集上的训练过程显示,引入 CBAM 后,模型能够更有效地聚焦于图像的关键特征。

  • 优化器: 使用 Adam 优化器,自适应调整学习率。
  • 学习率调度: 使用ReduceLROnPlateau,当验证集损失不再下降时自动降低学习率,有助于模型收敛到更优解。
  • 性能: 在约 50 个 Epoch 的训练中,模型能够达到较高的准确率 (如 86% 左右),证明了注意力机制对特征提取能力的增强作用。

5. 总结

CBAM 通过结合通道注意力和空间注意力,提供了一种即插即用的性能提升方案。

  • 轻量级: 参数量和计算量增加很少。
  • 通用性: 适用于各种 CNN 架构 (ResNet, MobileNet 等)。
  • 互补性: MaxPool 和 AvgPool 的结合保留了更丰富的特征信息。
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/18 3:51:06

pywencai终极指南:轻松玩转同花顺问财数据获取

还在为获取股票数据而烦恼吗?pywencai这款工具将彻底改变你的数据获取方式!作为一款专为Python用户设计的同花顺问财数据获取工具,它能让你用几行代码就轻松获取专业的财经数据,无论是量化研究还是数据分析,都能事半功…

作者头像 李华
网站建设 2026/4/12 6:27:22

无损视频剪辑终极指南:如何用LosslessCut实现专业级编辑

无损视频剪辑终极指南:如何用LosslessCut实现专业级编辑 【免费下载链接】lossless-cut The swiss army knife of lossless video/audio editing 项目地址: https://gitcode.com/gh_mirrors/lo/lossless-cut 还在为视频剪辑后画质下降而烦恼?传统…

作者头像 李华
网站建设 2026/4/17 21:03:25

纯硬件解决方案:CD4511控制数码管的555时钟源设计

用555和CD4511搭建一个“不会死机”的数码管时钟你有没有遇到过这样的问题:单片机控制的数码管显示突然卡住、闪烁,或者程序跑飞后整个系统失灵?尤其是在高温、震动或强干扰环境下,软件方案的脆弱性暴露无遗。今天,我们…

作者头像 李华
网站建设 2026/4/13 6:14:52

Zotero PDF翻译插件集成豆包大模型:解决学术翻译痛点的一站式方案

还在为PDF文献翻译的专业术语不准确而烦恼?豆包大模型凭借其在中文语境下的深度理解能力,为Zotero用户提供精准的学术翻译体验。本文将采用"问题诊断→解决方案→效果验证"的三段式结构,带你深入理解集成豆包大模型的技术路径与实用…

作者头像 李华
网站建设 2026/4/16 14:12:39

Qwen3-4B-Base震撼发布:36万亿 tokens训练的40亿参数大模型

导语:Qwen3系列最新成员Qwen3-4B-Base正式发布,这款拥有40亿参数、经过36万亿tokens训练的基础大模型,凭借创新的三阶段训练架构和32k超长上下文能力,重新定义了中小规模语言模型的性能边界。 【免费下载链接】Qwen3-4B-Base 探索…

作者头像 李华