FedJAX:使用 JAX 进行联邦学习模拟

1727688353967.jpg

联邦学习是一种机器学习设置,其中许多客户端(即移动设备或整个组织,取决于手头的任务)在中央服务器的协调下协作训练模型,同时保持训练数据的分散性。例如,联邦学习可以根据永远不会离开移动设备的用户数据 来训练虚拟键盘语言模型。

联邦学习算法通过首先在服务器上初始化模型并完成每轮训练的三个关键步骤来实现这一点:

服务器将模型发送给一组采样的客户端。

这些采样客户端利用本地数据来训练模型。

训练完成后,客户端将更新后的模型发送到服务器,服务器将它们聚合在一起。

1727688353967.jpg

具有四个客户端的联邦学习算法示例。

由于对隐私和安全的关注度不断提高,联邦学习已成为一个特别活跃的研究领域。对于这样一个快速发展的领域 来说,能够轻松地将想法转化为代码、快速迭代以及比较和重现现有基线非常重要。

鉴于此,我们很高兴推出FedJAX,这是一个基于JAX的联邦学习模拟开源库,强调研究的易用性。凭借其用于实现联邦算法的简单构建块、预打包的数据集、模型和算法以及快速的模拟速度,FedJAX 旨在让研究人员更快、更轻松地开发和评估联邦算法。在这篇文章中,我们讨论了 FedJAX 的库结构和内容。我们证明,在 TPU 上,FedJAX 可用于在几分钟内使用联邦平均法在EMNIST数据集上训练模型,并在大约一小时内使用标准超参数在Stack Overflow数据集上训练模型。

图书馆结构

考虑到易用性,FedJAX 仅引入了一些新概念。使用 FedJAX 编写的代码类似于学术论文中用于描述新算法的伪代码,因此很容易上手。此外,虽然 FedJAX 为联邦学习提供了构建块,但用户可以使用NumPy和 JAX将其替换为最基本的实现,同时仍能保持整体训练速度相当快。

包含的数据集和模型

在当前的联邦学习研究领域中,有各种常用的数据集和模型,例如图像识别、语言建模等。越来越多的数据集和模型可以在 FedJAX 中直接使用,因此预处理的数据集和模型不必从头开始编写。这不仅鼓励不同联邦算法之间的有效比较,而且还加速了新算法的开发。

目前,FedJAX 附带以下数据集和示例模型:

EMNIST-62,字符识别任务

莎士比亚,下一个角色预测任务

Stack Overflow,下一个单词预测任务

除了这些标准设置外,FedJAX 还提供了用于创建新数据集和模型的工具,这些数据集和模型可与库的其余部分一起使用。最后,FedJAX 附带了联邦平均和其他联邦算法的标准实现,用于在分散的示例上训练共享模型,例如自适应联邦优化器、不可知联邦平均和Mime,以便更轻松地与现有算法进行比较和评估。

绩效评估

我们对两个任务上的标准 FedJAX自适应联合平均实现进行了基准测试:联合 EMNIST-62 数据集的图像识别任务和Stack Overflow 数据集的下一个单词预测任务。联合 EMNIST-62 是一个较小的数据集,包含 3400 名用户及其书写样本,这些样本是 62 个字符(字母数字)之一,而 Stack Overflow 数据集要大得多,包含来自 Stack Overflow 论坛的数百万个问题和数十万用户的答案。

我们测量了各种机器学习专用硬件的性能。对于联邦 EMNIST-62,我们在 GPU(NVIDIA V100)和 TPU(Google TPU v2 上的 1 个 TensorCore)加速器上训练了一个模型,每轮 10 个客户端,共 1500 轮。

对于 Stack Overflow,我们使用 jax.jit 在 GPU(NVIDIA V100)上训练了一个模型,每轮有 50 个客户端,共训练了 1500 轮,仅使用 jax.jit 在 TPU(Google TPU v2 上的 1 个 TensorCore)上训练,使用 jax.pmap 在多核 TPU(Google TPU v2 上的 8 个 TensorCores)上训练。在下图中,我们记录了平均训练轮次完成时间、对测试数据进行全面评估所需的时间以及包括训练和全面评估在内的整体执行时间。

1727688396108.jpg

联邦 EMNIST-62 的基准测试结果。

1727688414960.jpg

Stack Overflow 的基准测试结果。

使用标准超参数和 TPU,联邦 EMNIST-62 的完整实验可以在几分钟内完成,而 Stack Overflow 则大约需要一个小时才能完成。

1727688435544.jpg

随着每轮客户端数量的增加,Stack Overflow 平均训练轮次持续时间。

我们还评估了随着每轮客户端数量的增加,Stack Overflow 平均训练轮次持续时间。通过比较图中 TPU(8 核)和 TPU(1 核)之间的平均训练轮次持续时间,可以明显看出,如果每轮参与的客户端数量很大,则使用多个 TPU 核心可以显著提高运行时间(这对于差分隐私学习等应用很有用)。

结论和未来工作

在这篇文章中,我们介绍了 FedJAX,这是一个快速且易于使用的研究用联邦学习模拟库。我们希望 FedJAX 能够促进人们对联邦学习的更多研究和兴趣。展望未来,我们计划不断扩充现有的算法、聚合机制、数据集和模型集合。

欢迎查看我们的一些教程笔记本,或亲自试用 FedJAX !有关该库及其与Tensorflow Federated等平台的关系的更多信息,请参阅我们的论文、README或常见问题解答。

致谢

我们要感谢 Ke Wu 和 Sai Praneeth Kamireddy 为该库做出的贡献以及开发过程中的各种讨论。

我们还要感谢 Ehsan Amid、Theresa Breiner、Mingqing Chen、Fabio Costa、Roy Frostig、Zachary Garrett、Alex Ingerman、Satyen Kale、Rajiv Mathews、Lara Mcconnaughey、Brendan McMahan、Mehryar Mohri、Krzysztof Ostrowski、Max Rabinovich、Michael Riley、Vlad Schogol、Jane Shapiro、Gary Sivek、Luciana Toledo-Lopez 和 Michael Wunder 提供的有益评论和贡献。

版权声明

本文仅代表作者观点,不代表本站立场。
本文系作者授权发表,未经许可,不得转载。

评论