DLAI-联邦学习笔记-全-
DLAI 联邦学习笔记(全)
001:课程介绍与动机 🎯
在本节课中,我们将要学习联邦学习的基本概念、其核心动机以及它如何解决传统机器学习中的数据隐私和集中化挑战。我们将通过一个简单的例子来理解联邦学习的工作流程。
欢迎来到由Flower Labs合作打造的《联邦学习入门》课程。我很高兴向大家介绍本课程的讲师Daniel Boyitto,他是开源联邦学习框架Flower的创始人之一。
感谢Andrew。我很高兴来到这里。在本课程中,你将使用Flower探索联邦学习。Flower是一个流行的开源框架,拥有庞大的AI研究者和开发者社区。它将使你能够构建一个联邦学习系统,并以一种增强隐私保护的方式运行分布式机器学习训练任务。
什么是联邦学习?🤔
上一节我们介绍了课程和讲师,本节中我们来看看联邦学习要解决的核心问题。
假设你想在医学图像上训练一个模型,但这些图像分散在不同的医院中。由于隐私和法规限制,可能无法将所有图像集中收集到一个地方。而通过联邦学习,你可以在分布式数据源上进行训练,而无需将所有数据集中起来。
其核心思想是:不将数据移动到训练处,而是将训练移动到数据处。具体做法是在所有医院运行分布式训练任务,之后仅集中模型参数,而非原始数据本身。通过这种方式,最终可以得到一个受益于所有医院数据的模型,而无需任何原始数据离开其所在的医院。
联邦学习工作流程示例 ✍️
在理解了基本概念后,我们通过一个具体例子来看看联邦学习是如何运作的。
在本课程中,你将使用MNIST手写数字数据集进行探索。假设数据被分割,每个部分都缺少一些数字。例如,一部分数据缺少数字1,另一部分则缺少数字7。
以下是联邦学习在此场景下的工作步骤:

- 本地训练:你使用你拥有的手写数字数据(例如,缺少数字1的数据集)在本地训练模型。同时,其他人使用他们自己的数据(例如,缺少数字7的数据集)进行训练。
- 参数上传:训练完成后,每个人将更新后的模型参数发送到中央服务器。
- 聚合更新:服务器聚合来自所有数据源的模型更新,从而改进一个全局模型。关键点在于,服务器无法访问任何个体的原始数据源。
- 模型共享:改进后的全局模型可以与所有人共享。
这个过程可以公式化地表示为一种常见的聚合方法(如FedAvg):
w_global = ∑ (n_k / n) * w_k
其中,w_global是全局模型参数,n_k是客户端k的数据量,n是总数据量,w_k是客户端k训练后的模型参数。
联邦学习的优势与意义 🌟
了解了工作流程后,我们来看看联邦学习为何如此令人兴奋及其重要意义。
联邦学习让我们能够在数据始终由拥有它的用户和组织控制的前提下,构建强大而精确的模型。通过在单个设备或服务器上本地训练模型,我们可以利用广泛的数据,而无需在中心共享实际数据。
这种方法对于医疗保健和金融等领域尤其重要,因为这些领域的数据非常敏感,需要被保护。联邦学习使我们能够为那些以前没有足够数量或足够多样性训练数据的任务训练模型。
本课程你将学到什么 📚

在概述了联邦学习的价值之后,本节将简要介绍你将在本课程中掌握的具体技能。
在本课程中,你将学习:
- 联邦训练过程的工作原理。
- 如何调整和优化联邦学习系统。
- 如何在联邦学习中考虑数据隐私。
- 如何考虑联邦学习过程中的带宽使用。
你还将学习差分隐私,通常简称为DP。这是一种保护个体数据点(如消息或图像)的技术。在本课程中,我们将描述一种向模型权重添加少量噪声的技术,以掩盖训练集中可能存在的任何潜在的私人敏感细节,同时仍然允许模型进行有效学习。
你将获得联邦学习系统中不同组件的概述,学习如何自定义和调整它们,以及如何协调训练过程以构建更好的模型。
课程结构与致谢 🙏
最后,我们来看看第一课的具体安排,并对课程制作人员表示感谢。
在第一课中,你将从使用联邦学习的动机开始。你将探索传统集中式机器学习的挑战(即数据必须收集在一处),并了解联邦学习如何通过分布式训练来解决这个问题。


这听起来很棒,让我们进入下一个视频,正式开始学习。

许多人为创建本课程付出了努力。我要感谢Flower Labs的Mohammed Nassri、Ruth Galino、Javier Fernandez Marquina,以及DeepLearning.AI的Dila Eadin和Jeff Ladray。
本节课总结:本节课我们一起学习了联邦学习的核心动机,即在不集中原始数据的前提下进行协同模型训练,以保护数据隐私。我们通过一个手写数字识别的例子,了解了联邦学习的基本工作流程(本地训练、参数上传、服务器聚合、模型共享),并认识了其在不同领域的应用价值。最后,我们预览了本课程将要涵盖的关键技术主题,如系统调优和差分隐私。
002:为什么需要联邦学习 🧠


在本节课中,我们将学习为什么联邦学习是解锁大量当前无法访问的训练数据的关键。你将看到训练数据对于训练优秀模型的重要性,同时也会了解传统训练方法的局限性。我们还将通过实例,了解联邦学习如何用于在分布于不同组织甚至数亿用户设备的数据上训练模型。
数据的重要性与挑战
上一节我们介绍了课程目标,本节中我们来看看数据在现代机器学习中的核心地位。
以近期发布的 Llama 3 为例,它仅在 Llama 2 发布九个月后问世。Llama 3 的性能相比 Llama 2 有了巨大飞跃。最令人印象深刻的细节之一是,最小的 Llama 3 模型(8B 参数)在多项指标上大幅超越了最大的 Llama 2 模型(70B 参数)。


