news 2026/4/17 22:29:42

TensorFlow学习系列01 | 实现mnist手写数字识别

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
TensorFlow学习系列01 | 实现mnist手写数字识别
  • 🍨本文为🔗365天深度学习训练营中的学习记录博客
  • 🍖原作者:K同学啊

一、前置知识

1、知识总结

概念

作用

归一化

统一数据范围,加速训练

卷积层

提取图像局部特征

池化层

压缩数据,增强鲁棒性

全连接层

综合特征,输出分类

激活函数(ReLU)

引入非线性

2、CNN网络

二、代码实现

1、准备工作

1.1.设置GPU

import tensorflow as tf gpus = tf.config.list_physical_devices("GPU") if gpus: gpu0 = gpus[0] #如果有多个GPU,仅使用第0个GPU tf.config.experimental.set_memory_growth(gpu0, True) #设置GPU显存用量按需使用 tf.config.set_visible_devices([gpu0],"GPU") print(gpus)
2026-01-08 18:44:48.698143: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

1.2.导入数据

import tensorflow as tf from tensorflow.keras import datasets, layers, models import matplotlib.pyplot as plt # 导入mnist数据,依次分别为训练集图片、训练集标签、测试集图片、测试集标签 (train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz 11490434/11490434 [==============================] - 3s 0us/step

1.3.归一化

数据归一化作用:

  • 使不同量纲的特征处于同一数值量级,减少方差大的特征的影响,使模型更准确
  • 加快学习算法的收敛速度
# 将像素的值标准化至0到1的区间内。(对于灰度图片来说,每个像素最大值是255,每个像素最小值是0,也就是直接除以255就可以完成归一化。) train_images, test_images = train_images / 255.0, test_images / 255.0 # 查看数据维数信息 train_images.shape,test_images.shape,train_labels.shape,test_labels.shape
((60000, 28, 28), (10000, 28, 28), (60000,), (10000,))

1.4.可视化图片

# 将数据集前20个图片数据可视化显示 # 进行图像大小为20宽、10长的绘图(单位为英寸inch) plt.figure(figsize=(20,10)) # 遍历MNIST数据集下标数值0~49 for i in range(20): # 将整个figure分成2行10列,绘制第i+1个子图。 plt.subplot(2,10,i+1) # 设置不显示x轴刻度 plt.xticks([]) # 设置不显示y轴刻度 plt.yticks([]) # 设置不显示子图网格线 plt.grid(False) # 图像展示,cmap为颜色图谱,"plt.cm.binary"为matplotlib.cm中的色表 plt.imshow(train_images[i], cmap=plt.cm.binary) # 设置x轴标签显示为图片对应的数字 plt.xlabel(train_labels[i]) # 显示图片 plt.show()

1.5.调整图片格式

#调整数据到我们需要的格式 train_images = train_images.reshape((60000, 28, 28, 1)) test_images = test_images.reshape((10000, 28, 28, 1)) train_images.shape,test_images.shape,train_labels.shape,test_labels.shape
((60000, 28, 28, 1), (10000, 28, 28, 1), (60000,), (10000,))

2、训练模型

2.1.构建CNN网络

