news 2026/4/18 11:47:09

第P3周:Pytorch实现天气识别

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
第P3周:Pytorch实现天气识别
  • 🍨本文为🔗365天深度学习训练营中的学习记录博客

  • 🍖原作者:K同学啊

目录

一、 前期准备

1. 设置GPU

2. 导入数据

3. 显示图片

4. 划分数据集

二、构建简单的CNN网络

三、 训练模型

1. 设置超参数

2. 编写训练函数

3. 编写测试函数

4. 正式训练

四、 结果可视化

五、 个人总结

过拟合的确认方法

解决方案

1. 正则化措施

2. 数据增强优化

3. 批归一化(BN)应用

4. 早停策略

5. 动态学习率调整

6. 优化器升级

一、 前期准备

1. 设置GPU

import torch import torch.nn as nn import torchvision.transforms as transforms import torchvision from torchvision import transforms, datasets import os,PIL,pathlib,random device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device

2. 导入数据

data_dir = './data/' data_dir = pathlib.Path(data_dir) data_paths = list(data_dir.glob('*')) classeNames = [str(path).split("\\")[1] for path in data_paths] classeNames
  • 第一步:使用pathlib.Path()函数将字符串类型的文件夹路径转换为pathlib.Path对象。
  • 第二步:使用glob()方法获取data_dir路径下的所有文件路径,并以列表形式存储在data_paths中。
  • 第三步:通过split()函数对data_paths中的每个文件路径执行分割操作,获得各个文件所属的类别名称,并存储在classeNames
  • 第四步:打印classeNames列表,显示每个文件所属的类别名称。

3. 显示图片

import matplotlib.pyplot as plt from PIL import Image # 指定图像文件夹路径 image_folder = './data/cloudy/' # 获取文件夹中的所有图像文件 image_files = [f for f in os.listdir(image_folder) if f.endswith((".jpg", ".png", ".jpeg"))] # 创建Matplotlib图像 fig, axes = plt.subplots(3, 8, figsize=(16, 6)) # 使用列表推导式加载和显示图像 for ax, img_file in zip(axes.flat, image_files): img_path = os.path.join(image_folder, img_file) img = Image.open(img_path) ax.imshow(img) ax.axis('off') # 显示图像 plt.tight_layout() plt.show()

total_datadir = './data/' # 关于transforms.Compose的更多介绍可以参考:https://blog.csdn.net/qq_38251616/article/details/124878863 train_transforms = transforms.Compose([ transforms.Resize([224, 224]), # 将输入图片resize成统一尺寸 transforms.ToTensor(), # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间 transforms.Normalize( # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛 mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。 ]) total_data = datasets.ImageFolder(total_datadir,transform=train_transforms) total_data

4. 划分数据集

train_size = int(0.8 * len(total_data)) test_size = len(total_data) - train_size train_dataset, test_dataset = torch.utils.data.random_split(total_data, [train_size, test_size]) train_dataset, test_dataset train_size,test_size batch_size = 32 train_dl = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_dl = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True) for X, y in test_dl: print("Shape of X [N, C, H, W]: ", X.shape) print("Shape of y: ", y.shape, y.dtype) break

二、构建简单的CNN网络

import torch.nn.functional as F class Network_bn(nn.Module): def __init__(self): super(Network_bn, self).__init__() """ nn.Conv2d()函数: 第一个参数(in_channels)是输入的channel数量 第二个参数(out_channels)是输出的channel数量 第三个参数(kernel_size)是卷积核大小 第四个参数(stride)是步长,默认为1 第五个参数(padding)是填充大小,默认为0 """ self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=5, stride=1, padding=0) self.bn1 = nn.BatchNorm2d(12) self.conv2 = nn.Conv2d(in_channels=12, out_channels=12, kernel_size=5, stride=1, padding=0) self.bn2 = nn.BatchNorm2d(12) self.pool1 = nn.MaxPool2d(2,2) self.conv4 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=5, stride=1, padding=0) self.bn4 = nn.BatchNorm2d(24) self.conv5 = nn.Conv2d(in_channels=24, out_channels=24, kernel_size=5, stride=1, padding=0) self.bn5 = nn.BatchNorm2d(24) self.pool2 = nn.MaxPool2d(2,2) self.fc1 = nn.Linear(24*50*50, len(classeNames)) def forward(self, x): x = F.relu(self.bn1(self.conv1(x))) x = F.relu(self.bn2(self.conv2(x))) x = self.pool1(x) x = F.relu(self.bn4(self.conv4(x))) x = F.relu(self.bn5(self.conv5(x))) x = self.pool2(x) x = x.view(-1, 24*50*50) x = self.fc1(x) return x device = "cuda" if torch.cuda.is_available() else "cpu" print("Using {} device".format(device)) model = Network_bn().to(device) model

