news 2026/6/10 11:11:32

第P2周:CIFAR10彩色图片识别

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
第P2周:CIFAR10彩色图片识别
  • 🍨本文为🔗365天深度学习训练营中的学习记录博客
  • 🍖原作者:K同学啊

目录

一、 前期准备

1. 设置GPU

2. 导入数据

3. 数据可视化

二、构建简单的CNN网络

三、 训练模型

1. 设置超参数

2. 编写训练函数

3. 编写测试函数

4. 正式训练

四、 结果可视化

五、个人总结

一、 前期准备

1. 设置GPU

import torch import torch.nn as nn import matplotlib.pyplot as plt import torchvision device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device

2. 导入数据

train_ds = torchvision.datasets.CIFAR10('data', train=True, transform=torchvision.transforms.ToTensor(), download=True) test_ds = torchvision.datasets.CIFAR10('data', train=False, transform=torchvision.transforms.ToTensor(), download=True) batch_size = 32 train_dl = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True) test_dl = torch.utils.data.DataLoader(test_ds, batch_size=batch_size) imgs, labels = next(iter(train_dl)) imgs.shape

3. 数据可视化

import numpy as np plt.figure(figsize=(20, 5)) for i, imgs in enumerate(imgs[:20]): npimg = imgs.numpy().transpose((1, 2, 0)) plt.subplot(2, 10, i+1) plt.imshow(npimg, cmap=plt.cm.binary) plt.axis('off')

二、构建简单的CNN网络

import torch.nn.functional as F num_classes = 10 class Model(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=3) self.pool1 = nn.MaxPool2d(kernel_size=2) self.conv2 = nn.Conv2d(64, 64, kernel_size=3) self.pool2 = nn.MaxPool2d(kernel_size=2) self.conv3 = nn.Conv2d(64, 128, kernel_size=3) self.pool3 = nn.MaxPool2d(kernel_size=2) self.fc1 = nn.Linear(512, 256) self.fc2 = nn.Linear(256, num_classes) def forward(self, x): x = self.pool1(F.relu(self.conv1(x))) x = self.pool2(F.relu(self.conv2(x))) x = self.pool3(F.relu(self.conv3(x))) x = torch.flatten(x, start_dim=1) x = F.relu(self.fc1(x)) x = self.fc2(x) return x from torchinfo import summary # 将模型转移到GPU中(我们模型运行均在GPU中进行) model = Model().to(device) summary(model)

三、 训练模型

1. 设置超参数

loss_fn = nn.CrossEntropyLoss() # 创建损失函数 learn_rate = 1e-2 # 学习率 opt = torch.optim.SGD(model.parameters(),lr=learn_rate)

2. 编写训练函数

# 训练循环 def train(dataloader, model, loss_fn, optimizer): size = len(dataloader.dataset) # 训练集的大小,一共60000张图片 num_batches = len(dataloader) # 批次数目,1875(60000/32) train_loss, train_acc = 0, 0 # 初始化训练损失和正确率 for X, y in dataloader: # 获取图片及其标签 X, y = X.to(device), y.to(device) # 计算预测误差 pred = model(X) # 网络输出 loss = loss_fn(pred, y) # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失 # 反向传播 optimizer.zero_grad() # grad属性归零 loss.backward() # 反向传播 optimizer.step() # 每一步自动更新 # 记录acc与loss train_acc += (pred.argmax(1) == y).type(torch.float).sum().item() train_loss += loss.item() train_acc /= size train_loss /= num_batches return train_acc, train_loss

3. 编写测试函数

def test (dataloader, model, loss_fn): size = len(dataloader.dataset) # 测试集的大小,一共10000张图片 num_batches = len(dataloader) # 批次数目,313(10000/32=312.5,向上取整) test_loss, test_acc = 0, 0 # 当不进行训练时,停止梯度更新,节省计算内存消耗 with torch.no_grad(): for imgs, target in dataloader: imgs, target = imgs.to(device), target.to(device) # 计算loss target_pred = model(imgs) loss = loss_fn(target_pred, target) test_loss += loss.item() test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item() test_acc /= size test_loss /= num_batches return test_acc, test_loss

4. 正式训练

epochs = 10 train_loss = [] train_acc = [] test_loss = [] test_acc = [] for epoch in range(epochs): model.train() epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt) model.eval() epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn) train_acc.append(epoch_train_acc) train_loss.append(epoch_train_loss) test_acc.append(epoch_test_acc) test_loss.append(epoch_test_loss) template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%,Test_loss:{:.3f}') print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss)) print('Done')

四、 结果可视化

import matplotlib.pyplot as plt #隐藏警告 import warnings warnings.filterwarnings("ignore") #忽略警告信息 plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签 plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号 plt.rcParams['figure.dpi'] = 100 #分辨率 from datetime import datetime current_time = datetime.now() # 获取当前时间 epochs_range = range(epochs) plt.figure(figsize=(12, 3)) plt.subplot(1, 2, 1) plt.plot(epochs_range, train_acc, label='Training Accuracy') plt.plot(epochs_range, test_acc, label='Test Accuracy') plt.legend(loc='lower right') plt.title('Training and Validation Accuracy') plt.xlabel(current_time) # 打卡请带上时间戳,否则代码截图无效 plt.subplot(1, 2, 2) plt.plot(epochs_range, train_loss, label='Training Loss') plt.plot(epochs_range, test_loss, label='Test Loss') plt.legend(loc='upper right') plt.title('Training and Validation Loss') plt.show()

五、个人总结

逐渐熟悉CNN模型构建过程,并逐步理解其原理。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/6/1 3:56:09

LobeChat未成年人保护机制

LobeChat 未成年人保护机制:构建安全可控的 AI 对话环境 在 AI 聊天应用日益普及的今天,孩子们只需轻点屏幕就能与“无所不知”的智能助手对话。这看似便利的背后,却潜藏着真实的风险——一个关于暴力、自残或成人话题的提问,可能…

作者头像 李华
网站建设 2026/6/9 19:43:25

新风口!NHANES肥胖新指标--代谢表型肥胖可一键提取

郑老师的NHANES Online平台,可零代码一键提取和分析数据!目前在持续快速更新指标中!(ps:感兴趣的指标可以和我们说一下,为您快马加鞭安排上!)平台目前可直接分析的所有指标如下&…

作者头像 李华
网站建设 2026/6/2 9:36:40

jQuery EasyUI 布局 - 添加自动播放标签页(Tabs)

jQuery EasyUI 布局 - 添加自动播放标签页(Tabs) jQuery EasyUI 的 tabs 组件本身不内置自动播放(autoplay)功能,但可以通过简单的 JavaScript 代码实现自动切换标签页(autoplay tabs)&#xf…

作者头像 李华
网站建设 2026/6/9 18:08:13

EmotiVoice语音合成系统日志记录与监控方案设计

EmotiVoice语音合成系统日志记录与监控方案设计 在如今的AI应用浪潮中,文本转语音(TTS)早已不再是简单的“机器朗读”,而是朝着情感化、个性化、拟人化的方向快速演进。EmotiVoice作为一款开源的高表现力语音合成引擎,…

作者头像 李华
网站建设 2026/6/10 0:18:22

29、虚拟化主机与应用:KVM与Docker的实践指南

虚拟化主机与应用:KVM与Docker的实践指南 在当今的IT领域,虚拟化技术已经成为了提高资源利用率、简化管理和降低成本的重要手段。本文将深入探讨KVM虚拟机网络桥接和Docker容器的创建、运行与管理,为你提供详细的操作指南和技术解析。 1. KVM虚拟机网络桥接 KVM(Kernel-…

作者头像 李华