news 2026/4/18 3:24:54

【联邦学习入门指南】Part 4:从零实现一个 FL 系统

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
【联邦学习入门指南】Part 4:从零实现一个 FL 系统

🛠️动手实战:环境配置 + 代码实现 + 避坑指南
🎯目标:抛开晦涩的公式,手把手教你在自己的电脑上搭建并运行第一个联邦学习模拟系统
💡核心:从安装软件到编写“数据切分、客户端训练、服务器聚合”的全流程


📋 目录

  • 1. 准备工作:硬件与软件环境配置
  • 2. 实战思维:如何在单机上模拟联邦?
  • 3. 步骤一:数据准备与切分 (Data Splitting)
  • 4. 步骤二:定义共享模型与客户端逻辑
  • 5. 步骤三:定义服务器逻辑 (FedAvg)
  • 6. 步骤四:组合运行与结果展示

1. 准备工作:硬件与软件环境配置

在写代码之前,我们先确保你的电脑(哪怕是普通笔记本)已经准备好了“厨房”。

1.1 硬件要求

对于本教程的 MNIST(手写数字)案例,无需昂贵的显卡

硬件最低要求推荐配置说明
CPU任意双核i5 / R5 及以上代码量小,CPU 运行更方便,无需配置显卡驱动。
内存8 GB16 GB联邦学习需在内存中暂存多个模型副本,内存过小易报错。

1.2 软件安装 (Anaconda + PyTorch)

为了避免复杂的环境配置,我们采用最稳妥的方案:

  1. 安装 Anaconda(你的工具箱):
  • 访问 Anaconda 官网 下载安装。
  • 安装后打开Anaconda Prompt(黑色终端窗口)。
  1. 创建虚拟环境与安装 PyTorch(你的发动机):
    在终端中依次输入以下指令(每行输入后按回车):
# 1. 创建一个名为 fl_demo 的环境conda create -n fl_demopython=3.9-y# 2. 激活环境 (最重要的一步,看到左侧括号变了才算成功)conda activate fl_demo# 3. 安装 PyTorch (CPU版) 和 Jupyter (编辑器)pipinstalltorch torchvision numpy jupyter
  1. 启动编辑器
    在终端输入jupyter notebook,浏览器会自动弹出一个网页,点击右上角New->Python 3,即可开始写代码。

2. 实战思维:如何在单机上模拟联邦?

在真实世界中,联邦学习涉及 10 台手机和 1 台服务器通过网络传输。但在学习阶段,我们在一台电脑上通过for 循环来模拟这个过程。

  • 数据隔离模拟:我们将一份大数据集强行切分成 10 份,分配给 10 个变量,假装它们互不通气。
  • 通信模拟:变量之间的赋值(server_model = client_model)代替了网络传输。

3. 步骤一:数据准备与切分 (Data Splitting)

我们需要编写代码,将 MNIST 数据集切分给 N 个客户端。

操作:将以下代码复制到 Jupyter 的第一个单元格并运行。

importtorchfromtorchvisionimportdatasets,transformsfromtorch.utils.dataimportDataLoader,Subsetimportcopydefget_dataset(num_clients=5):""" 下载并切分 MNIST 数据集 """# 1. 下载数据# 注意:如果网络报错,请将 download=True 改为 False,并手动下载数据集到 ./data 目录train_data=datasets.MNIST(root='./data',train=True,download=True,transform=transforms.ToTensor())# 2. 模拟数据切分 (IID切分:每个用户拿到的数据分布相似)data_len=len(train_data)indices=list(range(data_len))split_size=data_len//num_clients client_loaders=[]foriinrange(num_clients):# 截取属于第 i 个用户的数据索引subset_indices=indices[i*split_size:(i+1)*split_size]subset=Subset(train_data,subset_indices)# 封装成 DataLoaderloader=DataLoader(subset,batch_size=32,shuffle=True)client_loaders.append(loader)returnclient_loaders# 测试一下client_loaders=get_dataset(num_clients=5)print(f"数据准备完毕!成功创建{len(client_loaders)}个客户端数据源。")

4. 步骤二:定义共享模型与客户端逻辑

所有参与方必须使用相同的神经网络结构。客户端负责“接收模型 -> 训练 -> 返回参数”。

操作:复制到第二个单元格并运行。

importtorch.nnasnnimporttorch.nn.functionalasF# --- 1. 定义网络结构 ---classSimpleCNN(nn.Module):def__init__(self):super(SimpleCNN,self).__init__()self.conv1=nn.Conv2d(1,10,kernel_size=5)self.conv2=nn.Conv2d(10,20,kernel_size=5)self.fc=nn.Linear(320,10)defforward(self,x):x=F.relu(F.max_pool2d(self.conv1(x),2))x=F.relu(F.max_pool2d(self.conv2(x),2))x=x.view(-1,320)x=self.fc(x)returnx# --- 2. 定义客户端 (Client) ---classClient:def__init__(self,client_id,data_loader,device='cpu'):self.client_id=client_id self.data_loader=data_loader self.device=device self.model=SimpleCNN().to(self.device)deflocal_train(self,global_weights,epochs=1):# 加载服务器发来的参数self.model.load_state_dict(global_weights)# 本地训练 (常规的 PyTorch 训练流程)optimizer=torch.optim.SGD(self.model.parameters(),lr=0.01,momentum=0.5)self.model.train()forepochinrange(epochs):fordata,targetinself.data_loader:data,target=data.to(self.device),target.to(self.device)optimizer.zero_grad()output=self.model(data)loss=F.cross_entropy(output,target)loss.backward()optimizer.step()# 关键:只返回参数 (state_dict),不返回数据!returncopy.deepcopy(self.model.state_dict())

