news 2026/4/18 9:38:07

YOLOv8 改进 - 注意力机制 | EMA(Efficient Multi-Scale Attention)高效多尺度注意力通过跨空间学习增强特征表征

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
YOLOv8 改进 - 注意力机制 | EMA(Efficient Multi-Scale Attention)高效多尺度注意力通过跨空间学习增强特征表征

前言

本文提出了新颖高效的多尺度注意力(EMA)模块,并将其与YOLOv8结合以提升性能。该模块着重保留各通道信息、降低计算开销,通过将部分通道重塑为批量维度、分组通道维度,使空间语义特征分布更均匀。其创新点包括高效多尺度注意力机制、通道维度重塑、并行子网络设计等,在目标检测任务中表现出色。EMA模块结合通道和空间信息,采用多尺度并行子网络结构,优化坐标注意力机制。实验表明,将EMA集成进YOLOv8后,模型在图像分类和目标检测任务上有更好表现。

文章目录: YOLOv8改进大全:卷积层、轻量化、注意力机制、损失函数、Backbone、SPPF、Neck、检测头全方位优化汇总

专栏链接: YOLOv8改进专栏

文章目录

  • 前言
  • 介绍
    • 摘要
    • 创新点
  • 文章链接
  • 基本原理
  • 核心代码
  • 引入代码
  • tasks注册
    • 步骤1:
    • 步骤2
  • 配置yolov8-EMA.yaml
  • 实验
    • 脚本
    • 结果

介绍

摘要

通道或空间注意力机制在许多计算机视觉任务中表现出显著的效果,可以生成更清晰的特征表示。然而,通过通道维度缩减来建模跨通道关系可能会对提取深度视觉表示带来副作用。本文提出了一种新颖高效的多尺度注意力(EMA)模块。该模块着重于保留每个通道的信息并减少计算开销,我们将部分通道重新调整为批次维度,并将通道维度分组为多个子特征,使空间语义特征在每个特征组内分布均匀。具体来说,除了在每个并行分支中对全局信息进行编码以重新校准通道权重外,这两个并行分支的输出特征还通过跨维度交互进一步聚合,以捕捉像素级的成对关系。我们在图像分类和目标检测任务上进行了广泛的消融研究和实验,使用流行的基准数据集(如CIFAR-100、ImageNet-1k、MS COCO和VisDrone2019)来评估其性能。

创新点

  1. 高效的多尺度注意力机制:EMA模块提出了一种高效的多尺度注意力机制,能够同时捕获通道和空间信息,并在不增加太多参数和计算成本的情况下有效地提高特征表示能力。

  2. 通道维度重塑:EMA模块通过将部分通道重塑为批量维度,将通道维度分组为多个子特征,从而使空间语义特征在每个特征组内得到良好分布,提高了特征的表达能力。

  3. 并行子网络设计:EMA模块采用了并行子网络设计,有助于捕获跨维度的交互作用和建立维度间的依赖关系,提高了模型对长距离依赖关系的建模能力。

  4. 性能优越:EMA模块在目标检测任务中表现出色,相较于传统的注意力模块(如CA和CBAM),EMA在保持模型尺寸和计算效率的同时,取得了更好的性能表现,证明了其在提升模型性能方面的有效性和高效性。

  5. 适用性广泛:EMA模块的模型尺寸适中,适合在移动终端上部署,并且在各种计算机视觉任务中都表现出色,具有广泛的应用前景和实际意义。

文章链接

论文地址:论文地址

代码地址:代码地址

基本原理