这是如何实现的?Llama 2 和 Llama 3 之间最显著的变化之一是训练数据量。根据官方公告,Llama 3 的训练数据集是 Llama 2 的七倍大,并且包含了四倍多的代码。Llama 2 使用了大约 2 万亿个词元进行训练,而 Llama 3 将这个数字增加到了 15 万亿。这充分证明了海量、高质量训练数据的重要性。
与此同时,业界也在讨论大型语言模型是否正在耗尽训练数据。全球可用数据的总量难以精确估计。根据 Epoch AI 博客的近期估算,全球可用于 LLM 训练的高质量英文文本数据大约在 15 万亿词元左右。当前的 LLM 训练集似乎已经接近用尽所有可用的高质量英文文本。我们或许能通过数据增强等方式略微增加数据量,但这被认为是公开可用训练数据的上限。
有趣的是,即使是 LLM 自身也认为 LLM 正在耗尽训练数据,并且这个问题会随着时间的推移而加剧。
公开数据 vs. 私有数据
一个较少被讨论的重要方面是公开数据与私有数据的对比。
与 FineWeb 数据集中的 15 万亿词元和非英文数据中的 18 万亿词元相比,仅私人存储的即时消息中就估计有 650 万亿词元,所有存储的电子邮件中甚至有 1200 万亿词元。
这并非建议将这些数据纳入 LLM 训练,而是作为一个数据点,用以对比全球公开数据与敏感私有数据的数量级。
这引出了一个有趣的现象:我们知道数据对于训练优秀模型至关重要,我们似乎正在耗尽训练数据,但同时,又有海量的数据未被利用。
我们将在课程 2 中深入探讨基于私有数据的联邦 LLM 微调。在本课程中,我们通过介绍联邦学习来为此奠定基础。
敏感数据的分布式本质
让我们更深入地探讨敏感数据的主题。数据天然是分布式的。

以下是数据在不同领域的分布情况:
- 医疗保健:数据分布在不同的医院。
- 政府:数据分布在不同的政府机构。
- 金融:数据分布在不同的监管区域。
- 制造业:数据分布在不同的工厂。


在用户设备层面,敏感数据不仅存在于手机和笔记本电脑上,也存在于其他类型的智能设备中,如汽车甚至家中的扫地机器人。
传统集中式训练假设数据是集中的,它只在一个数据集上运行,而忽略了所有其他数据集。结果是,大量有价值的数据并未被用于训练。实际上,世界上大部分数据都无法轻易用于模型训练。
常见的解决方法是尝试将更多数据收集到一个地方,以增加单个数据集的规模。但在太多情况下,收集数据根本行不通。数据需要移动,但这往往由于多种原因而不可行:数据可能很敏感、数据量可能太大、用户隐私可能阻止我们收集数据、法规可能强制数据留在特定区域,有时这根本不切实际。
数据缺失的影响:一个实验
你可能会想,这实际上是一个多大的问题?如果我有数据,但分布不均,会发生什么?为了理解这一点,我们将构建三个数据集。
以下是实验步骤:
- 导入工具:首先导入一些工具函数,包括
utils.py中的MNIST数据集处理函数。 - 加载并分割数据:使用
torchvision.datasets.MNIST下载 MNIST 数据集,并将其随机分割成三个部分,以模拟分布在三个分区(如三个组织或用户设备)的数据。# 示例代码:分割数据集 total_len = len(train_dataset) part_size = total_len // 3 part1, part2, part3 = random_split(train_dataset, [part_size, part_size, total_len - 2*part_size]) - 模拟非独立同分布:通过在每个数据集中排除特定的数字来改变数据分布,模拟现实世界中不同数据源拥有不同数据特性的情况。
- 数据集 1 排除数字 1、3、7。
- 数据集 2 排除数字 2、5、8。
- 数据集 3 排除数字 4、6、9。
- 训练三个独立模型:在每个数据集上分别训练一个具有相同架构的简单神经网络模型(例如,一个包含两个全连接层的 PyTorch 模型)。训练 10 个周期,观察损失下降。
- 评估模型性能:在完整的 MNIST 测试集上评估每个模型。然后,为每个模型创建特定的测试子集,仅包含其训练时未见过的数字(例如,为模型 1 创建仅包含数字 1、3、7 的测试集),并再次评估。
- 分析混淆矩阵:计算每个模型在完整测试集上的混淆矩阵,以深入了解模型对每个数字的分类表现。
实验结果:
- 每个模型在完整测试集上的准确率大约在 65% 到 70%,这优于随机猜测(10%),但由于缺少三个数字的训练样本,性能受限。
- 然而,在仅包含其训练时缺失数字的特定测试子集上,每个模型的准确率都是 0%。
- 混淆矩阵显示,对于训练中缺失的标签(例如,模型 1 中的数字 1、3、7),模型学会了从不预测这些类别,而是错误地预测为其他看似相近的类别(例如,将数字 1 预测为 2 或 8)。
这个实验清晰地展示了训练数据的重要性以及数据缺失的后果:如果训练数据缺失某些类别,模型不仅无法识别它们,还会学会做出错误的预测。

联邦学习的核心理念
现在你已经看到了集中式训练的问题,以及当数据缺失某些部分时它是如何失效的。那么,联邦学习能让你做什么?它如何帮助解决这种情况?
在理想情况下,我们可以在所有可用数据集上训练模型,同时每个参与者都能保留对自己数据的控制权,或保持其数据的私密性。用户可以将数据保持私有,但仍然可以进行训练协作。在医疗保健等关键领域训练模型时,联邦学习是实现这一未来的关键组成部分。
联邦学习的核心思想是将模型训练移动到数据所在之处,而让数据保持原位。数据可以保留在组织的孤岛中或用户设备上,组织或用户保留对其数据的完全控制权。模型训练发生在数据所在的地方:公司的 GPU 集群、组织存储数据的云账户,甚至用户设备上。
联邦学习协调这些不同数据集和设备之间的训练过程。这使得联邦学习能够访问更多的数据和计算资源,包括组织孤岛中的敏感分布式数据以及用户设备上的数据。
在下一课中,你将确切了解这是如何运作的。但首先,让我们看几个联邦学习在工业界的真实案例。
联邦学习的实际应用案例

