import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import TensorDataset, DataLoader import gzip import pickle from pathlib import Path import numpy as np # 定义数据根目录路径 DATA_PATH = Path("./data") # 拼接MNIST数据子目录路径 PATH = DATA_PATH / "MNIST" # 定义MNIST数据文件名 FILENAME = "mnist.pkl.gz" # 以gzip压缩格式打开文件,读取二进制数据 # (PATH/FILENAME).as_posix() 将Path对象转为字符串路径,兼容不同操作系统 with gzip.open((PATH/FILENAME).as_posix(), 'rb') as f: # 加载pickle序列化数据,指定编码为latin-1以兼容旧版本数据 # 数据结构:(训练集, 验证集, 测试集),每个集合包含(特征, 标签) ((x_train, y_train), (x_valid, y_valid),(x_test, y_test)) = pickle.load(f,encoding='latin-1') # 将numpy数组转换为PyTorch张量,方便后续计算 x_train,y_train,x_valid,y_valid=map(torch.Tensor,(x_train,y_train,x_valid,y_valid)) # 定义损失函数为交叉熵损失(适用于多分类任务) # 交叉熵损失会自动将输出经过softmax并计算负对数似然损失 loss_func=F.cross_entropy # 定义批次大小:每次训练传入模型的样本数量 bs=64 # batch-size 指定一次训练的样本个数 """从训练集中取前64个样本做测试(后续实际使用DataLoader批量加载)""" xb=x_train[0:bs] # 取前64个训练集特征数据 yb=y_train[0:bs] # 取前64个训练集标签数据 """ 使用TensorDataset和DataLoader来简化数据加载流程 TensorDataset:将特征和标签打包成数据集 DataLoader:按批次加载数据,支持打乱、多线程等功能 """ # 将训练集特征和标签打包成TensorDataset对象 train_ds = TensorDataset(x_train,y_train) # 将验证集特征和标签打包成TensorDataset对象 valid_ds=TensorDataset(x_valid,y_valid) # 数据加载器封装函数:统一创建训练/验证数据加载器 # 功能:简化数据加载器的创建流程,确保训练/验证集加载器参数统一 # 参数: # train_ds:训练集TensorDataset对象 # valid_ds:验证集TensorDataset对象 # bs:训练集批次大小 # 返回: # 元组:(训练集DataLoader, 验证集DataLoader) def get_data(train_ds_1, valid_ds_1, bs_1): return ( # 训练集加载器:打乱顺序 DataLoader(train_ds_1,batch_size=bs_1,shuffle=True), # 验证集加载器:批次翻倍,不打乱(验证无需打乱) DataLoader(valid_ds_1,batch_size=bs_1*2), ) # 定义MNIST手写数字识别的神经网络模型类 # 继承nn.Module(PyTorch所有模型的基类) class MnistNN(nn.Module): # 构造函数:初始化模型层结构 def __init__(self): # 调用父类构造函数 super().__init__() # 第一个全连接隐藏层:输入维度784(28*28像素),输出维度128 self.hidden1=nn.Linear(784,128) # 第二个全连接隐藏层:输入维度128,输出维度256 self.hidden2=nn.Linear(128,256) # 注释掉的第三个隐藏层(可扩展) # self.hidden3=nn.Linear(256,512) # 输出层:输入维度256,输出维度10(对应0-9十个数字分类) self.out = nn.Linear(256,10) # Dropout层:随机丢弃50%的神经元,防止过拟合 self.dropout = nn.Dropout(0.5) # 前向传播函数:定义数据在模型中的流动路径 # 参数x:输入的特征张量,形状为[batch_size, 784] # 返回:模型输出的未经过softmax的原始得分(logits),形状为[batch_size, 10] def forward(self, x): # 第一层:线性变换 + ReLU激活函数 x=F.relu(self.hidden1(x)) # 对第一层输出进行Dropout正则化 x=self.dropout(x) # 第二层:线性变换 + ReLU激活函数 x=F.relu(self.hidden2(x)) # 对第二层输出进行Dropout正则化 x=self.dropout(x) # 注释掉的第三层计算(可扩展) # x=F.relu(self.hidden3(x)) # 输出层:线性变换(无激活函数,交叉熵损失会自动处理softmax) x=self.out(x) return x # 模型和优化器创建函数:封装模型初始化和优化器配置 # 功能:创建模型实例并配置Adam优化器 # 参数:无 # 返回: # 元组:(模型实例, Adam优化器实例) def get_model(): model_1 = MnistNN() # 创建模型实例 # 返回模型和Adam优化器: # model.parameters():需要优化的模型参数 # lr=0.001:学习率 return model_1,optim.Adam(model_1.parameters(),lr=0.001) """ 模型训练核心函数说明: epoch:完整遍历一次所有训练数据的次数 step:这里等价于epoch(迭代次数) model:模型实例 loss_func:损失函数 opt:优化器 train_dl:训练数据加载器 valid_dl:验证数据加载器 训练模式(model.train()):启用Batch Normalization和Dropout 验证模式(model.eval()):禁用Batch Normalization和Dropout(使用移动平均/不丢弃神经元) 训练时更新权重,验证时仅计算损失不更新 """ # 模型训练函数:执行多轮训练和验证 # 功能:完成模型的训练过程,每轮训练后在验证集评估损失 # 参数: # steps:训练轮数(等价于epoch数) # model:待训练的神经网络模型 # loss_func:损失函数(如交叉熵损失) # opt:优化器(如Adam) # train_dl:训练集数据加载器 # valid_dl:验证集数据加载器 # 返回:无(打印每轮验证集损失) def fit(steps, model_2, loss_func_1, opt_1, train_dl_1, valid_dl_1): # 遍历每一轮训练 for step in range(steps): model_2.train() # 将模型设为训练模式:启用Dropout/BatchNorm # 遍历训练集的每个批次 for xb_1,yb_1 in train_dl_1: # 计算当前批次损失并更新模型参数 loss_batch(model_2, loss_func_1, xb_1, yb_1,opt_1) model_2.eval() # 将模型设为验证模式:禁用Dropout/BatchNorm with torch.no_grad(): # 禁用梯度计算,节省内存和计算资源 # 计算验证集每个批次的损失和样本数 # losses,nums = zip(*[loss_batch(model_2, loss_func_1, xb_1, yb_1) for xb_1,yb_1 in valid_dl_1]) # 初始化空列表用于存储每个批次的损失和数量 batch_results = [] # 遍历验证数据加载器中的每个批次 for xb_1, yb_1 in valid_dl_1: # 计算当前批次的损失和样本数量 batch_loss, batch_num = loss_batch(model_2, loss_func_1, xb_1, yb_1) # 将结果添加到列表中 batch_results.append((batch_loss, batch_num)) # 解压列表,分别得到所有批次的损失和数量 losses, nums = zip(*batch_results) # 计算验证集平均损失:总损失/总样本数 """ np.sum(np.multiply(losses,nums))是 NumPy 逐元素乘法 作用是将 losses 和 nums 两个数组中对应位置的元素逐一相乘,最终返回一个与输入同形状的新数组。 """ val_loss = np.sum(np.multiply(losses,nums))/np.sum(nums) # 打印当前训练轮数和验证集损失 print("当前step:"+ str(step),'验证集损失:'+ str(val_loss)) # 批次损失计算函数:计算单个批次损失,可选更新模型参数 # 功能:计算批次损失,若传入优化器则执行反向传播和参数更新 # 参数: # model_3:神经网络模型 # loss_func_2:损失函数 # xb_2:单个批次的特征张量,形状[batch_size, 784] # yb_2:单个批次的标签张量,形状[batch_size] # opt_2:优化器(可选,None时仅计算损失不更新参数) # 返回: # 元组:(当前批次的损失值(标量), 当前批次的样本数量) def loss_batch(model_3, loss_func_2, xb_2, yb_2, opt_2=None): # 计算损失: # model(xb):模型输出,形状[batch_size, 10] # yb.long():将标签转为长整型(交叉熵损失要求标签为int64) loss = loss_func_2(model_3(xb_2),yb_2.long()) # 如果传入了优化器,则执行反向传播和参数更新 if opt_2 is not None: loss.backward() # 反向传播:计算参数梯度 opt_2.step() # 优化器更新参数 opt_2.zero_grad() # 清空梯度缓存(防止梯度累积) # 返回损失值(转换为Python标量)和批次样本数 return loss.item(),len(xb_2) # 创建训练/验证数据加载器 train_dl,valid_dl=get_data(train_ds,valid_ds,bs) # 创建模型和优化器 model,opt=get_model() # 开始训练:100轮 fit(50,model,loss_func,opt,train_dl,valid_dl) # 计算验证集上的模型准确率 correct = 0 # 正确预测的样本数 total = 0 # 总样本数 # 遍历验证集每个批次 for xb,yb in valid_dl: output = model(xb) # 模型前向传播,得到输出logits # torch.max(output.data,1):按维度1(类别维度)取最大值 # 返回值:(最大值, 最大值索引),索引即为预测的数字类别 _,predicted = torch.max(output.data,1) total += yb.size(0) # 累加当前批次样本数 # 统计预测正确的样本数:(predicted == yb)生成布尔张量,sum()求和,item()转标量 correct += (predicted == yb).sum().item() # 打印模型在验证集上的准确率 print("Accuracy of the network on the 10000 test images: %d %%" % (100 * correct / total))MNIST-手写数字识别分类案例
张小明
前端开发工程师
Langchain-Chatchat问答系统灰度期间服务优雅启停
Langchain-Chatchat问答系统灰度期间服务优雅启停 在企业级AI应用逐步从实验走向生产落地的今天,一个看似不起眼但至关重要的工程细节正悄然决定着系统的可靠性——如何在不中断用户体验的前提下完成服务升级?尤其是在部署像 Langchain-Chatchat 这类基于…
Langchain-Chatchat结合Argo CD实现GitOps部署
Langchain-Chatchat 结合 Argo CD 实现 GitOps 部署 在企业智能化转型的浪潮中,如何安全、可靠、可追溯地部署基于大语言模型(LLM)的知识管理系统,正成为 DevOps 与 AI 工程化交叉领域的重要课题。传统方式下,本地知识…
Austroads:车速管理研究综述:实证依据与指导建议(英) 2025
该报告是 Austroads 为更新《道路安全指南第 3 部分:安全车速》而开展的研究综述,核心是整合车速管理的最新实证与实践经验,为澳大拉西亚地区道路安全政策提供支撑。一、研究背景与目标背景:现有指南需纳入国际前沿方法࿰…
毫米波雷达:从3D到4D,智能汽车的“全天候眼”是怎么炼成的
本文约7,085字,建议收藏阅读作 者 | aFakeProgramer出 品 | 汽车电子与软件摘要各位技术佬、汽车控们,今天咱们聚焦智能汽车里最“耐造”的传感器——毫米波雷达。它不像激光雷达娇贵,也不似摄像头“看天吃饭”,却是L2到L4级自动驾…
前后端分离线上历史馆藏系统系统|SpringBoot+Vue+MyBatis+MySQL完整源码+部署教程
摘要 随着数字化时代的快速发展,博物馆和文化机构对线上馆藏管理系统的需求日益增长。传统的馆藏管理系统通常采用单体架构,存在前后端耦合度高、维护困难、扩展性差等问题。线上历史馆藏系统旨在解决这些问题,通过前后端分离架构实现高效、…
在线监测:让燃气轮机在能源转型中更可靠、更高效
最近几年,每当遇到极端高温或寒潮天气,各地的燃气电厂常常进入“战时状态”。在风电、光伏出力不足的时刻,这些电厂必须迅速顶上,保障电网稳定。这种频繁启停、快速爬坡的运行方式,对电厂最核心、最昂贵的设备——燃气…