# 创建并设置卷积神经网络 # 卷积层:通过卷积操作对输入图像进行降维和特征抽取 # 池化层:是一种非线性形式的下采样。主要用于特征降维,压缩数据和参数的数量,减小过拟合,同时提高模型的鲁棒性。 # 全连接层:在经过几个卷积和池化层之后,神经网络中的高级推理通过全连接层来完成。 model = models.Sequential([ # 设置二维卷积层1,设置32个3*3卷积核,activation参数将激活函数设置为ReLu函数,input_shape参数将图层的输入形状设置为(28, 28, 1) # ReLu函数作为激活励函数可以增强判定函数和整个神经网络的非线性特性,而本身并不会改变卷积层 # 相比其它函数来说,ReLU函数更受青睐,这是因为它可以将神经网络的训练速度提升数倍,而并不会对模型的泛化准确度造成显著影响。 layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)), #池化层1,2*2采样 layers.MaxPooling2D((2, 2)), # 设置二维卷积层2,设置64个3*3卷积核,activation参数将激活函数设置为ReLu函数 layers.Conv2D(64, (3, 3), activation='relu'), #池化层2,2*2采样 layers.MaxPooling2D((2, 2)), layers.Flatten(), #Flatten层,连接卷积层与全连接层 layers.Dense(64, activation='relu'), #全连接层,特征进一步提取,64为输出空间的维数,activation参数将激活函数设置为ReLu函数 layers.Dense(10) #输出层,输出预期结果,10为输出空间的维数 ]) # 打印网络结构 model.summary()
2026-01-08 18:49:53.891561: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F AVX512_VNNI FMA To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags. 2026-01-08 18:49:55.030714: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1532] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 10099 MB memory: -> device: 0, name: NVIDIA GeForce RTX 3080 Ti, pci bus id: 0000:5e:00.0, compute capability: 8.6 Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= conv2d (Conv2D) (None, 26, 26, 32) 320 max_pooling2d (MaxPooling2D (None, 13, 13, 32) 0 ) conv2d_1 (Conv2D) (None, 11, 11, 64) 18496 max_pooling2d_1 (MaxPooling (None, 5, 5, 64) 0 2D) flatten (Flatten) (None, 1600) 0 dense (Dense) (None, 64) 102464 dense_1 (Dense) (None, 10) 650 ================================================================= Total params: 121,930 Trainable params: 121,930 Non-trainable params: 0 _________________________________________________________________

2.2.编译模型

""" 这里设置优化器、损失函数以及metrics model.compile()方法用于在配置训练方法时,告知训练时用的优化器、损失函数和准确率评测标准 """ model.compile( # 设置优化器为Adam优化器 optimizer='adam', # 设置损失函数为交叉熵损失函数(tf.keras.losses.SparseCategoricalCrossentropy()) # from_logits为True时,会将y_pred转化为概率(用softmax),否则不进行转换,通常情况下用True结果更稳定 loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), # 设置性能指标列表,将在模型训练时监控列表中的指标 metrics=['accuracy'])

2.3.训练模型

""" 这里设置输入训练数据集(图片及标签)、验证数据集(图片及标签)以及迭代次数epochs 关于model.fit()函数的具体介绍可参考我的博客: https://blog.csdn.net/qq_38251616/article/details/122321757 """ history = model.fit( # 输入训练集图片 train_images, # 输入训练集标签 train_labels, # 设置10个epoch,每一个epoch都将会把所有的数据输入模型完成一次训练。 epochs=10, # 设置验证集 validation_data=(test_images, test_labels))
Epoch 1/10 2026-01-08 18:51:21.815488: I tensorflow/stream_executor/cuda/cuda_dnn.cc:384] Loaded cuDNN version 8101 2026-01-08 18:51:24.686085: I tensorflow/stream_executor/cuda/cuda_blas.cc:1786] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once. 1875/1875 [==============================] - 14s 4ms/step - loss: 0.6748 - accuracy: 0.7807 - val_loss: 0.3137 - val_accuracy: 0.9037 Epoch 2/10 1875/1875 [==============================] - 8s 4ms/step - loss: 0.2385 - accuracy: 0.9280 - val_loss: 0.1592 - val_accuracy: 0.9516 Epoch 3/10 1875/1875 [==============================] - 8s 4ms/step - loss: 0.1465 - accuracy: 0.9559 - val_loss: 0.1082 - val_accuracy: 0.9665 Epoch 4/10 1875/1875 [==============================] - 8s 4ms/step - loss: 0.1113 - accuracy: 0.9657 - val_loss: 0.0896 - val_accuracy: 0.9729 Epoch 5/10 1875/1875 [==============================] - 8s 4ms/step - loss: 0.0913 - accuracy: 0.9717 - val_loss: 0.0877 - val_accuracy: 0.9732 Epoch 6/10 1875/1875 [==============================] - 8s 4ms/step - loss: 0.0796 - accuracy: 0.9758 - val_loss: 0.0700 - val_accuracy: 0.9780 Epoch 7/10 1875/1875 [==============================] - 8s 4ms/step - loss: 0.0712 - accuracy: 0.9777 - val_loss: 0.0613 - val_accuracy: 0.9816 Epoch 8/10 1875/1875 [==============================] - 8s 4ms/step - loss: 0.0638 - accuracy: 0.9802 - val_loss: 0.0712 - val_accuracy: 0.9759 Epoch 9/10 1875/1875 [==============================] - 8s 4ms/step - loss: 0.0584 - accuracy: 0.9814 - val_loss: 0.0574 - val_accuracy: 0.9815 Epoch 10/10 1875/1875 [==============================] - 8s 4ms/step - loss: 0.0550 - accuracy: 0.9830 - val_loss: 0.0509 - val_accuracy: 0.9844