三、 训练模型

1. 设置超参数

loss_fn = nn.CrossEntropyLoss() # 创建损失函数 learn_rate = 1e-4 # 学习率 opt = torch.optim.SGD(model.parameters(),lr=learn_rate)

2. 编写训练函数

# 训练循环 def train(dataloader, model, loss_fn, optimizer): size = len(dataloader.dataset) # 训练集的大小,一共60000张图片 num_batches = len(dataloader) # 批次数目,1875(60000/32) train_loss, train_acc = 0, 0 # 初始化训练损失和正确率 for X, y in dataloader: # 获取图片及其标签 X, y = X.to(device), y.to(device) # 计算预测误差 pred = model(X) # 网络输出 loss = loss_fn(pred, y) # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失 # 反向传播 optimizer.zero_grad() # grad属性归零 loss.backward() # 反向传播 optimizer.step() # 每一步自动更新 # 记录acc与loss train_acc += (pred.argmax(1) == y).type(torch.float).sum().item() train_loss += loss.item() train_acc /= size train_loss /= num_batches return train_acc, train_loss

3. 编写测试函数

def test (dataloader, model, loss_fn): size = len(dataloader.dataset) # 测试集的大小,一共10000张图片 num_batches = len(dataloader) # 批次数目,313(10000/32=312.5,向上取整) test_loss, test_acc = 0, 0 # 当不进行训练时,停止梯度更新,节省计算内存消耗 with torch.no_grad(): for imgs, target in dataloader: imgs, target = imgs.to(device), target.to(device) # 计算loss target_pred = model(imgs) loss = loss_fn(target_pred, target) test_loss += loss.item() test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item() test_acc /= size test_loss /= num_batches return test_acc, test_loss

4. 正式训练

epochs = 20 train_loss = [] train_acc = [] test_loss = [] test_acc = [] for epoch in range(epochs): model.train() epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt) model.eval() epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn) train_acc.append(epoch_train_acc) train_loss.append(epoch_train_loss) test_acc.append(epoch_test_acc) test_loss.append(epoch_test_loss) template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%,Test_loss:{:.3f}') print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss)) print('Done')

四、 结果可视化

import matplotlib.pyplot as plt #隐藏警告 import warnings warnings.filterwarnings("ignore") #忽略警告信息 plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签 plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号 plt.rcParams['figure.dpi'] = 100 #分辨率 from datetime import datetime current_time = datetime.now() # 获取当前时间 epochs_range = range(epochs) plt.figure(figsize=(12, 3)) plt.subplot(1, 2, 1) plt.plot(epochs_range, train_acc, label='Training Accuracy') plt.plot(epochs_range, test_acc, label='Test Accuracy') plt.legend(loc='lower right') plt.title('Training and Validation Accuracy') plt.xlabel(current_time) # 打卡请带上时间戳,否则代码截图无效 plt.subplot(1, 2, 2) plt.plot(epochs_range, train_loss, label='Training Loss') plt.plot(epochs_range, test_loss, label='Test Loss') plt.legend(loc='upper right') plt.title('Training and Validation Loss') plt.show()

五、 个人总结

过拟合的确认方法

当模型出现疑似过拟合时,可通过以下方法进一步确认:

  1. 增加训练轮次:持续训练时若验证集准确率下降而损失上升,则确认过拟合
  2. 降低模型复杂度:如减少全连接层或通道数后验证指标提升,说明原模型过于复杂
  3. 增强数据多样性:添加数据增强后若验证指标改善,表明原训练数据不足导致模型死记硬背

