news 2026/4/18 16:17:15

卷积神经网络深度探索:NiN网络设计与实践

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
卷积神经网络深度探索:NiN网络设计与实践

网络中的网络(NiN)

学习目标

通过本课程的学习,学员将理解NiN网络的设计理念,同时学习NiN网络如何通过引入1×1卷积层和全局平均汇聚层来增强网络的表达能力和减少过拟合。

相关知识点

  • NiN模型介绍以及训练

学习内容

LeNet、AlexNet和VGG都有一个共同的设计模式:通过一系列的卷积层与汇聚层来提取空间结构特征;然后通过全连接层对特征的表征进行处理。
AlexNet和VGG对LeNet的改进主要在于如何扩大和加深这两个模块。
或者,可以想象在这个过程的早期使用全连接层。然而,如果使用了全连接层,可能会完全放弃表征的空间结构。
网络中的网络NiN)提供了一个非常简单的解决方案:在每个像素的通道上分别使用多层感知机(Lin.Chen.Yan.2013)

1 NiN模型介绍以及训练

1.1 NiN块

回想一下,卷积层的输入和输出由四维张量组成,张量的每个轴分别对应样本、通道、高度和宽度。
另外,全连接层的输入和输出通常是分别对应于样本和特征的二维张量。
NiN的想法是在每个像素位置(针对每个高度和宽度)应用一个全连接层。
如果我们将权重连接到每个空间位置,我们可以将其视为1×11\times 11×1卷积层,或作为在每个像素位置上独立作用的全连接层。
从另一个角度看,即将空间维度中的每个像素视为单个样本,将通道维度视为不同特征(feature)。

图1说明了VGG和NiN及它们的块之间主要架构差异。
NiN块以一个普通卷积层开始,后面是两个1×11 \times 11×1的卷积层。这两个1×11 \times 11×1卷积层充当带有ReLU激活函数的逐像素全连接层。
第一层的卷积窗口形状通常由用户设置。
随后的卷积窗口形状固定为1×11 \times 11×1

图1 对比 VGG 和 NiN 及它们的块之间主要架构差异

%pip install d2l
importtorchimporttorchvisionfromtorch.utilsimportdatafromtorchvisionimporttransformsimporttorch_npufromtorch_npu.contribimporttransfer_to_npufromtorchimportnnfromd2limporttorchasd2ldefnin_block(in_channels,out_channels,kernel_size,strides,padding):returnnn.Sequential(nn.Conv2d(in_channels,out_channels,kernel_size,strides,padding),nn.ReLU(),nn.Conv2d(out_channels,out_channels,kernel_size=1),nn.ReLU(),nn.Conv2d(out_channels,out_channels,kernel_size=1),nn.ReLU())
1.2 NiN模型

最初的NiN网络是在AlexNet后不久提出的,显然从中得到了一些启示。
NiN使用窗口形状为11×1111\times 1111×115×55\times 55×53×33\times 33×3的卷积层,输出通道数量与AlexNet中的相同。
每个NiN块后有一个最大汇聚层,汇聚窗口形状为3×33\times 33×3,步幅为2。

NiN和AlexNet之间的一个显著区别是NiN完全取消了全连接层。
相反,NiN使用一个NiN块,其输出通道数等于标签类别的数量。最后放一个全局平均汇聚层(global average pooling layer),生成一个对数几率 (logits)。NiN设计的一个优点是,它显著减少了模型所需参数的数量。然而,在实践中,这种设计有时会增加训练模型的时间。

