news 2026/4/18 6:27:49

在 TensorFlow 中实现卷积神经网络

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
在 TensorFlow 中实现卷积神经网络

原文:towardsdatascience.com/implementing-convolutional-neural-networks-in-tensorflow-bc1c4f00bd34

欢迎来到我们**深度学习图解**系列的实用实施指南。在这个系列中,我们弥合了理论与实践之间的差距,将之前文章中探讨的神经网络概念生动地呈现出来。

深度学习,图解

在今天的文章中,我们将使用 TensorFlow 构建一个卷积神经网络(CNN)。请确保阅读之前的CNN 文章,因为这篇文章假设你已经熟悉 CNN 的内部工作原理和数学基础。我们将只关注实现,所以先前的知识将帮助你更容易地跟随。

深度学习图解,第三部分:卷积神经网络

我们将创建一个简单的图像分类器,该分类器可以预测给定的图像是否是‘X’。

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/c18369f8a3fd9625d04a716283c30f07.png

我们将详细分解每一步,以确保你理解“如何”以及“为什么”。

第一步:导入必要的库

importtensorflowastffromtensorflow.keras.modelsimportSequentialfromtensorflow.keras.layersimportConv2D,MaxPooling2D,Flatten,Densefromtensorflow.keras.optimizersimportAdamimportnumpyasnpimportmatplotlib.pyplotasplt

TensorFlow 和 Keras(它是 TensorFlow 中的一个高级 API)将处理 CNN 的创建和训练,而 NumPy 和 Matplotlib 将帮助我们进行数据处理和可视化。

注意:为了确保每次运行代码时结果一致,我们将设置一个随机种子:

# Setting seed for reproducibilitynp.random.seed(42)tf.random.set_seed(42)

设置种子确保每次运行代码时随机过程都以相同的方式进行,就像每次玩牌时都按照完全相同的顺序洗牌一样。

第二步:理解和生成数据

让我们首先生成模型将学习分类的图像。之前我们看到一个‘X’可以表示为一个 5×5 像素的图像,如下所示:

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/20fe9154a99d2dd12bf58717de57e868.png

让我们将这些翻译成代码:

# 'X' patterndefgenerate_x_image():returnnp.array([[1,0,0,0,1],[0,1,0,1,0],[0,0,1,0,0],[0,1,0,1,0],[1,0,0,0,1]])

这个函数生成一个简单的 5×5 的‘X’图像。接下来,我们将创建一个函数来生成随机的 5×5 图像,这些图像不类似于‘X’:

defgenerate_not_x_image():# Ensuring not to generate an 'X' patternwhileTrue:img=np.random.randint(2,size=(5,5))ifnotnp.array_equal(img,generate_x_image()):returnimg

第三步:构建数据集

函数准备就绪后,我们现在可以创建一个包含 1,000 张图片的数据集。我们将相应地标记它们,其中 1 代表‘X’的图片,0 代表不是‘X’的图片:

# Create a datasetnum_samples=1000images=[]labels=[]for_inrange(num_samples):ifnp.random.rand()>0.5:images.append(generate_x_image())labels.append(1)else:images.append(generate_not_x_image())labels.append(0)images=np.array(images).reshape(-1,5,5,1)labels=np.array(labels)

这段代码生成了 1,000 张图片,其中一半包含一个‘X’,另一半则没有。然后我们将图片重塑,以确保它们具有 CNN 所需的正确维度。

为了有效地训练我们的模型,我们将此数据集分为训练集和测试集:

# Split the data into training and testing setsfromsklearn.model_selectionimporttrain_test_split x_train,x_test,y_train,y_test=train_test_split(images,labels,test_size=0.2,random_state=42)

这种划分保留了 80%的数据用于训练模型,20%用于测试。测试集帮助我们评估模型在新、未见过的数据上的表现。

在我们深入模型构建之前,让我们看看数据集中的一些图像,以了解我们正在处理的内容。

# Function to display imagesdefdisplay_sample_data(images,labels,num_samples=5):plt.figure(figsize=(10,2))foriinrange(num_samples):ax=plt.subplot(1,num_samples,i+1)plt.imshow(images[i].reshape(5,5),cmap='gray_r')plt.title(f'Label:{labels[i]}')plt.axis('off')plt.show()

此函数显示我们训练集中的图像,帮助我们确认数据是否正确标记和格式化。

# Display first 5 samples of our training datadisplay_sample_data(x_train,y_train)

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/ad861997b1466a28caaba49c059ba8c9.png

第 4 步:构建 CNN 模型

现在我们数据已准备就绪,让我们构建 CNN!这是我们之前使用的架构:

1 – 卷积层: 对输入图像应用四个 3×3 的过滤器以检测特征,并创建四个特征图

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/adfb18670ea552f6289b0eca99553b2e.png

2 – 最大池化层: 减少特征图的维度,使模型更高效

3 – 展平层: 将 2D 数据转换为 1D 数组,为神经网络做准备

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/2001e0556a5c85367b58834f179e4f64.png

4 – 隐藏层: 一个完全连接的隐藏层,包含三个具有 ReLU 激活函数的神经元

5 – 输出层: 带有 sigmoid 激活函数的单个神经元

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/cea7b044d7b7ab2f24f52679fdb68ca5.png