以下是联邦学习在不同场景下的应用实例:
- 金融领域:金融数据受到严格监管。例如,美国客户的交易数据必须存储在美国,欧洲客户的交易数据必须存储在欧洲。这些交易数据对于训练反洗钱模型非常有价值。借助联邦学习,可以在保持数据存储在世界各地不同区域的同时,仍然能够跨这些分布式数据训练一个统一的模型。
- 谷歌键盘(Gboard):这是联邦学习的另一个极端案例。该系统部署在数亿台用户设备上。当你使用谷歌键盘输入句子时,键盘会尝试预测你接下来要输入的单词或补全句子(智能撰写)。这些功能由语言模型驱动。用于训练的用户数据非常敏感,无法被收集。谷歌实际上是最早提出并开创联邦学习的公司之一,以使这些模型能够在无需收集数据的情况下在用户设备上进行训练。
- 医疗健康协作:前两个例子我们看到的是单个组织拥有分布式数据(跨区域或跨设备)。第三个例子的特殊之处在于多个组织之间的协作。在医疗领域,数据通常分布在许多医院。在下一课中将用到的 Flower 框架,曾被英国国家医疗服务体系用于与牛津大学合作,基于 13 万名患者的血液和生命体征数据,训练了一个早期的 COVID 筛查模型。这是一个激动人心的项目,因为它允许不同的医院协作训练模型。这种方法是在医疗健康领域推广 AI 的关键推动力,因为单个组织几乎从未拥有足够的数据来训练现代数据饥渴型模型架构。

总结
本节课中我们一起学习了以下内容:
- 数据的重要性:数据的体量和多样性对于训练优秀模型至关重要。
- 数据的困境:我们似乎正在耗尽(公开)训练数据,但同时又有海量的(私有)数据未被使用。
- 根本原因:数据通常是分布式的,而传统训练方法假设数据是集中的。集中化数据通常很困难或几乎不可能。
- 联邦学习的价值:你看到了联邦学习在分布式数据上运行,它被部署在许多不同行业,并运行在分布式设备或组织孤岛之上。
- 实际应用:我们探讨了联邦学习在金融、消费级应用(谷歌键盘)和跨组织医疗协作中的实际案例。

在下一课中,我们将深入探讨联邦学习具体是如何工作的。
003:联邦训练流程 🚀