EMA(Efficient Multi-Scale Attention)模块是一种新颖的高效多尺度注意力机制,旨在提高计算机视觉任务中的特征表示效果。 EMA注意力模块通过结合通道和空间信息、采用多尺度并行子网络结构以及优化坐标注意力机制,实现了更加高效和有效的特征表示,为计算机视觉任务的性能提升提供了重要的技术支持。

  1. 通道和空间注意力的结合:EMA模块通过将通道和空间信息相结合,实现了通道维度的信息保留和降低计算负担。这种结合有助于在特征表示中捕捉跨通道关系,同时避免了通道维度的削减,从而提高了模型的表现效果。

  2. 多尺度并行子网络:EMA模块采用多尺度并行子网络结构,其中包括一个处理1x1卷积核和一个处理3x3卷积核的并行子网络。这种结构有助于有效捕获跨维度交互作用,建立不同维度之间的依赖关系,从而提高特征表示的能力。

  3. 坐标注意力(CA)的再审视:EMA模块在坐标注意力(CA)的基础上进行了改进和优化。CA模块通过将位置信息嵌入通道注意力图中,实现了跨通道和空间信息的融合。EMA模块在此基础上进一步发展,通过并行子网络块有效捕获跨维度交互作用,建立不同维度之间的依赖关系。

  4. 特征聚合和交互:EMA模块通过并行子网络的设计,有助于实现特征的聚合和交互,从而提高模型对长距离依赖关系的建模能力。这种设计避免了更多的顺序处理和大规模深度,使模型更加高效和有效。