model=Sequential([# 1 - Convolutional LayerConv2D(4,(3,3),activation='relu',input_shape=(5,5,1)),# 2 - Max-Pooling LayerMaxPooling2D(pool_size=(2,2)),# 3 - Flatten LayerFlatten(),# 4 - Hidden LayerDense(3,activation='relu'),# 5 - Output LayerDense(1,activation='sigmoid')])

第 5 步:编译模型

编译模型至关重要,因为它定义了模型将如何学习。以下是编译函数每个部分的作用:

  1. 优化器(Adam): 优化器调整模型的权重以最小化损失函数。我们在这里使用Adam,但这是一个可以使用其他优化器的列表。

  2. 损失函数(二元交叉熵):衡量模型的预测值与实际结果之间的差距。由于我们处理的是二元分类(X 或非 X),**二元交叉熵**是合适的。损失值是模型在训练过程中试图最小化的值。较低的损失值表示更好的性能(即,模型的预测值更接近实际值)。损失值直接影响模型的训练过程,优化器使用它来更新模型权重,以最小化损失。

  3. 指标(准确率):指标用于评估模型的性能。我们使用准确率来跟踪所有预测中正确预测的比例。指标是评估模型性能的额外措施,但不会被优化器用于在训练过程中调整模型。它们提供了一种评估模型性能好坏的方法。

注意:虽然准确率是一个常见的指标,但它并不总是最可靠的,特别是在某些场景下,如不平衡的数据集。然而,为了简单起见,我们在这里使用它。如果您对探索可能提供更细致模型性能视图的其他评估指标感兴趣,这篇文章介绍了一些替代方案。

model.compile(optimizer=Adam(),loss='binary_crossentropy',metrics=['accuracy'])

第 6 步:训练模型

训练模型涉及多次(epochs)向模型提供训练数据,以便它能够学会做出更好的预测。一个 epoch 是整个训练数据集的一次完整遍历。我们将对模型进行 10 个 epoch 的训练:

history=model.fit(x_train,y_train,epochs=10,validation_data=(x_test,y_test))

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/444f1536a01433271dee955cba1e1515.png

第 7 步:评估模型

训练完成后,我们在测试数据上评估模型的性能:

loss,accuracy=model.evaluate(x_test,y_test)print(f'Test Accuracy:{accuracy*100:.2f}%')

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/4c7e283b4653bdf07281badf2099c15f.png

准确率指标告诉我们,测试数据中有 94.8%的图像被正确分类。

第 8 步:可视化训练过程

最后,让我们可视化模型在各个 epoch 中的准确率变化。这有助于我们了解模型在训练过程中的学习效果:

plt.plot(history.history['accuracy'],label='accuracy')plt.plot(history.history['val_accuracy'],label='val_accuracy')plt.xlabel('Epoch')plt.ylabel('Accuracy')plt.legend()plt.show()

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/85e05fae4dfbefd8463a2b31ec97c190.png

就这样。我们用 TensorFlow 在不到 5 分钟内构建了一个简单的卷积神经网络,用于预测图像是否为‘X’!

实践中的深度学习

如往常一样,欢迎您在LinkedIn上与我联系,提出任何评论或问题!

注意:除非特别说明,所有图片均为作者所有。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/3/28 17:00:38

STM32CubeMX时钟配置:超详细版低功耗设计指南

STM32低功耗设计的“心脏”:如何用好STM32CubeMX配置时钟树?你有没有遇到过这样的问题?一个本该靠纽扣电池运行一年的传感器节点,结果三个月就没电了。排查半天,发现MCU一直在“偷偷”耗电——而罪魁祸首,可…

作者头像 李华
网站建设 2026/4/11 17:19:08

企业级小型医院医疗设备管理系统管理系统源码|SpringBoot+Vue+MyBatis架构+MySQL数据库【完整版】

摘要 随着医疗行业的快速发展,小型医院在医疗设备管理方面面临诸多挑战,包括设备信息记录不准确、维护周期混乱、使用效率低下等问题。传统的人工管理方式难以满足现代化医疗设备管理的需求,亟需一套高效、智能化的管理系统来提升设备管理的规…

作者头像 李华
网站建设 2026/4/17 21:34:00

【南京航空航天大学主办,往届已见刊检索 | AP (ISSN: 2352-538X)出版 | 大咖嘉宾与会交流 | 录用率高,见刊快】第五届工程管理与信息科学国际学术会议 (EMIS 2026)

第五届工程管理与信息科学国际学术会议 (EMIS 2026) 2026 5th International Conference on Engineering Management and Information Science 大会时间:2026年1月23-25日 大会地点:中国-沈阳 大会官网:www.icemis.net【投稿参会】 报名…

作者头像 李华
网站建设 2026/4/15 22:33:15

LUCEDA IPKISS Tutorial 84:填充周期阵列

案例分享:周期阵列填充函数:ArrayFillTraceWindow所有代码如下: from si_fab import all as pdk import ipkiss3.all as i3total_length 250 reference i3.LayoutCell().Layout(elements[i3.Rectangle(layeri3.TECH.PPLAYER.V12, box_size(…

作者头像 李华
网站建设 2026/4/17 17:23:42

企业级线上学习资源智能推荐系统管理系统源码|SpringBoot+Vue+MyBatis架构+MySQL数据库【完整版】

摘要 随着数字化教育的快速发展,企业对于高效、个性化的员工培训需求日益增长。传统的线下培训模式受限于时间和空间,难以满足现代企业灵活化、智能化的学习需求。企业级线上学习资源智能推荐系统通过整合海量学习资源,结合用户画像和行为分析…

作者头像 李华