在本节课中,我们将使用 Flower 和 PyTorch 构建第一个联邦学习项目。你将学习联邦学习如何让你在分布式数据上训练 AI 模型。我们将使用的模型和数据只是一个示例,用于展示联邦学习的实际运作,但请记住,这种方法可以扩展到大多数其他模型、数据集,甚至不同的框架,如 TensorFlow、JAX、Hugging Face Transformers 和 Apple 的 MLX。现在,让我们开始构建。
联邦学习系统概述
上一节我们介绍了联邦学习的基本概念,本节中我们来看看一个基础联邦学习系统的构成。
在一个基础的联邦学习系统中,有一个服务器和多个客户端。服务器本身通常没有任何数据。它可能有一些用于评估全局模型的数据。但在经典的联邦学习中,它没有训练数据。客户端才是拥有实际训练数据的一方。
如果你有一个由五家医院协作进行模型训练的系统,那么你会有五个客户端,每家医院一个。每个客户端都运行在各自的医院环境中,并且能够访问该特定医院的数据。如果你有一个由一亿个用户设备持有数据的系统,那么你就有一亿个客户端,每个用户设备上运行一个客户端。
服务器的角色是协调这些客户端之间的训练。客户端的角色是在各自的本地数据上进行实际的训练。服务器和客户端都拥有自己的模型副本。服务器上的模型通常称为全局模型。客户端上的模型通常称为本地模型。
跨客户端的训练流程
上一节我们介绍了系统构成,本节中我们来看看训练是如何在多个客户端之间进行的。
整个过程从服务器初始化全局模型参数开始。服务器将全局模型的参数发送给客户端。在示例中,我们有五个客户端(平板、台式机、手机、笔记本电脑和服务器)。这五个客户端随后在其本地数据上训练模型。它们只训练一小段时间,通常只在本地数据集上训练一个轮次。
本地训练结束后,客户端将其改进后的模型发送回服务器。现在服务器拥有五个改进后的模型,它们的权重都略有不同。但我们需要的是一个模型,而不是五个。为了得到一个模型,服务器会聚合这五个模型。聚合模型的方法有多种,但最常见的一种是简单地平均权重。
在第一次聚合之后,你会得到一个略有改进的全局模型版本。使用这个新版本的全局模型,你将重复之前描述的步骤:将新模型发送给客户端,客户端在其本地训练数据上训练,它们发回改进后的模型,然后服务器聚合这些模型。联邦学习是一个迭代过程,它会一遍又一遍地重复这些所谓的轮次,直到收敛。
联邦学习算法
以下是联邦学习算法的更正式描述:
- 初始化:服务器初始化全局模型。
- 通信轮次:对于每个通信轮次,服务器将全局模型发送给参与的客户端,每个客户端接收全局模型。
- 客户端训练与模型更新:每个参与的客户端在本地数据集上训练接收到的模型。完成后,客户端将其本地更新后的模型发送回服务器。
- 模型聚合:服务器使用聚合算法(例如联邦平均)聚合从所有客户端收到的更新模型。联邦平均是对从客户端收到的所有模型更新进行加权平均,权重由每个特定客户端用于训练的训练样本数量决定。
- 收敛检查:如果满足收敛标准,则结束联邦学习过程。否则,进入下一个通信轮次(即步骤2)。
构建第一个联邦学习项目
在上一课中,我们构建了三个独立的数据集并在其上训练了三个独立的模型。在本课中,我们将连接这些独立的模型,目标是在三个分布式数据集上训练一个协作模型。让我们进入实验环节。
和之前一样,我们将从一些导入开始。在本课中,我们将使用 Flower 联邦学习框架来“联邦化”之前使用的训练流程。因此,除了实用程序外,我们还有一些与 Flower 相关的导入,例如 ClientApp、ServerApp 和 FedAvg(联邦平均策略)。
准备数据集
第一步是重新创建上一课中用于在 MNIST 数据上训练的三个数据集。MNIST 训练数据集以与之前相同的变换加载以进行归一化。然后,数据集被分成三个部分:Part1、Part2 和 Part3,大小与上一课相同。使用相同的随机种子以确保可复现性。数字 1、3 和 7 从 Part1 中排除;数字 2、5 和 8 从 Part2 中排除;数字 4、6 和 9 从 Part3 中排除。为了后续使用,我们将所有三个训练数据集放入一个名为 train_sets 的列表中。
完整的 MNIST 测试数据集以相同的变换加载,并且我们还创建了与第 1 课相同的三个测试数据子集:test_set_137、test_set_258 和 test_set_469。
模型权重交换函数
在联邦学习中,需要在服务器和客户端之间交换模型参数。当客户端从服务器接收到模型参数时,它需要用这些新参数更新本地模型。当客户端完成训练时,它需要将本地模型参数的最新版本发送回服务器。为了实现这一点,我们需要两个函数:set_weights 和 get_weights。
get_weights 函数接受一个参数,即对我们简单 PyTorch 模型的引用。然后它遍历 state_dict 中的项,将每一项转换为 NumPy 数组,并返回包含所有这些数组的列表。我们在本地模型训练完成后使用 get_weights 来获取模型更新后的权重并将其发送回服务器。
set_weights 函数则相反。它接受两个参数:对我们简单 PyTorch 模型的引用和一个 NumPy 数组列表。然后它使用这个 NumPy 数组列表来更新模型 state_dict 中的所有项。我们在本地模型训练之前使用 set_weights,利用从服务器接收到的新权重来更新模型的权重。请注意,这两个函数都是模型特定的。如果你使用不同的模型,可能需要相应地调整 set_weights 和 get_weights。
创建 Flower 客户端
为了连接你现有的训练和评估流程,你需要编写一个 Flower 客户端。FlowerClient 使用现有的函数(如 train_model 和 evaluate_model)来使 Flower 框架能够在一组参与的客户端上编排联邦训练。Flower 客户端类被定义为 NumPyClient 的子类。你向构造函数传递三个参数来初始化一个 Flower 客户端对象:net(我们在 PyTorch 中实现的简单神经网络模型)、train_set(一个特定客户端的训练数据集)和 test_set(一个特定客户端的测试数据集)。
Flower 客户端通常定义两个方法:fit 方法使用提供的参数和本地训练数据训练神经网络;evaluate 方法使用提供的参数和本地测试数据集评估神经网络的性能。
为了使 Flower 框架在需要时能够创建客户端对象,我们需要实现一个名为 client_fn 的函数,该函数按需创建 Flower 客户端实例。这对于优化资源利用率是必要的。联邦训练可以轻松扩展到数百个客户端,但在构建此类系统时,你希望在一台机器上高效地模拟它们。每当 Flower 需要一个特定客户端的实例来调用 fit 或 evaluate 时,它就会调用 client_fn。这使得框架能够在特定客户端对象不再需要时丢弃它们并释放资源。
最后,我们通过传入之前定义的 client_fn 来创建一个 ClientApp 实例。ClientApp 是客户端端所有操作的入口点,你将在接下来的课程中了解 ClientApp 的更多功能。
创建服务器端评估与聚合
现在你有了一个可以执行本地训练和评估的 ClientApp,你需要一个服务器端的对应部分来聚合从客户端收到的更新模型。当然,你也想看看全局模型的表现如何。
为此,你首先创建一个名为 evaluate 的评估函数。evaluate 接受三个参数:当前服务器轮次、全局模型参数的最新版本(一个 NumPy 数组列表)和一个配置字典。类似于 FlowerClient.evaluate,它使用 set_weights 用最新参数更新模型。它在多个测试数据集上评估模型的性能。在我们的例子中,是 test_set、test_set_137、test_set_258 和 test_set_469,以评估其在完整 MNIST 测试集以及不同数字子集上的准确性。它会打印评估结果,包括所有数字的准确率以及特定子集(137、258 和 469)在每个服务器轮次上的准确率,这样我们可以看到准确率在联邦学习的多轮中是如何演变的。如果当前服务器轮次是最后一轮(由 server_round == 3 表示),它会计算并绘制模型在整个测试数据集上预测的混淆矩阵。这将使你能够理解联邦模型与第 1 课中的三个独立模型相比表现如何。
要创建 ServerApp,你需要决定要使用哪种策略。策略是一种抽象,它实现了服务器端的联邦学习算法。我们之前介绍过联邦平均,但还有许多其他算法,如 FedAdam、FedMedian 和 Q-FedAvg,其中许多在 Flower 中作为内置策略提供。让我们从简单的联邦平均开始。
要初始化 FedAvg,你需要传递 fraction_fit(选择用于训练的可用客户端的比例)、fraction_eval(选择用于评估的可用客户端的比例)、initial_parameters(初始模型权重)和 evaluate_fn(用于服务器端模型评估的函数)。
现在你有了一个策略,你可以创建一个 ServerApp 实例。有了 ClientApp 和 ServerApp,你终于可以开始训练了。一个真正的联邦学习系统将分布在多台服务器或用户设备上。在这个笔记本环境中,你通过在一台机器上运行所有内容来模拟这样的系统。为此,你使用一个名为 run_simulation 的函数,它接受三个参数:ServerApp、ClientApp 和 num_supernodes(要模拟的客户端数量)。run_simulation 调用客户端 supernodes 是为了强调这些节点在联邦学习过程中的重要性,与传统的客户端-服务器设置(拥有强大的服务器和瘦客户端)相比,在联邦学习中,客户端才是真正的明星。它们拥有宝贵的数据和进行计算训练的能力。让我们运行它。

