news 2026/4/18 8:23:39

DAY 36 简单的神经网络

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
DAY 36 简单的神经网络

一、PyTorch和cuda的安装

二、查看显卡信息的命令行命令(cmd中)

三、cuda的检查

import torch # 检查CUDA是否可用 if torch.cuda.is_available(): print("CUDA可用!") # 获取可用的CUDA设备数量 device_count = torch.cuda.device_count() print(f"可用的CUDA设备数量: {device_count}") # 获取当前使用的CUDA设备索引 current_device = torch.cuda.current_device() print(f"当前使用的CUDA设备索引: {current_device}") # 获取当前CUDA设备的名称 device_name = torch.cuda.get_device_name(current_device) print(f"当前CUDA设备的名称: {device_name}") # 获取CUDA版本 cuda_version = torch.version.cuda print(f"CUDA版本: {cuda_version}") else: print("CUDA不可用。")

四、简单神经网络的流程

1.数据预处理(归一化、转化为张量)

注意事项:

  • 分类任务中,若标签为整数(如0/1/2类别),需转换为long型(对应PyTorch中的torch.long),否则交叉熵损失函数会报错。
  • 回归任务中标签需转换为float类型(如torch.float32)
# 用4特征,3分类的鸢尾花数据集作为我们今天的数据集 from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split import numpy as np # 加载鸢尾花数据集 iris = load_iris() X = iris.data # 特征数据 y = iris.target # 标签数据 # 划分训练集和测试集 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) # 打印下尺寸 print(X_train.shape) print(y_train.shape) print(X_test.shape) print(y_test.shape)
# 归一化数据,神经网络对于输入数据的尺寸敏感,归一化是最常见的处理方式 from sklearn.preprocessing import MinMaxScaler scaler = MinMaxScaler() X_train = scaler.fit_transform(X_train) X_test = scaler.transform(X_test) #确保训练集和测试集是相同的缩放
# 将数据转换为 PyTorch 张量,因为 PyTorch 使用张量进行训练 # y_train和y_test是整数,所以需要转化为long类型,如果是float32,会输出1.0 0.0 X_train = torch.FloatTensor(X_train) y_train = torch.LongTensor(y_train) X_test = torch.FloatTensor(X_test) y_test = torch.LongTensor(y_test)

2.模型的定义

2.1继承nn.Module类

2.2定义每一个层

2.3定义前向传播流程

import torch import torch.nn as nn import torch.optim as optim class MLP(nn.Module): # 定义一个多层感知机(MLP)模型,继承父类nn.Module def __init__(self): # 初始化函数 super(MLP, self).__init__() # 调用父类的初始化函数 # 前三行通用,后面的是自定义的 self.fc1 = nn.Linear(4, 10) # 输入层到隐藏层 self.relu = nn.ReLU() self.fc2 = nn.Linear(10, 3) # 隐藏层到输出层 # 输出层不需要激活函数,因为后面会用到交叉熵函数cross_entropy,交叉熵函数内部有softmax函数,会把输出转化为概率 def forward(self, x): out = self.fc1(x) out = self.relu(out) out = self.fc2(out) return out # 实例化模型 model = MLP() # def forward(self,x): #前向传播 # x=torch.relu(self.fc1(x)) #激活函数 # x=self.fc2(x) #输出层不需要激活函数,因为后面会用到交叉熵函数cross_entropy # return x

3.定义损失函数和优化器

# 分类问题使用交叉熵损失函数 criterion = nn.CrossEntropyLoss() # 使用随机梯度下降优化器 optimizer = optim.SGD(model.parameters(), lr=0.01) # # 使用自适应学习率的化器 # optimizer = optim.Adam(model.parameters(), lr=0.001)

4.定义训练流程

# 训练模型 num_epochs = 20000 # 训练的轮数 # 用于存储每个 epoch 的损失值 losses = [] for epoch in range(num_epochs): # range是从0开始,所以epoch是从0开始 # 前向传播 outputs = model.forward(X_train) # 显式调用forward函数 # outputs = model(X_train) # 常见写法隐式调用forward函数,其实是用了model类的__call__方法 loss = criterion(outputs, y_train) # output是模型预测值,y_train是真实标签 # 反向传播和优化 optimizer.zero_grad() #梯度清零,因为PyTorch会累积梯度,所以每次迭代需要清零,梯度累计是那种小的bitchsize模拟大的bitchsize loss.backward() # 反向传播计算梯度 optimizer.step() # 更新参数 # 记录损失值 losses.append(loss.item()) # 打印训练信息 if (epoch + 1) % 100 == 0: # range是从0开始,所以epoch+1是从当前epoch开始,每100个epoch打印一次 print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

5.可视化loss流程

import matplotlib.pyplot as plt # 可视化损失曲线 plt.plot(range(num_epochs), losses) plt.xlabel('Epoch') plt.ylabel('Loss') plt.title('Training Loss over Epochs') plt.show()

@浙大疏锦行

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/18 7:26:30

BiliFM:B站音频下载的终极解决方案

BiliFM:B站音频下载的终极解决方案 【免费下载链接】BiliFM 下载指定 B 站 UP 主全部或指定范围的音频,支持多种合集。A script to download all audios of the Bilibili uploader you love. 项目地址: https://gitcode.com/jingfelix/BiliFM 想要…

作者头像 李华
网站建设 2026/4/18 7:03:15

Linux也可以有图形化界面,手把手教你链接X11远程桌面

图形化介绍 Linux图形化界面(GUI)的发展可以说是一部开源软件与操作系统用户体验不断演进的历史。 图形界面需求的出现:早期的Unix和类Unix系统(包括最早的 Linux)全部是基于命令行的,用户只能通过终端输入…

作者头像 李华
网站建设 2026/4/18 7:23:05

汽车零部件企业如何通过OEE钻取分析实现降本增效?

在制造业的激烈竞争中,设备综合效率(OEE)已经成为衡量汽车零部件企业生产效能的核心指标。作为汽车产业链的基石,零部件生产效率的提升不仅关乎企业的成本控制,更直接影响整车厂的供应链稳定性。然而,许多汽…

作者头像 李华
网站建设 2026/4/18 7:26:44

收藏这篇就够了!AI大模型61大应用场景全解析,从入门到实践

本文详细介绍了人工智能大模型在12个领域的61个应用场景,涵盖城市治理、医疗、金融、教育等。大模型通过自然语言处理、图像识别等技术,实现智能诊断、风险评估、个性化学习等功能,推动各行业数字化转型,为生活和工作带来便利&…

作者头像 李华
网站建设 2026/4/18 7:23:15

2025年高频Java面试集锦,害怕面试?收藏这篇就够了

一、Java 基础与集合 1. ArrayList 和 LinkedList 的区别? ArrayList 底层是数组,查询快,增删慢;适合查多改少。 LinkedList 底层是双向链表,增删快,查询慢;适合频繁插入删除。 2. HashMap 的底…

作者头像 李华
网站建设 2026/4/10 22:18:49

第二章 CentOS配置YUM源

一、YUM源配置1,挂载光盘2、清空默认源配置3、创建本地yum源文件4、刷新yum源5、安装软件二,源码安装软件1、安装依赖包2、解压3、配置4、编译5、安装

作者头像 李华