AdaTape:具有自适应计算和动态读写的基础模型

BQ%$ERHZ760FREHCSE]AM%F.png自适应计算是指机器学习系统根据环境变化调整其行为的能力。虽然传统神经网络具有固定的功能和计算能力,即它们花费相同数量的 FLOP 来处理不同的输入,但具有自适应和动态计算的模型会根据输入的复杂性调整用于处理每个输入的计算预算。

神经网络中的自适应计算之所以具有吸引力,主要有两个原因。首先,引入自适应性的机制提供了一种归纳偏差,这在解决一些具有挑战性的任务中可以发挥关键作用。例如,在解决需要对不同深度层次进行建模的算术问题时,为不同的输入启用不同数量的计算步骤可能至关重要。其次,它使从业者能够通过动态计算提供的更大灵活性来调整推理成本,因为这些模型可以进行调整以花费更多的 FLOP 来处理新输入。

通过对各种输入使用不同的函数或计算预算,可以使神经网络具有自适应性。深度神经网络可以被认为是一种基于输入及其参数输出结果的函数。为了实现自适应函数类型,需要根据输入有选择地激活一组参数,这个过程称为条件计算。在混合专家研究中,已经探索了基于函数类型的自适应性,其中通过路由确定每个输入样本的稀疏激活参数。

自适应计算的另一个研究领域涉及动态计算预算。与标准神经网络(例如T5、GPT-3、PaLM和ViT)不同,这些网络的计算预算对于不同的样本是固定的,而最近的研究表明,自适应计算预算可以提高 Transformer 不足的任务的性能。其中许多工作通过使用动态深度来分配计算预算来实现自适应性。例如,提出了自适应计算时间(ACT) 算法来为循环神经网络提供自适应计算预算。Universal Transformer通过使计算预算取决于每个输入示例或 token 使用的 Transformer 层数,将 ACT 算法扩展到 Transformer。最近的研究,如PonderNet,在改进动态停机机制的同时采用了类似的方法。

在论文“具有弹性输入序列的自适应计算”中,我们介绍了一种利用自适应计算的新模型,称为AdaTape。该模型是一种基于 Transformer 的架构,它使用一组动态标记来创建弹性输入序列,与以前的研究相比,它在自适应性方面提供了独特的视角。AdaTape 使用自适应磁带读取机制来确定根据输入的复杂性添加到每个输入的不同数量的磁带标记。AdaTape 实现起来非常简单,提供了一个有效的旋钮来在需要时提高准确性,但与其他自适应基线相比也更高效,因为它直接将自适应性注入输入序列而不是模型深度。最后,Adatape 在标准任务(如图像分类)以及算法任务上提供了更好的性能,同时保持了有利的质量和成本权衡。

具有弹性输入序列的自适应计算变压器

AdaTape 同时使用自适应函数类型和动态计算预算。具体来说,对于标记化后的一批输入序列(例如,视觉转换器中图像中不重叠块的线性投影),AdaTape 使用表示每个输入的向量来动态选择可变大小的磁带标记序列。

AdaTape 使用一组令牌(称为“磁带库”)来存储通过自适应磁带读取机制与模型交互的所有候选磁带令牌。我们探索了两种创建磁带库的不同方法:输入驱动库和可学习库。

输入驱动库 的一般思想是从输入中提取一组标记,同时采用与原始模型标记器不同的方法将原始输入映射到输入标记序列。这可以动态、按需访问使用不同视角(例如不同的图像分辨率或不同的抽象级别)获得的输入信息。

在某些情况下,无法在不同的抽象级别进行标记化,因此输入驱动的磁带库是不可行的,例如当难以进一步拆分图转换器中的每个节点时。为了解决这个问题,AdaTape 提供了一种更通用的方法来生成磁带库,即使用一组可训练向量作为磁带标记。这种方法被称为可学习库,可以看作是一个嵌入层,模型可以根据输入示例的复杂性动态检索标记。可学习库使 AdaTape 能够生成更灵活的磁带库,使其能够根据每个输入示例的复杂性动态调整其计算预算,例如,更复杂的示例从库中检索更多标记,这让模型不仅可以使用库中存储的知识,还可以花费更多的 FLOP 来处理它,因为输入现在更大了。