net=nn.Sequential(nin_block(1,96,kernel_size=11,strides=4,padding=0),nn.MaxPool2d(3,stride=2),nin_block(96,256,kernel_size=5,strides=1,padding=2),nn.MaxPool2d(3,stride=2),nin_block(256,384,kernel_size=3,strides=1,padding=1),nn.MaxPool2d(3,stride=2),nn.Dropout(0.5),# 标签类别数是10nin_block(384,10,kernel_size=3,strides=1,padding=1),nn.AdaptiveAvgPool2d((1,1)),# 将四维的输出转成二维的输出,其形状为(批量大小,10)nn.Flatten())

我们创建一个数据样本来查看每个块的输出形状。

X=torch.rand(size=(1,1,224,224))forlayerinnet:X=layer(X)print(layer.__class__.__name__,'output shape:\t',X.shape)

out:

Sequential output shape: torch.Size([1, 96, 54, 54]) MaxPool2d output shape: torch.Size([1, 96, 26, 26]) Sequential output shape: torch.Size([1, 256, 26, 26]) MaxPool2d output shape: torch.Size([1, 256, 12, 12]) Sequential output shape: torch.Size([1, 384, 12, 12]) MaxPool2d output shape: torch.Size([1, 384, 5, 5]) Dropout output shape: torch.Size([1, 384, 5, 5]) Sequential output shape: torch.Size([1, 10, 5, 5]) AdaptiveAvgPool2d output shape: torch.Size([1, 10, 1, 1]) Flatten output shape: torch.Size([1, 10])
1.3 获取数据集
!wget https://model-community-picture.obs.cn-north-4.myhuaweicloud.com/ascend-zone/notebook_datasets/0edd96d814ee11f09ef9fa163edcddae/FashionMNIST.zip
!unzip FashionMNIST.zip

现在我们定义load_data_fashion_mnist函数,用于获取和读取Fashion-MNIST数据集。
这个函数返回训练集和验证集的数据迭代器。
此外,这个函数还接受一个可选参数resize,用来将图像大小调整为另一种形状。

defload_data_fashion_mnist(batch_size,resize=None):#@save"""下载Fashion-MNIST数据集,然后将其加载到内存中"""trans=[transforms.ToTensor()]ifresize:trans.insert(0,transforms.Resize(resize))trans=transforms.Compose(trans)mnist_train=torchvision.datasets.FashionMNIST(root="./data",train=True,transform=trans,download=False)mnist_test=torchvision.datasets.FashionMNIST(root="./data",train=False,transform=trans,download=False)return(data.DataLoader(mnist_train,batch_size,shuffle=True,num_workers=4),data.DataLoader(mnist_test,batch_size,shuffle=False,num_workers=4))
1.4 训练模型

和以前一样,我们使用Fashion-MNIST来训练模型。训练NiN与训练AlexNet、VGG时相似。

lr,num_epochs,batch_size=0.1,5,128train_iter,test_iter=load_data_fashion_mnist(batch_size,resize=224)d2l.train_ch6(net,train_iter,test_iter,num_epochs,lr,d2l.try_gpu())

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/18 8:19:50

【MCP IP冲突检测神器推荐】:5款高效工具揭秘,告别网络瘫痪危机

第一章:MCP IP冲突检测工具概述在现代数据中心与云计算环境中,虚拟机和容器的大规模部署使得IP地址管理变得愈发复杂。MCP(Multi-Cloud Platform)IP冲突检测工具是一款专为跨云环境设计的网络诊断组件,用于实时发现并报…

作者头像 李华
网站建设 2026/4/17 8:42:00

Spring AOP实现原理及几种应用方式详解

在Spring框架中,AOP(面向切面编程)是实现关注点分离、增强代码模块化的重要工具。它允许开发者将横切关注点(如日志、事务管理)从核心业务逻辑中抽离,通过声明式或编程式的方式织入到程序执行流程中。理解其…

作者头像 李华
网站建设 2026/4/18 3:38:08

数据中心如果有几十甚至几百 T 数据,如何实现数据安全和数据备份?

说句实在话,很多人第一次真正面对几十 T、上百 T 数据的时候,都会有一个错觉: “我们不是早就做了 RAID、做了备份吗?还能出什么事?” 然后,事故真的发生一次,你就再也不敢这么想了。 我见过的数据中心事故里,真正致命的,从来不是硬盘坏了,而是: 误删 脚本写错 勒…

作者头像 李华
网站建设 2026/4/18 3:30:51

C语言转中文编程:编译器如何实现关键字转换?

从C语言转向中文编程语言,本质上是将一种成熟的、以英文关键字为基础的编程体系,转化为更贴近中文思维习惯的编程环境。这不仅仅是关键字的简单翻译,更涉及到编译器设计、语法解析、社区生态等一系列工程与理念的挑战。对于习惯了C语言严谨性…

作者头像 李华
网站建设 2026/4/18 7:59:04

STM32驱动开发中Keil工程搭建核心要点

从零搭建一个可靠的STM32开发环境:Keil工程实战全解析你有没有过这样的经历?新项目刚开,信心满满地打开Keil,新建工程、添加文件、写好main函数,一编译——报错;好不容易编译通过了,下载进去单片…

作者头像 李华