运行模拟并分析结果
你可以看到 run_simulation 启动了 Flower 服务器应用,使用一个执行三轮联邦学习的配置。它告诉我们它正在使用策略提供的初始全局模型参数。然后它继续使用你之前定义的评估函数来评估这些初始全局模型参数。这意味着它评估所有数字的测试准确率,也评估你之前定义的三个不同子集上的准确率。然后它继续进入联邦学习的第一轮。策略从总共三个可用客户端中采样三个客户端。它对这些客户端进行采样,将全局模型参数发送给这三个客户端,并要求这些客户端在其本地数据集上执行训练。这需要一分钟才能完成。
你现在可以看到服务器收到了三个结果,零个失败。所有三个客户端都将其结果提交回服务器。然后服务器继续再次调用评估函数,以查看新聚合的全局模型在测试集(完整测试集以及你之前定义的三个不同测试集)上的表现。服务器继续进行第二轮。它再次从总共三个可用客户端中采样三个客户端。你可以看到,服务器再次收到了三个结果,零个失败。它继续再次调用评估函数来评估新聚合的全局模型。然后进入第三轮。对于第三轮,它再次从三个可用客户端中采样三个客户端。这又需要一分钟才能在客户端完成。当三个客户端完成其本地训练时,它们将其更新后的模型发送回服务器。你可以看到服务器收到了三个结果,零个失败。然后它继续调用 evaluate 来评估最终全局模型在完整 MNIST 测试集以及你之前构建的三个特定集合上的表现。
请记住,在第 1 课中,当你训练三个独立模型时,这些模型达到了大约 65% 到 70% 的准确率。在这里,通过联邦学习,你可以看到准确率大幅跃升至 96%。这意味着全局模型比第 1 课中训练的任何单个模型都要好得多。也许更有趣的是,你可以看到在这些特定子集上,全局模型的准确率从之前的 0% 跃升至 94% 到 96% 之间。观察混淆矩阵,我们看到与第 1 课中训练的单个模型的混淆矩阵相比,情况大不相同。该模型学会了如何分类所有数字,即使这些数字在其中一个数据集中缺失。因此,我们不再看到任何列中只有零的情况。

课程总结

在本节课中,我们一起学习了联邦学习的核心流程。联邦学习是一个迭代过程,拥有数据的客户端训练模型,而通常没有数据的服务器则聚合模型更新。你通过 Flower 的 ClientApp 定义客户端端的训练、评估或分析,并通过 Flower 的 ServerApp 定义服务器端的配置或聚合。在开发过程中,你通常在一台机器上模拟这些系统,但在生产环境中,你将它们部署在拥有各个数据集的不同机器上。
004:调优与自定义 🛠️

在本节课中,我们将对第2课中构建的Flower项目进行自定义和调优。你将学习联邦系统中通常需要调优的多个方面,并将在本课中实现一个自定义的训练调度方案。

好了,让我们开始吧。
概述
与传统的集中式训练相比,联邦学习引入了一些额外的概念和组件,这些组件在训练过程中可以进行自定义和调优。本节课我们将探讨这些可调优的方面,并学习如何在Flower框架中实现它们。
联邦学习中的可调优组件
上一节我们介绍了联邦学习的基本流程,本节中我们来看看其中可以自定义和调优的关键组件。
客户端选择策略
一个重要的因素是,如何选择参与每一轮训练的客户端。在下图中,你可以看到五个客户端。与其将全局模型发送给所有五个客户端,你可以在第一轮只选择三个客户端发送模型。这三个客户端将照常参与训练。然后在下一轮,你可以选择另外三个客户端。
在只有少数客户端的场景中,答案通常是每轮选择所有客户端。如果你像图中一样只有五个客户端,你可能会在联邦学习的每一轮中都选择它们全部。然而,在拥有大量客户端的场景中,你通常不会这样做。事实上,研究表明,选择越来越多的客户端会产生收益递减效应。相反,你只会从可用客户端中选择一个子集。在拥有数百万客户端的移动设备场景中,你通常每轮只选择几百个客户端,或者根据任务最多选择几千个。
以下是选择客户端的不同策略:
- 随机选择:一种非常常见的方法是随机选择客户端。
- 周期性训练:严格来说,这并非联邦学习,但可能非常有用。在这种方法中,你会将模型发送给一个客户端,让该客户端训练,然后发送回服务器,再发送给下一个客户端,如此一个接一个地进行。
客户端配置
一旦选择了客户端,你需要决定如何配置它们。你需要确定客户端应该做什么、训练模型多长时间、应该使用哪些超参数,以及客户端是否需要知道其他信息来正确执行训练或评估。
你可以看到客户端上突出显示的小模型图标。这表示客户端用于执行任务的配置。
聚合方法
第三个通常需要自定义或调优的例子是聚合。在联邦学习中,有许多不同的聚合模型参数的方法。
你在上一课中已经看到并使用了联邦平均算法。还有许多其他方法,如Q-Fed平均或FedAdam,它们对基础的联邦平均算法提供了某些改进。你在上一课中使用的Flower框架内置了许多这样的方法,它们被称为“策略”。
在Flower中配置客户端训练
让我们进入实验环节,看看服务器如何配置客户端训练。
和往常一样,我们从一些导入开始。
import flwr as fl
from flwr.common import Parameters
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
使用联邦数据集
在第1课和第2课中,我们手动划分MNIST数据集来模拟分布在多个用户设备或多个组织上的多个数据集。现在,与其在参与者之间手动划分数据,我们可以使用一个名为feddatasets的库。
feddatasets为我们提供了一个名为FederatedDataset的类。这个抽象层可以划分许多现有的数据集,如MNIST,并允许你为每个客户端生成小的训练集和测试集。
load_data函数加载并准备用于联邦学习的数据。它接收一个partition_id作为输入,指定要加载的数据集分区。这里使用的数据集是MNIST,它被划分为10个分区。然后,加载的分区以80:20的比例被分割为训练和测试子集,这是通过train_test_split函数完成的。应用了一个自定义的转换来标准化数据。最后,分别为训练和测试子集创建了数据加载器,使用的是PyTorch的DataLoader。
向客户端发送配置值
除了模型参数,你通常还想向客户端发送配置值。配置值可以用于各种目的。
假设你希望服务器控制每个客户端执行的本地训练轮数(即本地客户端在训练期间遍历本地数据集的次数)。为此,你定义一个名为fit_config的函数,它只接收一个参数——当前服务器轮次,并返回一个配置字典。
在配置字典中,你放入一个名为local_epochs的键,其整数值告诉客户端要训练多少个本地轮次。在这个例子中,你可以看到如何根据当前服务器轮次来改变这个数字。
def fit_config(server_round: int):
"""返回训练配置字典。"""
config = {
"local_epochs": 2 if server_round < 3 else 5, # 前两轮训练2个epoch,之后训练5个
}
return config
客户端可以使用这个值来动态改变其在本地数据集上训练的轮数。这意味着客户端最初将训练2个轮次,然后在后续轮次中增加本地训练轮数。
接下来,我们像往常一样初始化联邦平均策略。我们将fraction_evaluate设置为0,因为我们不打算执行任何客户端评估,并将初始模型参数传递给策略。
现在,为了让策略使用我们新的fit_config函数,我们只需在初始化FedAvg时通过参数on_fit_config_fn将其传入。
strategy = fl.server.strategy.FedAvg(
fraction_fit=1.0, # 每轮选择全部客户端进行训练
fraction_evaluate=0.0, # 不进行评估
min_fit_clients=5,
min_evaluate_clients=0,
min_available_clients=5,
on_fit_config_fn=fit_config, # 传入配置函数
initial_parameters=parameters,
)
将一个fit_config函数传递给联邦平均策略,将使该策略在每一轮都调用这个函数。返回的配置字典将被包含在发送给每个客户端的消息中。它每次都会调用该函数,使你能够每一轮都向客户端发送不同的配置值。
最后,我们定义一个ServerApp的实例。
客户端接收并使用配置
你还需要像往常一样创建一个Flower客户端类。唯一的区别是,你希望使用配置中的local_epochs值。
FlowerClient类中的fit方法负责训练。它接收两个主要参数:parameters(来自服务器的模型参数)和config(用于训练的配置参数)。
和往常一样,该方法首先使用服务器提供的参数来设置模型的参数。然后从配置字典中提取local_epochs的值,以确定本地训练的轮数。在这里,你锁定这个值,然后调用train_model函数。train_model函数被调用来使用客户端的本地训练数据(存储在self.trainloader中)训练模型。这次,我们将本地训练轮数作为一个额外的参数传递。
class FlowerClient(fl.client.NumPyClient):
def __init__(self, trainloader, valloader):
self.trainloader = trainloader
self.valloader = valloader
self.model = Net()
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
def fit(self, parameters, config):
# 设置模型参数
set_parameters(self.model, parameters)
# 从配置中获取本地训练轮数
local_epochs = config.get("local_epochs", 1)
# 使用指定的轮数进行训练
train_model(self.model, self.trainloader, epochs=local_epochs, device=self.device)
# 返回更新后的参数和其他信息
return get_parameters(self.model), len(self.trainloader), {}
然后,我们像往常一样创建客户端函数和客户端应用。
让我们运行server_app和client_app,看看会发生什么。