解决方案

1. 正则化措施

  • Dropout:在全连接层前添加nn.Dropout(0.5),通过随机丢弃神经元强制学习鲁棒特征
  • L2正则化:优化器中设置weight_decay=1e-4,抑制过大权重,防止噪声拟合

2. 数据增强优化

  • 方法:使用torchvision.transforms实现随机裁剪、翻转、颜色变换等
  • 作用:提升数据多样性,促使模型学习通用特征而非样本细节,增强泛化能力

3. 批归一化(BN)应用

  • 原理:对卷积层输出进行归一化处理(均值≈0,方差≈1)后缩放平移
  • 优势
    • 稳定层间输入分布,加速收敛并缓解梯度问题
    • 具有正则化效果,类似"mini-batch级数据增强"

4. 早停策略

  • 实施:持续监控验证集准确率,当性能不再提升(超过patience轮次)时终止训练
  • 价值
    • 防止过度拟合训练噪声
    • 节省计算资源,自动获取最佳泛化模型

5. 动态学习率调整

  • 机制:当验证性能停滞时,按因子(如0.5)降低学习率
  • 效益
    • 实现更精细的参数优化
    • 平缓的参数更新可降低过拟合风险

6. 优化器升级

  • 改进方案:将SGD替换为Adam优化器并配合权重衰减
  • 原因:当前SGD学习率(1e-4)偏低导致收敛缓慢,且缺乏动量和正则化支持
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/18 2:04:04

回溯+位运算|前缀和优化背包

汉字写一遍代码思路 可以提升和人沟通 和加深理解所以不可以只写完代码就去去丸了 还是要复习lc756class Solution { public:bool pyramidTransition(string bottom, vector<string>& allowed) {vector<int> groups[7][7];for (auto& s : allowed) {// A~F…

作者头像 李华
网站建设 2026/4/18 2:02:26

PostgreSQL 时间/日期处理指南

PostgreSQL 时间/日期处理指南 引言 PostgreSQL 是一款功能强大的开源关系型数据库系统,它提供了丰富的数据类型和功能,其中包括对时间/日期数据的支持。本文将详细介绍 PostgreSQL 中时间/日期类型的使用方法,包括数据类型、常用函数、操作和注意事项。 PostgreSQL 时间…

作者头像 李华
网站建设 2026/4/18 2:04:46

Doris资源组管理:精细化资源分配策略

Doris资源组管理:精细化资源分配的"食堂排队秘诀" 关键词:Doris资源组、精细化资源分配、查询优化、资源隔离、队列调度、Cgroup、多租户管理 摘要:当Doris作为大规模数据查询的"餐厅"时,如何让"食客"(查询)快速吃到"饭"(结果)…

作者头像 李华
网站建设 2026/4/17 8:38:54

毕业生都在用的十大降ai工具,建议收藏

家人们&#xff0c;现在学校查得是真严&#xff0c;不仅重复率&#xff0c;还得降ai率&#xff0c;学校规定必须得20%以下... 折腾了半个月&#xff0c;终于把市面上各类方法试了个遍&#xff0c;坑踩了不少&#xff0c;智商税也交了。今天这就把这份十大降AI工具合集掏心窝子…

作者头像 李华
网站建设 2026/4/18 1:59:16

fwrite与fflush作用

简单说&#xff1a; fwrite 负责“写数据”&#xff0c; fflush 负责“把缓冲里的内容真的推到文件/设备”。一、 fwrite 做什么&#xff1f;fwrite 是标准 C 里的带缓冲的文件写入函数&#xff0c;原型&#xff1a;csize_t fwrite(const void *ptr, size_t size, size_t nme…

作者头像 李华
网站建设 2026/4/18 1:57:14

《告别跨端运算偏差:游戏确定浮点数学库的核心搭建指南》

早期涉足游戏开发时,曾执着于浮点精度的极致提升,认为更高的精度就能消除所有差异,直到在一款多人协作游戏的测试中,见证过同一技能在PC端与移动端的伤害结算偏差、主机玩家与手机玩家看到的角色跳跃轨迹分歧—明明是相同的触发条件,却出现技能命中判定失效、物理道具飞行…

作者头像 李华