news 2026/5/11 15:45:57

Day 46 通道注意力机制

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
Day 46 通道注意力机制

在深度学习中,注意力机制(Attention Mechanism) 是让模型学会“关注重点”的方法。正如人类在看图时会自动聚焦于主体(如猫、车、人脸),而忽略背景,模型也希望学会同样的能力。

常见的通道注意力

其中,以SE 注意力最为经典

import torch import torch.nn as nn # Squeeze-and-Excitation 模块实现 class SEBlock(nn.Module): def __init__(self, in_channels, reduction=16): super(SEBlock, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) # Squeeze self.fc = nn.Sequential( nn.Linear(in_channels, in_channels // reduction, bias=False), nn.ReLU(inplace=True), nn.Linear(in_channels // reduction, in_channels, bias=False), nn.Sigmoid() ) def forward(self, x): b, c, _, _ = x.size() y = self.avg_pool(x).view(b, c) y = self.fc(y).view(b, c, 1, 1) return x * y.expand_as(x)

通道注意力在CNN中的应用:

class SimpleCNN(nn.Module): def __init__(self, num_classes=10): super(SimpleCNN, self).__init__() self.layer1 = nn.Sequential( nn.Conv2d(3, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), SEBlock(64) # 加入通道注意力模块 ) self.layer2 = nn.Sequential( nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), SEBlock(128) ) self.fc = nn.Linear(128, num_classes) def forward(self, x): out = self.layer1(x) out = self.layer2(out) out = torch.mean(out, dim=[2, 3]) # Global AvgPool out = self.fc(out) return out model = SimpleCNN() print(model)
SE 注意力后的特征图和热力图
if __name__ == "__main__": # 1. 加载图片 img_tensor, img = load_image() # 2. 初始化两个CNN(加SE和不加SE) cnn_no_se = SimpleCNN() cnn_with_se = CNNWithSE() cnn_no_se.eval() cnn_with_se.eval() # 3. 提取特征 feat1_no_se, feat2_no_se = cnn_no_se(img_tensor) feat1_with_se, feat2_with_se = cnn_with_se(img_tensor) # 4. 可视化对比(第一层特征图) visualize_feat(feat1_no_se, "不加SE的第一层特征图") visualize_feat(feat1_with_se, "加SE的第一层特征图")
# 可视化空间注意力热力图(显示模型关注的图像区域) def visualize_attention_map(model, test_loader, device, class_names, num_samples=3): """可视化模型的注意力热力图,展示模型关注的图像区域""" model.eval() # 设置为评估模式 with torch.no_grad(): for i, (images, labels) in enumerate(test_loader): if i >= num_samples: # 只可视化前几个样本 break images, labels = images.to(device), labels.to(device) # 创建一个钩子,捕获中间特征图 activation_maps = [] def hook(module, input, output): activation_maps.append(output.cpu()) # 为最后一个卷积层注册钩子(获取特征图) hook_handle = model.conv3.register_forward_hook(hook) # 前向传播,触发钩子 outputs = model(images) # 移除钩子 hook_handle.remove() # 获取预测结果 _, predicted = torch.max(outputs, 1) # 获取原始图像 img = images[0].cpu().permute(1, 2, 0).numpy() # 反标准化处理 img = img * np.array([0.2023, 0.1994, 0.2010]).reshape(1, 1, 3) + np.array([0.4914, 0.4822, 0.4465]).reshape(1, 1, 3) img = np.clip(img, 0, 1) # 获取激活图(最后一个卷积层的输出) feature_map = activation_maps[0][0].cpu() # 取第一个样本 # 计算通道注意力权重(使用SE模块的全局平均池化) channel_weights = torch.mean(feature_map, dim=(1, 2)) # [C] # 按权重对通道排序 sorted_indices = torch.argsort(channel_weights, descending=True) # 创建子图 fig, axes = plt.subplots(1, 4, figsize=(16, 4)) # 显示原始图像 axes[0].imshow(img) axes[0].set_title(f'原始图像\n真实: {class_names[labels[0]]}\n预测: {class_names[predicted[0]]}') axes[0].axis('off') # 显示前3个最活跃通道的热力图 for j in range(3): channel_idx = sorted_indices[j] # 获取对应通道的特征图 channel_map = feature_map[channel_idx].numpy() # 归一化到[0,1] channel_map = (channel_map - channel_map.min()) / (channel_map.max() - channel_map.min() + 1e-8) # 调整热力图大小以匹配原始图像 from scipy.ndimage import zoom heatmap = zoom(channel_map, (32/feature_map.shape[1], 32/feature_map.shape[2])) # 显示热力图 axes[j+1].imshow(img) axes[j+1].imshow(heatmap, alpha=0.5, cmap='jet') axes[j+1].set_title(f'注意力热力图 - 通道 {channel_idx}') axes[j+1].axis('off') plt.tight_layout() plt.show() # 调用可视化函数 visualize_attention_map(model, test_loader, device, class_names, num_samples=3)

@浙大疏锦行

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/5/11 6:22:49

LVGL移植驱动开发:手把手教程(基于STM32)

从零开始移植LVGL到STM32:一个嵌入式工程师的实战手记最近接手了一个工业HMI项目,客户要求在一块3.5寸TFT屏上实现流畅的图形界面。没有选择TouchGFX——不是它不好,而是成本和授权问题让小团队望而却步。最终我们选了LVGL,开源、…

作者头像 李华
网站建设 2026/4/28 9:39:37

Keil4安装超详细版:驱动与注册机处理全解析

Keil4 安装实战指南:从驱动配置到授权激活的完整解决方案 在嵌入式开发的世界里, Keil Vision4 (简称 Keil4)虽然不是最新版本,但至今仍是许多工程师手中的“主力工具”。尤其是在维护老旧项目、适配经典 STM32 芯片…

作者头像 李华
网站建设 2026/4/21 10:59:21

通义千问3-14B模型压缩:知识蒸馏的应用案例

通义千问3-14B模型压缩:知识蒸馏的应用案例 1. 引言:大模型轻量化的现实需求 随着大语言模型在推理能力、上下文长度和多语言支持等方面的持续突破,其参数规模也迅速攀升。然而,高性能往往伴随着高昂的部署成本。以百亿级参数模…

作者头像 李华
网站建设 2026/5/1 4:12:24

VibeThinker-1.5B部署全流程:从镜像拉取到网页调用

VibeThinker-1.5B部署全流程:从镜像拉取到网页调用 1. 引言 随着大模型技术的快速发展,小型参数模型在特定任务上的高效推理能力逐渐受到关注。VibeThinker-1.5B 是微博开源的一款小参数语言模型,拥有15亿参数,专为数学推理与编…

作者头像 李华
网站建设 2026/5/9 1:23:58

Hunyuan HY-MT1.5-1.8B部署教程:3步完成vLLM服务启动

Hunyuan HY-MT1.5-1.8B部署教程:3步完成vLLM服务启动 1. 模型介绍与技术背景 1.1 HY-MT1.5-1.8B 模型概述 混元翻译模型 1.5 版本(Hunyuan MT 1.5)包含两个核心模型:HY-MT1.5-1.8B 和 HY-MT1.5-7B,分别拥有 18 亿和…

作者头像 李华
网站建设 2026/5/9 12:12:15

PyTorch-2.x镜像使用指南:ipykernel配置多环境教程

PyTorch-2.x镜像使用指南:ipykernel配置多环境教程 1. 环境介绍与核心特性 本镜像为 PyTorch-2.x-Universal-Dev-v1.0,基于官方最新稳定版 PyTorch 构建,专为深度学习开发场景优化。系统经过精简处理,移除冗余缓存和无用依赖&am…

作者头像 李华