你可以看到Flower服务器应用启动,并执行了三轮联邦学习。它初始化了由策略提供的全局参数,然后进入第一轮联邦学习。策略从5个可用客户端中采样了5个客户端,即采样了所有可用客户端。然后,它将配置字典与模型参数一起发送给所有参与的客户端,也就是我们的全部5个客户端。你可以看到第一个客户端记录它训练了2个轮次,第二个客户端也是如此,其他所有客户端都一样。训练结束后,聚合函数接收了5个结果和0个失败。
你还可以看到有一条日志显示“没有选择客户端,跳过评估”。这是因为之前在初始化联邦平均策略时,我们将fraction_evaluate设为了0,所以在客户端侧不进行评估。
在第二轮中,我们看到的情况基本相同。我们可以看到我们所有的五个客户端都被选中了,并且可以看到每个客户端都训练了2个轮次。
现在,看第三轮,你可以看到再次从5个客户端中选择了5个。但是,你可以看到客户端突然不再只训练2个轮次,而是训练了5个轮次。这是由这些客户端从服务器接收到的配置字典引起的。因此,本地训练轮数由服务器控制,客户端根据从服务器接收到的任何配置值做出反应,并执行适当数量的本地训练轮次。
配置字典是一个相当灵活的概念。你可以在该字典中放入许多不同类型的键和值,这是一个非常适合进行实验的东西。例如,你可以用它来从服务器向客户端发送学习率,并控制每个客户端应该使用的学习率。你可以用它来控制客户端训练过程的许多不同方面。
请随意尝试和实验,并告诉我们进展如何。
总结
本节课中我们一起学习了联邦学习的调优与自定义。联邦学习引入了额外的超参数和概念,这些概念对于控制服务器端的训练过程非常重要。
在服务器端,我们可以自定义和调优诸如客户端选择、客户端配置和结果聚合等方面。在客户端,我们可以配置诸如数据预处理、本地训练以及将权重发送回服务器之前要进行的任何后处理。
Flower的策略允许你自定义我们讨论过的所有这些客户端侧和服务器侧的行为。它允许你自定义客户端选择、客户端配置、结果聚合和服务器端评估。


通过本节课的学习,你应该能够理解联邦学习系统中可调优的关键组件,并掌握在Flower框架中通过配置字典动态控制客户端行为的基本方法。
005:数据隐私 🔒