最后,选定的磁带标记将附加到原始输入并馈送到以下转换器层。对于每个转换器层,所有输入和磁带标记都使用相同的多头注意力。但是,使用两个不同的前馈网络 (FFN):一个用于原始输入中的所有标记,另一个用于所有磁带标记。我们观察到,对输入和磁带标记使用单独的前馈网络可以稍微提高质量。

AdaTape 概述。对于不同的样本,我们从磁带库中挑选不同数量的标记。磁带库可以由输入驱动,例如,通过提取一些额外的细粒度信息,也可以是一组可训练向量。自适应磁带读取用于递归地为不同的输入选择具有可变长度的不同磁带标记序列。然后只需将这些标记附加到输入并馈送到变压器编码器。

AdaTape 提供了有用的归纳偏差

我们对 AdaTape 进行了奇偶校验评估,这对于标准 Transformer 来说是一项非常具有挑战性的任务,目的是研究 AdaTape 中归纳偏差的影响。对于奇偶校验任务,给定一个 1、0 和 -1 序列,模型必须预测序列中 1 的数量是偶数还是奇数。奇偶校验是最简单的非计数器自由或周期性正则语言,但令人惊讶的是,标准 Transformer 无法解决该任务。

对奇偶校验任务的评估。标准 Transformer 和 Universal Transformer 都无法执行此任务,均表现出随机猜测基线水平的性能。

尽管在简短的序列上进行了评估,但标准 Transformer 和 Universal Transformer 都无法执行奇偶校验任务,因为它们无法在模型中维护计数器。然而,AdaTape 的表现优于所有基线,因为它在其输入选择机制中整合了轻量级循环,提供了一种归纳偏差,可以隐式维护计数器,而这在标准 Transformer 中是不可能的。

图像分类评估

我们还在图像分类任务上对 AdaTape 进行了评估。为此,我们从头开始在ImageNet-1K上训练 AdaTape。下图显示了 AdaTape 和基线方法(包括A-ViT和 Universal Transformer ViT(UViT 和 U2T))的准确率与它们的速度(以每秒每个代码处理的图像数量来衡量)。在质量和成本权衡方面,AdaTape 的表现远远优于替代自适应变压器基线。在效率方面,较大的 AdaTape 模型(就参数数量而言)比较小的基线更快。这些结果与以前的研究结果一致,表明自适应模型深度架构不太适合许多加速器,例如 TPU。

我们通过在 ImageNet 上从头开始训练来评估 AdaTape。对于A-ViT,我们不仅报告了其论文中的结果,还通过从头开始训练重新实现了 A-ViT,即 A-ViT(我们的)。

AdaTape 行为研究

除了在奇偶校验任务和 ImageNet-1K 上的表现外,我们还在JFT-300M验证集上评估了 AdaTape 使用输入驱动库的令牌选择行为。为了更好地理解模型的行为,我们将输入驱动库上的令牌选择结果可视化为热图,其中颜色越浅表示该位置被选择的频率越高。热图显示 AdaTape 更频繁地选择中心补丁。这与我们先前的知识一致,因为中心补丁通常更具信息量——尤其是在包含自然图像的数据集中,其中主要对象位于图像中间。这一结果凸显了 AdaTape 的智能性,因为它可以有效地识别和优先考虑更具信息量的补丁,以提高其性能。

我们可视化了 AdaTape-B/32(左)和 AdaTape-B/16(右)的磁带标记选择热图。颜色越热/越浅,表示此位置的补丁被选择的频率越高。

结论

AdaTape 的特点是自适应磁带读取机制生成的弹性序列长度。这还引入了一种新的归纳偏差,使 AdaTape 有潜力解决标准 Transformer 和现有自适应 Transformer 都面临的挑战性任务。通过对图像识别基准进行全面的实验,我们证明了在计算保持不变的情况下,AdaTape 的表现优于标准 Transformer 和自适应架构 Transformer。

致谢

本文作者之一 Mostafa Dehghani 目前就职于 Google DeepMind。


版权声明

本文仅代表作者观点,不代表本站立场。
本文系作者授权发表,未经许可,不得转载。

评论