5. 步骤三:定义服务器逻辑 (FedAvg)

服务器通过FedAvg (联邦平均算法)将收集到的参数进行加权平均。

操作:复制到第三个单元格并运行。

classServer:def__init__(self,device='cpu'):self.global_model=SimpleCNN().to(device)self.device=devicedefaggregate(self,client_weights_list):""" FedAvg 核心:对参数取平均 """# 拿出第一个客户端的参数作为基准avg_weights=copy.deepcopy(client_weights_list[0])# 逐层累加其他客户端的参数forkeyinavg_weights.keys():foriinrange(1,len(client_weights_list)):avg_weights[key]+=client_weights_list[i][key]# 取平均值avg_weights[key]=torch.div(avg_weights[key],len(client_weights_list))# 更新全局模型self.global_model.load_state_dict(avg_weights)defget_weights(self):returnself.global_model.state_dict()

6. 步骤四:组合运行与结果展示

这是最激动人心的时刻,我们将启动训练循环。

操作:复制到第四个单元格并运行。

# --- 初始化 ---device=torch.device("cpu")server=Server(device)clients=[Client(i,client_loaders[i],device)foriinrange(5)]# 5个客户端print("启动")# --- 主循环 (3轮为例) ---forround_idxinrange(3):print(f"\n--- Round{round_idx+1}---")# 1. 服务器下发参数global_weights=server.get_weights()client_updates=[]# 2. 客户端并行训练forclientinclients:w_local=client.local_train(global_weights,epochs=1)client_updates.append(w_local)print(f"Client{client.client_id}已上传参数")# 3. 服务器聚合server.aggregate(client_updates)print("Server 完成参数聚合 (FedAvg)")print("\n训练结束!全局模型已更新。")

预期输出

如果一切顺利,你将看到如下输出:

🚀 联邦学习系统启动... --- Round 1 --- Client 0 已上传参数 ... Server 完成参数聚合 (FedAvg) --- Round 2 --- ... 训练结束!全局模型已更新。

常见报错 (Troubleshooting)

  1. ModuleNotFoundError:说明环境没激活。请检查命令行左侧是否有(fl_demo)字样。
  2. HTTP Error 503:MNIST 下载失败。请检查网络,或手动下载数据集放入data文件夹。
  3. RuntimeError: CUDA error:请确保代码中写的是device = 'cpu'

🎉祝你天天开心,我将更新更多有意思的内容,欢迎关注!

最后更新:2026年1月
作者:Echo

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/12 12:30:45

SAM 3GPU算力适配:梯度检查点+激活重计算节省40%显存

SAM 3GPU算力适配:梯度检查点激活重计算节省40%显存 1. SAM 3 是什么?图像与视频的“视觉理解助手” 你有没有试过给一张照片里的一只猫单独抠出来,或者想让一段视频里奔跑的小狗始终被高亮框住?过去这需要专业软件、大量手动操…

作者头像 李华
网站建设 2026/4/16 8:58:48

当灰狼优化算法遇上BiLSTM:参数调优的自动化实践

灰狼优化算法与BiLSTM的超参数自动化调优实战 在时间序列预测领域,BiLSTM(双向长短期记忆网络)因其出色的上下文捕捉能力而备受青睐。然而,BiLSTM的性能高度依赖于超参数的选择——从隐藏层节点数到学习率,每个参数都…

作者头像 李华
网站建设 2026/4/3 6:45:23

GLM-4-9B-Chat-1M助力企业知识管理:文档智能检索应用

GLM-4-9B-Chat-1M助力企业知识管理:文档智能检索应用 1. 为什么企业需要“能读懂整本手册”的AI助手? 你有没有遇到过这些场景? 法务同事花三天通读一份287页的并购协议,只为确认某一条款是否隐含风险; 研发团队每次…

作者头像 李华
网站建设 2026/4/18 5:35:25

VibeThinker-1.5B助力私有化部署智能判题系统

VibeThinker-1.5B助力私有化部署智能判题系统 在高校教学、编程竞赛培训和算法课程实践中,教师常面临一个现实困境:学生提交的代码五花八门,手动批改耗时费力,而通用大模型又容易在边界案例中给出错误解析或模糊反馈。更关键的是…

作者头像 李华
网站建设 2026/4/18 5:39:16

动手试了科哥的卡通化工具,结果让我惊呼太像了

动手试了科哥的卡通化工具,结果让我惊呼太像了 大家好,我是小陈,一个喜欢把AI工具用在日常创作里的普通用户。上周偶然看到朋友转发的“科哥人像卡通化工具”,标题写着“真人秒变二次元”,我第一反应是:又…

作者头像 李华