3、模型预测

plt.imshow(test_images[1])
<matplotlib.image.AxesImage at 0x7f917a9d5100>

pre = model.predict(test_images) # 对所有测试图片进行预测 pre[1] # 输出第一张图片的预测结果
313/313 [==============================] - 1s 2ms/step array([ 4.0916557, 8.121223 , 17.289515 , -1.078687 , -17.282494 , -4.956851 , -0.79667 , -10.196654 , 4.119256 , -14.890562 ], dtype=float32)
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/18 6:46:51

基于微信小程序的点餐小程序开发与设计

摘要 近年来&#xff0c;伴随者互联网产业的快速发展&#xff0c;各种信息化软件应运而生。当下&#xff0c;人们出门在外一部手机就可以解决线下所有的交易支付&#xff0c;人们对于信息化软件的使用也已不陌生。经济的发展&#xff0c;人均收入的提高&#xff0c;人们去餐饮店…

作者头像 李华
网站建设 2026/4/18 6:31:24

Spring4Shell CVE-2022-22965原理及复现

Spring4Shell&#xff08;正式编号为 CVE-2022-22965&#xff09;是 2022 年 3 月底发现的一个存在于 Spring Framework 中的远程代码执行&#xff08;RCE&#xff09;高危漏洞。由于 Spring 框架在 Java 生态中的核心地位&#xff0c;该漏洞曾引发了全行业的广泛关注&#xff…

作者头像 李华
网站建设 2026/4/18 1:01:06

拥抱大数据领域数据可视化,提升数据分析效率

拥抱大数据领域数据可视化&#xff0c;提升数据分析效率关键词&#xff1a;大数据、数据可视化、数据分析效率、可视化工具、可视化方法摘要&#xff1a;本文深入探讨了大数据领域的数据可视化&#xff0c;旨在帮助大家通过数据可视化来提升数据分析效率。首先介绍了数据可视化…

作者头像 李华
网站建设 2026/4/18 6:31:09

数字孪生在航空发动机总体性能中的应用前景

截至2026年初&#xff0c;数字孪生技术在航空发动机总体性能优化中的应用已从概念验证迈向规模化落地阶段&#xff0c;展现出广阔的应用前景。其核心价值在于通过构建高保真、多物理场耦合、全生命周期覆盖的虚拟镜像&#xff0c;实现对发动机设计、制造、试验、运维等各环节性…

作者头像 李华
网站建设 2026/4/18 6:30:31

雷军又发奖了!1000万奖金花落“玄戒”,未来5年还要砸2000亿搞研发

1月8日一早&#xff0c;科技圈就被雷军的一条消息刷屏了。小米不仅开了个隆重的技术大奖颁奖礼&#xff0c;雷军还在社交平台上大大方方地宣布&#xff1a;今年的千万技术大奖&#xff0c;被“玄戒O1”团队稳稳拿下了。能在小米这么多顶尖项目里脱颖而出&#xff0c;拿到这沉甸…

作者头像 李华
网站建设 2026/4/18 6:31:53

基于51单片机的排队叫号系统—两块单片机串行通信

基于51单片机的排队叫号系统 &#xff08;仿真&#xff0b;程序原理图&#xff0b;设计报告&#xff09; 功能介绍 具体功能&#xff1a; 1.主机通过4个按键模拟4个柜台号&#xff0c;按下按键实现叫号&#xff1b; 2.柜台叫号后&#xff0c;LCD1602显示被叫的号码及叫号的柜…

作者头像 李华