前言
本文介绍了Token Statistics Self-Attention(TSSA)机制,并将其集成到YOLO26中。传统自注意力计算复杂度高,TSSA进行了范式转变,基于token统计特征实现高效注意力交互。它通过“算法展开”推导得出,以“最大编码率降低”为目标,实现特征学习。TSSA包含动态分组和低秩投影优化两步创新,具备线性复杂度。我们将TSSA代码集成到YOLO26的C2PSA模块中。实验表明,改进后的YOLO26在目标检测任务中表现良好,验证了TSSA机制的有效性。
文章目录: YOLO26改进大全:卷积层、轻量化、注意力机制、损失函数、Backbone、SPPF、Neck、检测头全方位优化汇总
专栏链接: YOLO26改进专栏
文章目录
- 前言
- 介绍
- 摘要
- 文章链接
- 基本原理
- 1. 从“逐对对比”到“统计聚合”的范式转变
- 2. 基于“白盒设计”的目标导向优化
- 3. 数据驱动的低秩投影与动态分组
- YOLO26引入代码
- 注册
- 步骤1:
- 步骤2
- 配置yolo26-C2PSA_TSSA.yaml
- 实验
- 脚本
- 结果
介绍
摘要
注意力算子可以说是 Transformer 架构的关键特征,该架构在多种任务中都表现出了最先进的性能。然而,Transformer 的注意力算子通常会带来巨大的计算负担,其计算复杂度随 Token 数量呈二次方增长。在这项工作中,我们提出了一种新型的 Transformer 注意力算子,其计算复杂度随 Token 数量呈线性增长。我们将之前的研究成果进行了扩展,之前的研究表明,通过“白盒”架构设计可以自然地构建出 Transformer 风格的架构,即网络的每一层都被设计为实现最大编码率降低目标(M C R 2 MCR^{2}MCR2)的一个增量优化步骤。具体来说,我们推导了M C R 2 MCR^{2}MCR2目标的一种新颖变分形式,并展示了基于该变分目标进行展开梯度下降所得到的架构,导出了一种新的注意力模块,称为Token 统计自注意力(Token Statistics Self-Attention,TSSA)。TSSA 具有线性的计算和内存复杂度,并且与计算 Token 之间成对相似度的典型注意力架构截然不同。在视觉、语言和长序列任务上的实验表明,只需简单地用 TSSA 替换标准自注意力(我们将这种架构称为Token 统计 Transformer,即 TOST),就能获得与传统 Transformer 相当的性能,同时计算效率更高且更具可解释性。我们的结果还在一定程度上质疑了“成对相似度风格的注意力机制是 Transformer 架构成功的关键”这一传统观念。代码将在 https://github.com/RobinWu218/ToST 开源。
文章链接
论文地址:论文地址
代码地址:代码地址
基本原理
TSSA(Token Statistics Self-Attention)的核心创新是彻底抛弃传统自注意力的“成对相似度计算”,转而基于token的统计特征实现高效注意力交互 :
1. 从“逐对对比”到“统计聚合”的范式转变
传统自注意力需要计算所有token两两之间的相似度(如缩放点积),导致复杂度随token数量呈平方增长。TSSA跳出这一框架,认为注意力的本质是“基于数据关联的特征优化”,而这种关联无需逐对计算——只需捕捉token群体的统计规律(即“二阶矩”,可理解为token特征的分布集中程度),就能实现类似的特征聚合效果。
2. 基于“白盒设计”的目标导向优化
TSSA并非经验性设计,而是通过“算法展开”的白盒思路推导得出:以“最大编码率降低(MCR²)”为核心目标,先将该目标转化为更易计算的变分形式,再把优化过程拆分成网络的逐层操作。每一层的作用都是增量优化这个目标——让同一组内的token特征更集中(压缩),同时让所有token的整体特征更分散(扩展),最终实现 discriminative 特征学习。
3. 数据驱动的低秩投影与动态分组
TSSA的核心操作包含两步关键创新:
- 动态分组:通过计算token与不同子空间的匹配度,用软聚类(类似概率分配)将token分到K个组,无需人工定义分组规则,完全由数据自动决定。
- 低秩投影优化:对每个组,基于token特征的统计信息构建“重要性权重”,保留组内特征中“能量集中”(即多数token共同拥有)的方向,抑制冗余或噪声方向。这一过程不依赖任何成对相似度,仅通过矩阵投影和统计计算完成,天然具备线性复杂度。
YOLO26引入代码
在根目录下的ultralytics/nn/目录,新建一个C2PSA目录,然后新建一个以C2PSA_TSSA为文件名的py文件, 把代码拷贝进去。
importtorchimporttorch.nnasnnfromeinopsimportrearrangeclassAttentionTSSA(nn.Module):def__init__(self,dim,num_heads=8,qkv_bias=False,qk_scale=None,attn_drop=0.,proj_drop=0.):super().__init__()self.heads=num_heads self.dim=dim head_dim=dim//num_heads self.attend=nn.Softmax(dim=1)self.attn_drop=nn.Dropout(attn_drop)self.qkv=nn.Linear(dim,dim,bias=qkv_bias)self.temp=nn.Parameter(torch.ones(num_heads,1))self.to_out=nn.Sequential(nn.Linear(dim,dim),nn.Dropout(proj_drop))defforward(self,x):# x: (B, C, H, W) - standard attention interfaceB,C,H,W=x.shape N=H*W x_flat=x.view(B,C,N).permute(0,2,1)# (B, N, C)# Apply linear projection and reshape for multi-headw=self.qkv(x_flat)# (B, N, C)w=w.view(B,N,self.heads,C//self.heads).permute(0,2,1,3)# (B, h, N, d)w_normed=torch.nn.functional.normalize(w,dim=-2)w_sq=w_normed**2# Pi from Eq. 10 in the paperPi=self.attend(torch.sum(w_sq,dim=-1)*self.temp)# b * h * ndots=torch.matmul((Pi/(Pi.sum(dim=-1,keepdim=True)+1e-8)).unsqueeze(-2),w**2)attn=1./(1+dots)attn=self.attn_drop(attn)out=-torch.mul(w.mul(Pi.unsqueeze(-1)),attn)out=rearrange(out,'b h n d -> b n (h d)')out=self.to_out(out)# Reshape back to (B, C, H, W)out=out.permute(0,2,1).view(B,C,H,W)returnout@torch.jit.ignoredefno_weight_decay(self):return{'temp'}classAttention(nn.Module):def__init__(self,dim:int,num_heads:int=8,attn_ratio:float=0.5):super().__init__()self.num_heads=num_heads self.head_dim=dim//num_heads self.key_dim=int(self.head_dim*attn_ratio)self.scale=self.key_dim**-0.5nh_kd=self.key_dim*num_heads h=dim+nh_kd*2self.qkv=Conv(dim,h,1,act=False)self.proj=Conv(dim,dim,1,act=False)self.pe=Conv(dim,dim,3,1,g=dim,act=False)defforward(self,x:torch.Tensor)->torch.Tensor:B,C,H,W=x.shape N=H*W qkv=self.qkv(x)q,k,v=qkv.view(B,self.num_heads,self.key_dim*2+self.head_dim,N).split([self.key_dim,self.key_dim,self.head_dim],dim=2)attn=(q.transpose(-2,-1)@ k)*self.scale attn=attn.softmax(dim=-1)x=(v @ attn.transpose(-2,-1)).view(B,C,H,W)+self.pe(v.reshape(B,C,H,W))x=self.proj(x)returnxdefautopad(k,p=None,d=1):"""Pad to 'same' shape outputs."""ifd>1:k=d*(k-1)+1ifisinstance(k,int)else[d*(x-1)+1forxink]ifpisNone:p=k//2ifisinstance(k,int)else[x//2forxink]# auto-padreturnpclassConv(nn.Module):default_act=nn.SiLU()def__init__(self,c1,c2,k=1,s=1,p=None,g=1,d=1,act=True):super().__init__()self.conv=nn.Conv2d(c1,c2,k,s,autopad(k,p,d),groups=g,dilation=d,bias=False)self.bn=nn.BatchNorm2d(c2)self.act=(self.default_actifactisTrueelseactifisinstance(act,nn.Module)elsenn.Identity())defforward(self,x):c=self.conv(x)c=self.bn(c)c=self.act(c)returncclassPSABlock(nn.Module):def__init__(self,c:int,attn_ratio:float=0.5,num_heads:int=4,shortcut:bool=True)->None:super().__init__()self.attn=Attention(c,attn_ratio=attn_ratio,num_heads=num_heads)self.ffn=nn.Sequential(Conv(c,c*2,1),Conv(c*2,c,1,act=False))self.add=shortcutdefforward(self,x:torch.Tensor)->torch.Tensor:x=x+self.attn(x)ifself.addelseself.attn(x)x=x+self.ffn(x)ifself.addelseself.ffn(x)returnxclassC2PSA(nn.Module):def__init__(self,c1,c2,n=1,e=0.5):super().__init__()assertc1==c2 self.c=int(c1*e)self.cv1=Conv(c1,2*self.c,1,1)self.cv2=Conv(2*self.c,c1,1)self.m=nn.Sequential(*(PSABlock(self.c,attn_ratio=0.5,num_heads=self.c//64)for_inrange(n)))defforward(self,x):a,b=self.cv1(x).split((self.c,self.c),dim=1)b=self.m(b)returnself.cv2(torch.cat((a,b),1))classPSABlock_AttentionTSSA(PSABlock):def__init__(self,c:int,attn_ratio:float=0.5,num_heads:int=4,shortcut:bool=True)->None:super().__init__(c,attn_ratio,num_heads,shortcut)self.attn=AttentionTSSA(c)classC2PSA_TSSA(C2PSA):def__init__(self,c1:int,c2:int,n:int=1,e:float=0.5):super().__init__(c1,c2,n,e)self.m=nn.Sequential(*(PSABlock_AttentionTSSA(self.c,attn_ratio=0.5,num_heads=self.c//64)for_inrange(n)))注册
在ultralytics/nn/tasks.py中进行如下操作:
步骤1:
fromultralytics.nn.C2PSA.C2PSA_TSSAimportC2PSA_TSSA步骤2
修改def parse_model(d, ch, verbose=True):
C2PSA_TSSA配置yolo26-C2PSA_TSSA.yaml
ultralytics/cfg/models/26/yolo26-C2PSA_TSSA.yaml
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license# Ultralytics YOLO26 object detection model with P3/8 - P5/32 outputs# Model docs: https://docs.ultralytics.com/models/yolo26# Task docs: https://docs.ultralytics.com/tasks/detect# Parametersnc:80# number of classesend2end:True# whether to use end-to-end modereg_max:1# DFL binsscales:# model compound scaling constants, i.e. 'model=yolo26n.yaml' will call yolo26.yaml with scale 'n'# [depth, width, max_channels]n:[0.50,0.25,1024]# summary: 260 layers, 2,572,280 parameters, 2,572,280 gradients, 6.1 GFLOPss:[0.50,0.50,1024]# summary: 260 layers, 10,009,784 parameters, 10,009,784 gradients, 22.8 GFLOPsm:[0.50,1.00,512]# summary: 280 layers, 21,896,248 parameters, 21,896,248 gradients, 75.4 GFLOPsl:[1.00,1.00,512]# summary: 392 layers, 26,299,704 parameters, 26,299,704 gradients, 93.8 GFLOPsx:[1.00,1.50,512]# summary: 392 layers, 58,993,368 parameters, 58,993,368 gradients, 209.5 GFLOPs# YOLO26n backbonebackbone:# [from, repeats, module, args]-[-1,1,Conv,[64,3,2]]# 0-P1/2-[-1,1,Conv,[128,3,2]]# 1-P2/4-[-1,2,C3k2,[256,False,0.25]]-[-1,1,Conv,[256,3,2]]# 3-P3/8-[-1,2,C3k2,[512,False,0.25]]-[-1,1,Conv,[512,3,2]]# 5-P4/16-[-1,2,C3k2,[512,True]]-[-1,1,Conv,[1024,3,2]]# 7-P5/32-[-1,2,C3k2,[1024,True]]-[-1,1,SPPF,[1024,5,3,True]]# 9-[-1,2,C2PSA_TSSA,[1024]]# 10# YOLO26n headhead:-[-1,1,nn.Upsample,[None,2,"nearest"]]-[[-1,6],1,Concat,[1]]# cat backbone P4-[-1,2,C3k2,[512,True]]# 13-[-1,1,nn.Upsample,[None,2,"nearest"]]-[[-1,4],1,Concat,[1]]# cat backbone P3-[-1,2,C3k2,[256,True]]# 16 (P3/8-small)-[-1,1,Conv,[256,3,2]]-[[-1,13],1,Concat,[1]]# cat head P4-[-1,2,C3k2,[512,True]]# 19 (P4/16-medium)-[-1,1,Conv,[512,3,2]]-[[-1,10],1,Concat,[1]]# cat head P5-[-1,1,C3k2,[1024,True,0.5,True]]# 22 (P5/32-large)-[[16,19,22],1,Detect,[nc]]# Detect(P3, P4, P5)实验
脚本
importwarnings warnings.filterwarnings('ignore')fromultralyticsimportYOLOif__name__=='__main__':# 修改为自己的配置文件地址model=YOLO('./ultralytics/cfg/models/26/yolo26-C2PSA_TSSA.yaml')# 修改为自己的数据集地址model.train(data='./ultralytics/cfg/datasets/coco8.yaml',cache=False,imgsz=640,epochs=10,single_cls=False,# 是否是单类别检测batch=8,close_mosaic=10,workers=0,optimizer='MuSGD',# optimizer='SGD',amp=False,project='runs/train',name='yolo26-C2PSA_TSSA',)