在本节课中,我们将学习数据隐私以及隐私增强技术。你将了解隐私在联邦学习中的重要性,学习如何思考隐私保护技术,并以差分隐私为例进行深入探讨。
概述
联邦学习通过防止直接访问数据,本身可被视为一种数据最小化解决方案。然而,客户端与服务器之间交换的模型更新仍可能导致隐私泄露。根据攻击模型和攻击者在联邦学习中的角色,存在多种可能的攻击方式需要考虑。
隐私攻击类型
攻击者可以是客户端、服务器或第三方。以下是三种攻击示例:
- 成员推断攻击:旨在推断特定数据样本是否参与了训练。
- 属性推断攻击:旨在推断训练数据中未见的属性。
- 重建攻击:旨在推断具体的训练数据样本。
例如,有研究论文表明,在特定设置下,恶意服务器能够重建联邦学习中特定客户端的训练数据样本。可以看到,重建的图像并非完全相同,但其质量惊人地接近原始数据。
差分隐私简介
差分隐私是一种在数据分析中增强个人隐私的突出解决方案。它通过向查询结果添加校准过的噪声来模糊个体数据,确保任何单个数据点的存在与否都不会显著影响分析结果。这保证了在不泄露敏感信息的前提下进行准确分析。
假设有两个数据集 D 和 D',它们仅相差一个数据点。差分隐私保证,任何分析 M(例如计算平均收入)对这两个数据集产生的结果 O 和 O' 将几乎相同。
在机器学习中,差分隐私为我们提供以下保证:如果我们在数据集 D 上训练模型 M1,然后添加或移除一个数据点(例如本例中A的数据)后训练第二个模型 M2,那么得到的模型 M1 和 M2 将在一定程度上无法区分。这种不可区分性的程度由我们旨在实现的隐私保护水平来量化。
联邦学习中的差分隐私
在联邦学习的背景下,差分隐私可以应用于流程的各个阶段,包括模型训练、模型更新的聚合以及客户端与服务器之间的通信。根据其应用方式,差分隐私提供不同级别的隐私保护。
本节课你将学习差分隐私的两种变体:中心化差分隐私 和 本地差分隐私。
关于差分隐私有两个重要主题:
- 裁剪:用于限制敏感度并减轻异常值的影响。此处的“敏感度”是指当数据集中添加或移除单个数据点时,输出可能改变的最大量。
- 加噪:通过添加校准过的噪声,使输出在统计上无法区分。
中心化差分隐私
在中心化差分隐私中,中央服务器负责向全局聚合的参数添加噪声。需要注意的是,这要求信任服务器。整体方法是先裁剪客户端发送的模型更新,然后向聚合后的模型添加一定量的噪声。
本地差分隐私
在本地差分隐私中,每个客户端负责在本地执行差分隐私操作,然后再将更新后的模型发送给服务器。本地差分隐私避免了对完全可信聚合器的需求。每个客户端在将更新模型发送到服务器之前,负责在本地执行裁剪和加噪。
实践环节
现在,让我们进入实验环节。
与往常一样,我们首先导入实用函数和类。我们同时导入一个服务器端差分隐私策略(称为 DifferentialPrivacyClientSideAdaptiveClipping)和一个客户端自适应裁剪模块(称为 adaptive_clipping_mod)。稍后我们将解释这些组件的作用。
与上一课类似,我们加载 MNIST 数据并使用 Flower 的数据工具将其划分为 10 部分。我们还像之前一样定义了一个 Flower 客户端。
我们定义了一个 client_fn 函数,用于初始化训练加载器、测试加载器和 Flower 客户端。在定义客户端应用时,我们使用了一个名为 flower mods 的新功能。
Mods(也称为修饰器)允许你在任务在 ClientApp 中被处理之前和之后执行操作。你可以使用内置的 mods,甚至可以定义自己的自定义 mods。这里我们使用 adaptive_clipping_mod,它在将模型更新发送回服务器之前执行自适应裁剪。
在服务器端,你首先像往常一样创建联邦平均策略。这次我们给它一个不同的名字,称为 fed_avg_without_dp。然后,我们不直接将联邦平均策略对象传递给 server_app,而是用一个名为 DifferentialPrivacyClientSideAdaptiveClipping 的包装策略来包装它。

为此,你创建 DifferentialPrivacyClientSideAdaptiveClipping 的一个实例,并将之前创建的策略对象以及两个 DP 特定参数(噪声乘数和客户端采样数量)传递给它。DifferentialPrivacyClientSideAdaptiveClipping 策略包装器本身就是一个策略,它包装其他策略并负责在服务器端应用差分隐私。这意味着它接收模型更新,将其转发给内部策略进行聚合,然后向聚合后的模型添加噪声。
最后一步是像往常一样创建 server_app。现在,你有了一个匹配的服务器应用和客户端应用,它们可以执行带有客户端自适应裁剪的中心化差分隐私。
运行实验
现在,我们有了一个匹配的服务器应用和客户端应用,它们共同可以执行带有客户端自适应裁剪的中心化差分隐私。
配置好我们的服务器端 DP 策略和客户端裁剪模块后,我们可以看到联邦训练如何与 DP 一起运行。在本地训练之后,客户端模块裁剪参数并将裁剪后的模型更新发送到服务器。内部策略然后聚合这些模型更新,DP 包装策略向聚合后的模型添加噪声。


你可以像往常一样通过 run_simulation 运行此实验。这次,你使用 10 个模拟客户端运行模拟,并在每一轮中选择其中的 6 个客户端(这是 fraction_fit=0.6 的结果)。现在的日志信息更多一些,但你可以看到在客户端,模块裁剪了参数。你可以看到一些以 adaptive_clipping_mod 开头的日志,它显示参数被某个值裁剪。在以 aggregate_fit 开头的部分日志中,你可以看到添加了具有特定标准偏差的中心化 DP 噪声。

由于裁剪和加噪,DP 通常会导致收敛速度变慢。因此,在这个实验中,我们使用了一个较小的噪声乘数(这导致隐私性较低),并运行联邦训练 50 轮(理想情况下,根据所需的隐私性和实用性,可以运行更多轮)。
总结
本节课我们一起学习了数据隐私在联邦学习中的重要性。联邦学习本身并不能保证数据隐私。隐私增强技术(通常简称为 PETs),如差分隐私和安全聚合,可以提供帮助。




差分隐私(无论是中心化还是本地化)通过裁剪梯度和添加噪声来工作。此类隐私增强通常会带来成本,例如效用降低和计算开销增加。
006:带宽分析 📊

在本节课中,我们将深入学习使用联邦学习训练模型时的带宽需求。你将理解如何在理论上分析联邦系统的带宽使用情况,以及如何在实践中使用Flower框架测量带宽消耗。
上一节我们介绍了联邦学习的基本流程,本节中我们来看看其通信开销。

