联邦学习

联邦学习实验指导书:基于PyTorch的MNIST手写数字识别

实验目标

  • 理解联邦学习核心流程:掌握“本地训练-参数上传-全局聚合”的循环机制。
  • 动手实现FedAvg算法:使用Python和PyTorch构建一个简单的横向联邦学习系统。
  • 观察数据异质性影响:对比独立同分布与非独立同分布数据对模型性能的影响。
  • 培养工程实践能力:学习模块化编程、参数配置和实验结果分析。

实验环境

  • 操作系统:Windows/Linux/macOS
  • 编程语言:Python 3.8+
  • 核心库:PyTorch, torchvision, numpy, json
  • 硬件:CPU即可,GPU可加速(可选)

安装依赖

pip install torch torchvision numpy

项目结构

建议创建一个名为federated_mnist的文件夹,并按以下结构组织代码:

federated_mnist/
├── main.py              # 主程序入口
├── client.py            # 客户端类定义
├── server.py            # 服务器类定义
├── models.py            # 神经网络模型定义
├── datasets.py          # 数据加载与分区
├── utils/
│   ├── conf.json        # 配置文件
│   └── sampling.py      # 数据采样工具(用于生成非独立同分布数据)
└── data/                # 自动下载MNIST数据集

实验步骤

步骤1:配置实验参数

  • utils/conf.json中设置实验参数。这是整个实验的“控制中心”。
{
  "model_name": "mlp",
  "no_models": 10,
  "type": "MNIST",
  "global_epochs": 50,
  "k": 5,
  "local_epochs": 5,
  "lr": 0.01
}
  • 参数说明
    • no_models:总客户端数量(模拟10家医院)。
    • k:每轮参与训练的客户端数量(模拟每次随机选5家)。
    • global_epochs:全局通信轮次。
    • local_epochs:每个客户端本地训练的轮次。

步骤2:定义神经网络模型

  • models.py中定义一个简单的多层感知机(MLP)用于MNIST分类。
import torch
import torch.nn as nn
import torch.nn.functional as F

class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 784)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)

步骤3:实现客户端逻辑

  • client.py中实现客户端类,负责本地训练。
import torch
import torch.optim as optim

class Client:
    def __init__(self, model, train_loader, config):
        self.model = model
        self.train_loader = train_loader
        self.config = config
        self.optimizer = optim.SGD(self.model.parameters(), lr=config['lr'])

    def local_train(self):
        self.model.train()
        for _ in range(self.config['local_epochs']):
            for data, target in self.train_loader:
                self.optimizer.zero_grad()
                output = self.model(data)
                loss = F.nll_loss(output, target)
                loss.backward()
                self.optimizer.step()
        return self.model.state_dict()

步骤4:实现服务器逻辑

  • server.py中实现服务器类,负责模型聚合。
import copy

class Server:
    def __init__(self, global_model, config):
        self.global_model = global_model
        self.config = config

    def aggregate(self, client_weights):
        # FedAvg: 加权平均
        total_num_samples = sum([w['num_samples'] for w in client_weights])
        avg_state_dict = copy.deepcopy(self.global_model.state_dict())
        
        for key in avg_state_dict.keys():
            avg_state_dict[key] = sum([w['state_dict'][key] * (w['num_samples'] / total_num_samples) for w in client_weights])
        
        self.global_model.load_state_dict(avg_state_dict)
        return self.global_model

步骤5:数据加载与非独立同分布划分(核心)

  • datasets.py中加载MNIST并实现非独立同分布数据划分。这是理解联邦学习挑战的关键。
from torchvision import datasets, transforms
import numpy as np

def load_mnist_data():
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
    train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    return train_dataset

def split_data_iid(dataset, num_clients):
    # 独立同分布:随机均匀分配
    indices = list(range(len(dataset)))
    np.random.shuffle(indices)
    client_indices = np.array_split(indices, num_clients)
    return client_indices

def split_data_noniid(dataset, num_clients):
    # 非独立同分布:按标签分配,每个客户端只拥有部分类别的数据
    # 例如:客户端0只有0,1类数据,客户端1只有2,3类数据...
    labels = np.array(dataset.targets)
    client_indices = []
    samples_per_client = len(dataset) // num_clients
    
    for i in range(num_clients):
        # 为每个客户端分配2个类别的标签
        label_start = (i * 2) % 10
        label_end = (label_start + 2) % 10
        if label_start < label_end:
            indices = np.where((labels >= label_start) & (labels < label_end))[0]
        else:
            indices = np.where((labels >= label_start) | (labels < label_end))[0]
        client_indices.append(indices)
    
    return client_indices

步骤6:主程序入口

  • main.py中整合所有模块,启动联邦训练。
import json
import torch
from models import MLP
from client import Client
from server import Server
from datasets import load_mnist_data, split_data_iid, split_data_noniid
from torch.utils.data import DataLoader, Subset

# 加载配置
with open('utils/conf.json', 'r') as f:
    config = json.load(f)

# 1. 初始化全局模型
global_model = MLP()

# 2. 加载并划分数据
dataset = load_mnist_data()
# 切换iid或noniid来观察不同效果
client_indices = split_data_noniid(dataset, config['no_models'])

# 3. 创建客户端
clients = []
for i in range(config['no_models']):
    client_dataset = Subset(dataset, client_indices[i])
    train_loader = DataLoader(client_dataset, batch_size=32, shuffle=True)
    client = Client(MLP(), train_loader, config)
    clients.append(client)

# 4. 创建服务器
server = Server(global_model, config)

# 5. 联邦训练循环
for global_epoch in range(config['global_epochs']):
    print(f"第 {global_epoch+1} 轮全局训练...")
    
    # 随机选择k个客户端
    selected_clients = np.random.choice(clients, config['k'], replace=False)
    client_weights = []
    
    for client in selected_clients:
        # 本地训练
        state_dict = client.local_train()
        # 记录样本数用于加权平均
        num_samples = len(client.train_loader.dataset)
        client_weights.append({'state_dict': state_dict, 'num_samples': num_samples})
    
    # 全局聚合
    server.aggregate(client_weights)
    
    # 简单评估(可选)
    # 可以在这里添加测试代码,观察全局模型在测试集上的准确率

print("联邦训练完成!")

实验任务与思考

  1. 基础任务:运行代码,观察训练过程。尝试修改conf.json中的global_epochslocal_epochs,观察对最终模型性能的影响。
  2. 核心探究:在main.py中,将split_data_noniid改为split_data_iid,重新运行。对比两种数据分布下,模型的收敛速度和最终准确率有何不同?为什么?
  3. 进阶挑战:尝试在server.pyaggregate函数中,实现一个简单的差分隐私机制。在聚合前,为每个客户端上传的参数添加高斯噪声。观察噪声大小(epsilon)对模型性能的影响。

预期结果

  • 在独立同分布数据下,模型应能较快收敛,准确率较高。
  • 在非独立同分布数据下,模型收敛会变慢,最终准确率可能较低,这直观地展示了数据异质性对联邦学习的挑战。

拓展学习

  • 框架学习:尝试使用更成熟的联邦学习框架,如Flower或FATE,完成同样的实验。
  • 算法改进:研究FedProx算法,它如何解决非独立同分布问题?尝试在你的代码中实现。
  • 安全攻防:模拟一个恶意客户端,在上传参数时故意发送错误的梯度(投毒攻击),观察对全局模型的影响,并思考如何防御。
posted @ 2026-04-13 22:58  左耳听风  阅读(24)  评论(0)    收藏  举报