下图是结构,其中包括输入、特征重组、通道注意力和输出步骤。

  1. 输入阶段

    • 输入张量形状为( c , h , w ) (c, h, w)(c,h,w),其中 (c) 为通道数,(h) 为高度,(w) 为宽度。
  2. 特征重组阶段

    • 输入张量被划分为( g (g(g) 个组,每组包含( c g ) (\frac{c}{g})(gc)个通道。重组后的张量形状为( g × batch size ) (g \times \text{batch size})(g×batch size)
  3. 特征提取和聚合阶段

    • 特征通过( ( 1 × 1 ) ) ((1 \times 1))((1×1))的卷积层和平均池化层提取并聚合。
    • 聚合后的特征通过一个 Sigmoid 激活函数进行处理,以生成通道权重。
  4. 通道注意力阶段

    • 特征被进一步通过( ( 3 × 3 ) ) ((3 \times 3))((3×3))卷积层和平均池化层提取,得到每个通道的注意力权重。
    • 使用 Softmax 进行归一化处理,确保权重分布合理。
  5. 特征融合阶段

    • 提取到的特征通过矩阵乘法与原始特征进行融合,生成最终的通道注意力特征。
    • 最终特征通过 Sigmoid 激活函数处理,以生成像素级别的注意力图。
  6. 输出阶段

    • 经过以上处理后,生成的特征被重组为原始输入的形状( ( c , h , w ) ) ((c, h, w))((c,h,w)),形成最终输出。

通过这些步骤,实现了对输入特征的多尺度注意力处理,增强了特征表示的辨识能力,同时降低了计算开销。

核心代码

importtorchfromtorchimportnnclassEMA(nn.Module):def__init__(self,channels,c2=None,factor=32):super(EMA,self).__init__()self.groups=factor# 分组数,默认为32assertchannels//self.groups>0# 确保通道数能够被分组数整除self.softmax=nn.Softmax(-1)# 定义 Softmax 层,用于最后一维度的归一化self.agp=nn.AdaptiveAvgPool2d((1,1))# 自适应平均池化,将特征图缩小为1x1self.pool_h=nn.AdaptiveAvgPool2d((None,1))# 自适应平均池化,保留高度维度,将宽度压缩为1self.pool_w=nn.AdaptiveAvgPool2d((1,None))# 自适应平均池化,保留宽度维度,将高度压缩为1self.gn=nn.GroupNorm(channels//self.groups,channels//self.groups)# 分组归一化self.conv1x1=nn.Conv2d(channels//self.groups,channels//self.groups,kernel_size=1,stride=1,padding=0)# 1x1卷积self.conv3x3=nn.Conv2d(channels//self.groups,channels//self.groups,kernel_size=3,stride=1,padding=1)# 3x3卷积defforward(self,x):b,c,h,w=x.size()# 获取输入张量的尺寸:批次、通道、高度、宽度group_x=x.reshape(b*self.groups,-1,h,w)# 将张量按组重构:批次*组数, 通道/组数, 高度, 宽度x_h=self.pool_h(group_x)# 对高度方向进行池化,结果形状为 (b*groups, c//groups, h, 1)x_w=self.pool_w(group_x).permute(0,1,3,2)# 对宽度方向进行池化,并转置结果形状为 (b*groups, c//groups, 1, w)hw=self.conv1x1(torch.cat([x_h,x_w],dim=2))# 将池化后的特征在高度方向拼接后进行1x1卷积x_h,x_w=torch.split(hw,[h,w],dim=2)# 将卷积后的特征分为高度特征和宽度特征x1=self.gn(group_x*x_h.sigmoid()*x_w.permute(0,1,3,2).sigmoid())# 结合高度和宽度特征,应用分组归一化x2=self.conv3x3(group_x)# 对重构后的张量应用3x3卷积x11=self.softmax(self.agp(x1).reshape(b*self.groups,-1,1).permute(0,2,1))# 对 x1 进行自适应平均池化并应用Softmaxx12=x2.reshape(b*self.groups,c//self.groups,-1)# 重构 x2 的形状为 (b*groups, c//groups, h*w)x21=self.softmax(self.agp(x2).reshape(b*self.groups,-1,1).permute(0,2,1))# 对 x2 进行自适应平均池化并应用Softmaxx22=x1.reshape(b*self.groups,c//self.groups,-1)# 重构 x1 的形状为 (b*groups, c//groups, h*w)weights=(torch.matmul(x11,x12)+torch.matmul(x21,x22)).reshape(b*self.groups,1,h,w)# 计算权重,并重构为 (b*groups, 1, h, w)return(group_x*weights.sigmoid()).reshape(b,c,h,w)# 将权重应用于原始张量,并重构为原始输入形状

引入代码

在根目录下的ultralytics/nn/目录,新建一个attention目录,然后新建一个以EMA_attention为文件名的py文件, 把代码拷贝进去。

importtorchfromtorchimportnnclassEMA(nn.Module):def__init__(self,channels,c2=None,factor=32):super(EMA,self).__init__()self.groups=factorassertchannels//self.groups>0self.softmax=nn.Softmax(-1)self.agp=nn.AdaptiveAvgPool2d((1,1))self.pool_h=nn.AdaptiveAvgPool2d((None,1))self.pool_w=nn.AdaptiveAvgPool2d((1,None))self.gn=nn.GroupNorm(channels//self.groups,channels//self.groups)self.conv1x1=nn.Conv2d(channels//self.groups,channels//self.groups,kernel_size=1,stride=1,padding=0)self.conv3x3=nn.Conv2d(channels//self.groups,channels//self.groups,kernel_size=3,stride=1,padding=1)defforward(self,x):b,c,h,w=x.size()group_x=x.reshape(b*self.groups,-1,h,w)# b*g,c//g,h,wx_h=self.pool_h(group_x)x_w=self.pool_w(group_x).permute(0,1,3,2)hw=self.conv1x1(torch.cat([x_h,x_w],dim=2))x_h,x_w=torch.split(hw,[h,w],dim=2)x1=self.gn(group_x*x_h.sigmoid()*x_w.permute(0,1,3,2).sigmoid())x2=self.conv3x3(group_x)x11=self.softmax(self.agp(x1).reshape(b*self.groups,-1,1).permute(0,2,1))x12=x2.reshape(b*self.groups,c//self.groups,-1)# b*g, c//g, hwx21=self.softmax(self.agp(x2).reshape(b*self.groups,-1,1).permute(0,2,1))x22=x1.reshape(b*self.groups,c//self.groups,-1)# b*g, c//g, hwweights=(torch.matmul(x11,x12)+torch.matmul(x21,x22)).reshape(b*self.groups,1,h,w)return(group_x*weights.sigmoid()).reshape(b,c,h,w)

tasks注册

ultralytics/nn/tasks.py中进行如下操作:

步骤1:

fromultralytics.nn.attention.EMA_attentionimportEMA

步骤2

修改def parse_model(d, ch, verbose=True):

elifmin{EMA}:args=[ch[f],*args]

配置yolov8-EMA.yaml

‘ultralytics/cfg/models/v8/yolov8-EMA.yaml’

# Ultralytics YOLO 🚀, GPL-3.0 license# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parametersnc:2# number of classesscales:# model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'# [depth, width, max_channels]n:[0.33,0.25,1024]# YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPss:[0.33,0.50,1024]# YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPsm:[0.67,0.75,768]# YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPsl:[1.00,1.00,512]# YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPsx:[1.00,1.25,512]# YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs# YOLOv8.0n backbonebackbone:# [from, repeats, module, args]-[-1,1,Conv,[64,3,2]]# 0-P1/2-[-1,1,Conv,[128,3,2]]# 1-P2/4-[-1,3,C2f,[128,True]]-[-1,1,Conv,[256,3,2]]# 3-P3/8-[-1,6,C2f,[256,True]]-[-1,1,Conv,[512,3,2]]# 5-P4/16-[-1,6,C2f,[512,True]]-[-1,1,Conv,[1024,3,2]]# 7-P5/32-[-1,3,C2f,[1024,True]]-[-1,1,SPPF,[1024,5]]# 9# YOLOv8.0n headhead:-[-1,1,nn.Upsample,[None,2,'nearest']]-[[-1,6],1,Concat,[1]]# cat backbone P4-[-1,3,C2f,[512]]# 12-[-1,1,nn.Upsample,[None,2,'nearest']]-[[-1,4],1,Concat,[1]]# cat backbone P3-[-1,3,C2f,[256]]# 15 (P3/8-small)-[-1,1,EMA,[256]]# 16 (P5/32-large)-[-1,1,Conv,[256,3,2]]-[[-1,12],1,Concat,[1]]# cat head P4-[-1,3,C2f,[512]]# 19 (P4/16-medium)-[-1,1,EMA,[512]]# 20 (P5/32-large)-[-1,1,Conv,[512,3,2]]-[[-1,9],1,Concat,[1]]# cat head P5-[-1,3,C2f,[1024]]# 23 (P5/32-large)-[-1,1,EMA,[1024]]# 24 (P5/32-large)-[[16,20,24],1,Detect,[nc]]# Detect(P3, P4, P5)

实验

脚本

importosfromultralyticsimportYOLO yaml='ultralytics/cfg/models/v8/yolov8-EMA.yaml'model=YOLO(yaml)model.info()if__name__=="__main__":results=model.train(data='coco128.yaml',name='EMA',epochs=10,workers=8,batch=1)

结果

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/18 0:21:22

直线导轨限速受哪些因素影响?负载预压环境关联分析

“直线导轨的限速不是一个固定值,它受多种因素影响——这是很多客户容易忽略的点。作为台湾HIWIN集团正式授权专属经销商,深圳市海威机电有限公司今天就来分析影响直线导轨限速的三个核心因素:负载、预压、环境,帮你找到最适合的限…

作者头像 李华
网站建设 2026/4/18 6:30:37

Javaweb房产销售管理系统

可s我领取源码!JavaWeb 房产销售管理系统是一款专门为房地产销售企业设计的综合性管理平台,旨在提升房产销售流程的效率与透明度,实现房源、客户、销售团队等多方面的高效管理。借助 Java Web 技术,该系统能够为企业提供稳定、可靠且功能丰富…

作者头像 李华
网站建设 2026/4/18 6:31:07

8个AI论文工具,专科生轻松搞定毕业写作!

8个AI论文工具,专科生轻松搞定毕业写作! AI 工具让论文写作不再难 对于专科生来说,毕业论文可能是大学生活中最令人头疼的任务之一。从选题到开题,再到撰写和降重,每一步都充满了挑战。而随着 AI 技术的不断发展&#…

作者头像 李华
网站建设 2026/4/18 5:32:52

蓝桥杯 嵌入式 客观题 [1000道]第二期 持续更新中

1. 在蓝桥杯嵌入式竞赛常用的CT117E-M4开发板上,为了控制LED灯(LD1~LD8),使用了74HC573锁存器配合74LS138译码器进行片选。若要选通控制LED的锁存器(通常连接在Y4),则74LS138的输入端 A2, A1, A…

作者头像 李华
网站建设 2026/4/18 8:39:12

【MongoDB实战】7.3 批量操作优化:BulkWrite

文章目录 7.3 批量操作优化:BulkWrite 前置准备 1. 环境要求 2. 基础连接代码 7.3.1 循环单条操作vs批量操作:性能差异对比 核心差异 实战性能对比(测试10000条插入) 典型输出结果(参考) 差异原因分析 7.3.2 BulkWrite实战:批量插入、更新、删除组合操作 核心语法 实战:…

作者头像 李华