🛠️动手实战:环境配置 + 代码实现 + 避坑指南
🎯目标:抛开晦涩的公式,手把手教你在自己的电脑上搭建并运行第一个联邦学习模拟系统
💡核心:从安装软件到编写“数据切分、客户端训练、服务器聚合”的全流程
📋 目录
- 1. 准备工作:硬件与软件环境配置
- 2. 实战思维:如何在单机上模拟联邦?
- 3. 步骤一:数据准备与切分 (Data Splitting)
- 4. 步骤二:定义共享模型与客户端逻辑
- 5. 步骤三:定义服务器逻辑 (FedAvg)
- 6. 步骤四:组合运行与结果展示
1. 准备工作:硬件与软件环境配置
在写代码之前,我们先确保你的电脑(哪怕是普通笔记本)已经准备好了“厨房”。
1.1 硬件要求
对于本教程的 MNIST(手写数字)案例,无需昂贵的显卡。
| 硬件 | 最低要求 | 推荐配置 | 说明 |
|---|---|---|---|
| CPU | 任意双核 | i5 / R5 及以上 | 代码量小,CPU 运行更方便,无需配置显卡驱动。 |
| 内存 | 8 GB | 16 GB | 联邦学习需在内存中暂存多个模型副本,内存过小易报错。 |
1.2 软件安装 (Anaconda + PyTorch)
为了避免复杂的环境配置,我们采用最稳妥的方案:
- 安装 Anaconda(你的工具箱):
- 访问 Anaconda 官网 下载安装。
- 安装后打开Anaconda Prompt(黑色终端窗口)。
- 创建虚拟环境与安装 PyTorch(你的发动机):
在终端中依次输入以下指令(每行输入后按回车):
# 1. 创建一个名为 fl_demo 的环境conda create -n fl_demopython=3.9-y# 2. 激活环境 (最重要的一步,看到左侧括号变了才算成功)conda activate fl_demo# 3. 安装 PyTorch (CPU版) 和 Jupyter (编辑器)pipinstalltorch torchvision numpy jupyter- 启动编辑器:
在终端输入jupyter notebook,浏览器会自动弹出一个网页,点击右上角New->Python 3,即可开始写代码。
2. 实战思维:如何在单机上模拟联邦?
在真实世界中,联邦学习涉及 10 台手机和 1 台服务器通过网络传输。但在学习阶段,我们在一台电脑上通过for 循环来模拟这个过程。
- 数据隔离模拟:我们将一份大数据集强行切分成 10 份,分配给 10 个变量,假装它们互不通气。
- 通信模拟:变量之间的赋值(
server_model = client_model)代替了网络传输。
3. 步骤一:数据准备与切分 (Data Splitting)
我们需要编写代码,将 MNIST 数据集切分给 N 个客户端。
操作:将以下代码复制到 Jupyter 的第一个单元格并运行。
importtorchfromtorchvisionimportdatasets,transformsfromtorch.utils.dataimportDataLoader,Subsetimportcopydefget_dataset(num_clients=5):""" 下载并切分 MNIST 数据集 """# 1. 下载数据# 注意:如果网络报错,请将 download=True 改为 False,并手动下载数据集到 ./data 目录train_data=datasets.MNIST(root='./data',train=True,download=True,transform=transforms.ToTensor())# 2. 模拟数据切分 (IID切分:每个用户拿到的数据分布相似)data_len=len(train_data)indices=list(range(data_len))split_size=data_len//num_clients client_loaders=[]foriinrange(num_clients):# 截取属于第 i 个用户的数据索引subset_indices=indices[i*split_size:(i+1)*split_size]subset=Subset(train_data,subset_indices)# 封装成 DataLoaderloader=DataLoader(subset,batch_size=32,shuffle=True)client_loaders.append(loader)returnclient_loaders# 测试一下client_loaders=get_dataset(num_clients=5)print(f"数据准备完毕!成功创建{len(client_loaders)}个客户端数据源。")4. 步骤二:定义共享模型与客户端逻辑
所有参与方必须使用相同的神经网络结构。客户端负责“接收模型 -> 训练 -> 返回参数”。
操作:复制到第二个单元格并运行。
importtorch.nnasnnimporttorch.nn.functionalasF# --- 1. 定义网络结构 ---classSimpleCNN(nn.Module):def__init__(self):super(SimpleCNN,self).__init__()self.conv1=nn.Conv2d(1,10,kernel_size=5)self.conv2=nn.Conv2d(10,20,kernel_size=5)self.fc=nn.Linear(320,10)defforward(self,x):x=F.relu(F.max_pool2d(self.conv1(x),2))x=F.relu(F.max_pool2d(self.conv2(x),2))x=x.view(-1,320)x=self.fc(x)returnx# --- 2. 定义客户端 (Client) ---classClient:def__init__(self,client_id,data_loader,device='cpu'):self.client_id=client_id self.data_loader=data_loader self.device=device self.model=SimpleCNN().to(self.device)deflocal_train(self,global_weights,epochs=1):# 加载服务器发来的参数self.model.load_state_dict(global_weights)# 本地训练 (常规的 PyTorch 训练流程)optimizer=torch.optim.SGD(self.model.parameters(),lr=0.01,momentum=0.5)self.model.train()forepochinrange(epochs):fordata,targetinself.data_loader:data,target=data.to(self.device),target.to(self.device)optimizer.zero_grad()output=self.model(data)loss=F.cross_entropy(output,target)loss.backward()optimizer.step()# 关键:只返回参数 (state_dict),不返回数据!returncopy.deepcopy(self.model.state_dict())5. 步骤三:定义服务器逻辑 (FedAvg)
服务器通过FedAvg (联邦平均算法)将收集到的参数进行加权平均。
操作:复制到第三个单元格并运行。
classServer:def__init__(self,device='cpu'):self.global_model=SimpleCNN().to(device)self.device=devicedefaggregate(self,client_weights_list):""" FedAvg 核心:对参数取平均 """# 拿出第一个客户端的参数作为基准avg_weights=copy.deepcopy(client_weights_list[0])# 逐层累加其他客户端的参数forkeyinavg_weights.keys():foriinrange(1,len(client_weights_list)):avg_weights[key]+=client_weights_list[i][key]# 取平均值avg_weights[key]=torch.div(avg_weights[key],len(client_weights_list))# 更新全局模型self.global_model.load_state_dict(avg_weights)defget_weights(self):returnself.global_model.state_dict()6. 步骤四:组合运行与结果展示
这是最激动人心的时刻,我们将启动训练循环。
操作:复制到第四个单元格并运行。
# --- 初始化 ---device=torch.device("cpu")server=Server(device)clients=[Client(i,client_loaders[i],device)foriinrange(5)]# 5个客户端print("启动")# --- 主循环 (3轮为例) ---forround_idxinrange(3):print(f"\n--- Round{round_idx+1}---")# 1. 服务器下发参数global_weights=server.get_weights()client_updates=[]# 2. 客户端并行训练forclientinclients:w_local=client.local_train(global_weights,epochs=1)client_updates.append(w_local)print(f"Client{client.client_id}已上传参数")# 3. 服务器聚合server.aggregate(client_updates)print("Server 完成参数聚合 (FedAvg)")print("\n训练结束!全局模型已更新。")预期输出
如果一切顺利,你将看到如下输出:
🚀 联邦学习系统启动... --- Round 1 --- Client 0 已上传参数 ... Server 完成参数聚合 (FedAvg) --- Round 2 --- ... 训练结束!全局模型已更新。常见报错 (Troubleshooting)
ModuleNotFoundError:说明环境没激活。请检查命令行左侧是否有(fl_demo)字样。HTTP Error 503:MNIST 下载失败。请检查网络,或手动下载数据集放入data文件夹。RuntimeError: CUDA error:请确保代码中写的是device = 'cpu'。
🎉祝你天天开心,我将更新更多有意思的内容,欢迎关注!
最后更新:2026年1月
作者:Echo