news 2026/4/18 12:45:39

[深度学习]Vision Transformer

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
[深度学习]Vision Transformer

Pytorch实现Vision Transformer

importtorchimporttorch.nnasnnclassPatchEmbedding(nn.Module):def__init__(self,img_size=224,patch_size=16,in_channels=3,embed_dim=768):super().__init__()self.img_size=img_size self.patch_size=patch_size self.n_patches=(img_size//patch_size)**2# 使用卷积层实现patch分割和嵌入self.proj=nn.Conv2d(in_channels=in_channels,out_channels=embed_dim,kernel_size=patch_size,stride=patch_size)defforward(self,x):# 输入x形状: [batch_size, in_channels, img_size, img_size]# 输出形状: [batch_size, n_patches, embed_dim]x=self.proj(x)# [batch_size, embed_dim, n_patches^0.5, n_patches^0.5]x=x.flatten(2)# [batch_size, embed_dim, n_patches]x=x.transpose(1,2)# [batch_size, n_patches, embed_dim]returnxclassPositionEmbedding(nn.Module):def__init__(self,n_patches,embed_dim,dropout=0.1):super().__init__()self.pos_embed=nn.Parameter(torch.zeros(1,n_patches+1,embed_dim))# +1 for class tokenself.dropout=nn.Dropout(dropout)defforward(self,x):# x形状: [batch_size, n_patches+1, embed_dim]x=x+self.pos_embed# 添加位置编码x=self.dropout(x)returnxclassMultiHeadAttention(nn.Module):def__init__(self,embed_dim,num_heads,dropout=0.1):super().__init__()self.embed_dim=embed_dim self.num_heads=num_heads self.head_dim=embed_dim//num_headsassertself.head_dim*num_heads==embed_dim,"Embedding dimension must be divisible by number of heads"self.qkv=nn.Linear(embed_dim,embed_dim*3)# 同时计算Q,K,Vself.attn_dropout=nn.Dropout(dropout)self.proj=nn.Linear(embed_dim,embed_dim)self.proj_dropout=nn.Dropout(dropout)self.scale=self.head_dim**-0.5defforward(self,x):batch_size,n_patches,embed_dim=x.shape# 计算Q,K,V [batch_size, n_patches, num_heads, head_dim]qkv=self.qkv(x).reshape(batch_size,n_patches,3,self.num_heads,self.head_dim).permute(2,0,3,1,4)q,k,v=qkv[0],qkv[1],qkv[2]# 计算注意力分数 [batch_size, num_heads, n_patches, n_patches]attn=(q @ k.transpose(-2,-1))*self.scale attn=attn.softmax(dim=-1)attn=self.attn_dropout(attn)# 应用注意力权重到V上 [batch_size, num_heads, n_patches, head_dim]out=attn @ v out=out.transpose(1,2).reshape(batch_size,n_patches,embed_dim)# 线性投影和dropoutout=self.proj(out)out=self.proj_dropout(out)returnoutclassMLP(nn.Module):def__init__(self,in_features,hidden_features,out_features,dropout=0.1):super().__init__()self.fc1=nn.Linear(in_features,hidden_features)self.act=nn.GELU()self.fc2=nn.Linear(hidden_features,out_features)self.dropout=nn.Dropout(dropout)defforward(self,x):x=self.fc1(x)x=self.act(x)x=self.dropout(x)x=self.fc2(x)x=self.dropout(x)returnxclassTransformerBlock(nn.Module):def__init__(self,embed_dim,num_heads,mlp_ratio=4,dropout=0.1):super().__init__()self.norm1=nn.LayerNorm(embed_dim)self.attn=MultiHeadAttention(embed_dim,num_heads,dropout)self.norm2=nn.LayerNorm(embed_dim)self.mlp=MLP(in_features=embed_dim,hidden_features=embed_dim*mlp_ratio,out_features=embed_dim,dropout=dropout)defforward(self,x):# 残差连接和层归一化x=x+self.attn(self.norm1(x))x=x+self.mlp(self.norm2(x))returnxclassVisionTransformer(nn.Module):def__init__(self,img_size=224,patch_size=16,in_channels=3,n_classes=1000,embed_dim=768,depth=12,num_heads=12,mlp_ratio=4,dropout=0.1):super().__init__()self.patch_embed=PatchEmbedding(img_size,patch_size,in_channels,embed_dim)n_patches=self.patch_embed.n_patches# 分类token和位置编码self.cls_token=nn.Parameter(torch.zeros(1,1,embed_dim))self.pos_embed=PositionEmbedding(n_patches,embed_dim,dropout)# Transformer编码器self.blocks=nn.Sequential(*[TransformerBlock(embed_dim,num_heads,mlp_ratio,dropout)for_inrange(depth)])# 分类头self.norm=nn.LayerNorm(embed_dim)self.head=nn.Linear(embed_dim,n_classes)# 初始化权重nn.init.trunc_normal_(self.cls_token,std=0.02)defforward(self,x):batch_size=x.shape[0]# 生成patch嵌入x=self.patch_embed(x)# [batch_size, n_patches, embed_dim]# 添加class tokencls_token=self.cls_token.expand(batch_size,-1,-1)x=torch.cat([cls_token,x],dim=1)# [batch_size, n_patches+1, embed_dim]# 添加位置编码x=self.pos_embed(x)# 通过Transformer编码器x=self.blocks(x)# 分类x=self.norm(x)cls_token_final=x[:,0]# 只取class token对应的输出x=self.head(cls_token_final)returnxif__name__=='__main__':x=torch.rand(1,3,224,224)model=VisionTransformer(img_size=224,patch_size=16,)y=model(x)print('y.shape = ',y.shape)print(y)

参考资料

  1. 博客
  2. B站教程
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/17 15:58:41

还要多久?NASA卫星从太空俯瞰,那条通往“正义”的道路

当NASA的卫星从数百公里的高空俯瞰地球,它们通常在记录冰川的消融或城市的扩张。但这一次,Landsat 8 卫星的镜头聚焦在了阿拉巴马州的一条街道上。这里,曾见证了一场改变人类文明进程的漫长行军。 来自太空的“历史快照”2025年9月&#xff0…

作者头像 李华
网站建设 2026/4/18 9:44:11

设计模式 -详解

1.单例模式 单例模式是指在整个应用中一个类的对象只允许出现一个(类的对象最多 只允许创建一次); 我们在创建一个类的对象时,调用的是类的构造器,所以在单例中类的构 造器只允许调用一次 核心:构造方法私有化,不允许…

作者头像 李华
网站建设 2026/4/18 7:03:42

您的APP还在“隐身”吗?2026年ASO优化高级实战指南

应用商店优化 (ASO)是一个持续的过程,旨在通过优化元数据(标题、关键词)、创意素材(应用截图、视频)和性能指标(应用评分、应用评论)来提升应用在Apple和Google Play等应用商店中的曝光度和转化…

作者头像 李华
网站建设 2026/4/18 1:25:39

【山海鲸实战案例】通过二维组件控制三维场景昼夜变化

在项目制作过程中,我们可能会需要手动控制三维场景的昼夜切换,此时通过按钮组件的交互设置就可以非常简单地达到目的,下面我们就来看一下具体该如何进行设置。 首先,创建一个三维场景。 添加两个“按钮”组件,分别命名…

作者头像 李华
网站建设 2026/4/18 9:44:28

原子层加工技术推动碳化硅量子光子电路发展

原子层加工技术助力碳化硅量子光子电路蓬勃发展 来自马克斯普朗克光科学研究所(Max Planck Institute for the Science of Light)与弗劳恩霍夫集成系统与元器件技术研究所(Fraunhofer Institute for Integrated Systems and Device Technolo…

作者头像 李华
网站建设 2026/4/18 7:55:17

深圳跨境电商中的“亚马逊精品模式“详解

深圳跨境电商中的"亚马逊精品模式"详解 一、核心定义 亚马逊精品模式是跨境电商中一种"少而精"的运营策略,指卖家专注于少数高潜力产品(通常成熟期仅需10-20款),通过深度选品、精细化运营和供应链优化&#x…

作者头像 李华