带宽需求计算公式 📐
为了更好地理解联邦学习系统的带宽使用情况,我们将逐步推导一个公式,用于计算运行此类系统的大致带宽需求。
我们首先关注发送给单个客户端的模型大小,以及从该客户端接收回来的模型更新大小。在某些场景下,这两个大小并不相同。有时我们向客户端发送完整的模型参数,但客户端返回的是压缩后的梯度,因此从客户端接收的更新大小会小于我们发送出去的模型大小。
将发送给客户端的模型大小与从客户端接收的模型更新大小相加,就得到了向一个客户端发送模型并接收其更新所需的带宽需求。
接着,我们将这个数值乘以队列规模,即系统中客户端的总数。然后,我们需要乘以每轮选择的客户端比例。如果我们有100个客户端,但每轮只选择其中的20%,我们就需要乘以0.2。
最后,我们将这个数字乘以执行的轮数。前面的步骤给出了单轮联邦学习的带宽需求,再乘以总轮数即可得到总带宽需求。
如果发送出去的模型大小与从客户端接收回来的模型大小相同,我们可以简化为:模型大小 × 2。
以下是计算联邦学习带宽需求的简化公式:
总带宽 ≈ (发送模型大小 + 接收更新大小) × 队列规模 × 选择比例 × 轮数
计算示例 🔢
让我们通过一个例子来应用这个公式。在课程2中,将涉及在私有数据上进行联邦大语言模型微调。将使用的语言模型是EleutherAI的Pythia-14m模型,这是一个拥有1400万个参数的模型。该模型的大小为53 MB。
根据我们的简化公式,我们将其乘以2,得到向一个客户端发送模型并接收其更新所需的带宽。如果我们跨两个客户端训练此模型,则将该数字乘以2。因为我们只有两个客户端,所以每轮都会选择它们。
如果我们有100个客户端,在单轮联邦学习中只选择其中的50个是合理的。在这种情况下,我们将选择比例设置为0.5,而不是1.0。
最后,我们乘以联邦学习的轮数。在实验中,为了节省时间,我们只进行一轮,因此将其设置为1。这样,我们得到单轮联邦学习的总带宽使用量约为212 MB。
实验验证 🧪
现在,让我们进入实验环节,看看我们的计算是否正确。
以下是实验的关键步骤:

- 导入模块:导入必要的工具函数和类,以及一个用于在客户端跟踪参数传输大小的模块。
- 初始化模型:初始化一个拥有1400万个参数的Pythia-14m语言模型,并计算其大小(确认为53 MB)。
- 定义客户端:定义Flower客户端。与之前课程不同的是,我们跳过了实际的训练和评估部分,因为我们只想在服务器端计算带宽。
- 创建自定义策略:为了跟踪发送和接收的模型大小,我们创建一个名为
BandwidthTrackingFedAvg的自定义策略。它扩展了之前课程中使用的联邦平均策略。- 在
aggregate_fit方法中,它为每个客户端的结果计算接收到的模型更新大小(以MB为单位)并记录。 - 在
configure_fit方法中,它计算即将发送给客户端的模型大小(以MB为单位)并记录。
- 在
- 配置服务器:创建服务器应用,将策略设置为自定义策略,并将轮数设置为1(因为在此设置中,带宽需求在连续轮次中不会改变)。
- 运行模拟并查看结果:启动模拟后,在日志中可以看到:
- 要发送的模型大小为53 MB。
- 服务器接收了两个大小约为53 MB的模型。
- 最后,通过汇总记录的所有带宽大小,得到总带宽使用量为212 MB,这与我们使用公式计算的结果完全一致。
带宽优化策略 ⚙️

在实验中我们看到,即使只进行一轮联邦学习,也可能快速消耗大量带宽。有许多方法可以减少联邦学习中的带宽使用。
优化带宽主要分为两类:减少单个更新的大小和减少通信频率。
以下是减少更新大小的几种方法:
- 稀疏化:例如使用Top-K稀疏化。如果要通信的梯度低于某个阈值,则将其作为零进行通信(实际上可以跳过),从而节省通信成本。这在训练后期尤其有效,因为梯度中更多元素的幅度会变小。
- 量化:量化有多种形式,它们通过减少表示标量所需的位数来降低客户端与服务器之间交换的更新大小。
你也可以利用预训练模型。在许多场景下,可以找到一个对特定应用有用的预训练模型,然后联邦学习可以在此基础上继续训练。在这种情况下,我们可能不需要训练每一层,而只需通信被联邦训练修改的层。
另一种方法是在本地训练更长时间,然后再与服务器交换更新。例如,不是只训练一个本地周期,而是在发送更新后的模型回服务器之前训练五个周期。但需要注意,这也可能阻碍收敛。如果本地模型训练过多周期,它们可能会越来越发散,导致聚合后的模型变差而非变好。
总结 📝
本节课中我们一起学习了联邦学习的带宽分析。
- 我们可以通过将发送的模型大小与接收的更新大小相加,再乘以队列规模、每轮选择的客户端比例以及总轮数,来计算带宽需求。
- 在实际实现中,我们可以使用Flower中的客户端模块和服务器端策略来测量服务器端和客户端的带宽需求。
- 我们可以通过应用稀疏化或量化等技术、使用预训练模型(不通信所有层)或者在交换模型更新之前进行更多的本地训练,来优化带宽利用率。

理解并管理带宽是构建高效、可扩展的联邦学习系统的关键一步。
007:课程总结与展望 🎯
在本节课中,我们将对联邦学习课程的核心内容进行总结,并展望后续的学习方向与应用前景。
课程概述 📚
在本系列课程中,我们学习了联邦学习及其在解锁大量当前未被利用的分布式训练数据方面所扮演的重要角色。
核心学习内容回顾 🔄
上一节我们介绍了联邦学习的应用场景,本节中我们来回顾整个课程的核心要点。
以下是本课程涵盖的主要学习模块:
- 使用 Flower 联邦学习框架,构建了不同版本的联邦训练流程。
- 学习了如何调整联邦系统的不同方面。
- 探讨了如何考量数据隐私问题。
- 掌握了如何使用 Flower 计算和测量联邦系统的带宽使用情况。
后续学习方向 🚀
那么,接下来该做什么?这是两门系列课程中的第一门。下一门课程将介绍如何在私有数据上进行联邦大语言模型(LLM)的微调。
探索与实践鼓励 💡
我也鼓励你开始自己的探索。我们非常乐意在 Flower AI 或 X 平台的 FLwr labs 上听到你的声音。加入 Flower 社区 Slack 频道,那里有成千上万志同道合的人工智能研究者和开发者正在交流想法。
致谢与展望 🙏
感谢令人惊叹的 Flower 社区。看到如此多的项目不断突破边界,这真是一种激励。我期待着看到你们自己构建的作品。

课程总结 ✨

本节课中,我们一起学习了联邦学习的基础概念、使用Flower框架的实践方法、系统调优、隐私考量以及资源评估。联邦学习为在保护数据隐私的前提下利用分散数据提供了强大的解决方案,希望你已准备好将其应用于自己的项目中。

浙公网安备 33010602011771号