天天滚动:ICLR 2023 | DIFFormer: 扩散过程启发的Transformer
2023-04-30 09:58:38
【资料图】
机器之心专栏
机器之心编辑部本⽂介绍⼀项近期的研究⼯作,试图建⽴能量约束扩散微分⽅程与神经⽹络架构的联系,从而原创性的提出了物理启发下的 Transformer,称作 DIFFormer。作为⼀种通⽤的可以灵活⾼效的学习样本间隐含依赖关系的编码器架构,DIFFormer 在各类任务上都展现了强大潜⼒。这项工作已被 ICLR 2023 接收,并在⾸轮评审就收到了四位审稿⼈给出的 10/8/8/6 评分(最终均分排名位于前 0.5%)。论⽂地址:https://arxiv.org/pdf/2301.09474.pdf 项⽬地址:https://github.com/qitianwu/DIFFormer 简介如何得到有效的样本表征是机器学习领域的⼀⼤核⼼基础问题,也是深度学习范式在各类下游任务 能发挥作用的重要前提。传统的表征学习⽅法通常假设每个输⼊样本是独⽴的,即分别将每个样本输⼊进 encoder ⽹络得到其在隐空间中的表征,每个样本的前向计算过程互不干扰。然⽽这⼀假设通常与现实物理世界中数据的⽣成过程是违背的:由于显式的物理连接或隐含的交互关系,每个观测样本之间可能存在相互的依赖。 这⼀观察也启发了我们去重新思考⽤于表征计算的 encoder ⽹络设计:是否能设计⼀种新型的 encoder ⽹络能够在前向计算中显式的利⽤样本间的依赖关系(尽管 这些依赖关系是未被观察到的)。在这个⼯作中,我们从两个物理学原理出发,将神经⽹络计算样本表征的前向过程看作给定初始状态的扩散过程,且随着时间的推移(层数加深)系统的整体能量不断下降(见下图)。 DIFFormer 模型主要思想的示意图:将模型计算样本表征的前向过程看作⼀个扩散过程,随着时间的推移,节点之间存在信号传递,且任意节点对之间信号传递的速率会随着时间适应性的变化,使得系统整体的能量最⼩化。通过扩散过程和能量约束,最终的样本表征能够吸收个体和全局的信息,更有助于下游任务。通过试图建⽴扩散微分⽅程与神经⽹络架构的联系,我们阐释了能量约束扩散过程与各类信息传递网络(如 MLP/GNN/Transformers)的联系,并为新的信息传递设 计提供了⼀种理论参考。基于此,我们提出了⼀种新型的可扩展 Transformer 模型,称为 DIFFormer(diffusionbased Transformers)。它可以作为⼀种通⽤的 encoder,在前向计算中利⽤样本间隐含的依赖关系。⼤量实验表明在⼩ / ⼤图节点分类、图⽚ / ⽂本分类、时空预测等多个领域的实验任务上 DIFFormer 都展现了强⼤的应⽤潜⼒。在计算效率上,DIFFormer 只需要 3GB 显存就可以实现⼗万级样本间全联接的信息传递。 动机与背景我们⾸先回顾⼀个经典的热⼒学中的热传导过程:假设系统中有 个节点,每个节点有初始的温度,两两节点之间都存在信号流动,随着时间的推移节点的温度会不断更新。上述物理过程事实上可以类⽐的看作深度神经网络计算样本表征(embedding)的前向过程。 将神经⽹络的前向计算过程看作⼀个扩散过程:每个样本视为流形上的固定位置节点,样本的表征为节点的信号,表征的更新视作节点信号的改变,样本间的信息传递看作节点之间的信号流动具体的,考虑包含 个样本的数据集,用 表示样本 i 的输入特征, 表示样本 i 的表征向量。⼀个 L 层的神经网络模型会把每个输⼊样本映射到⼀系列隐空间中的表征向量: 这⾥我们可以把每个样本看作⼀个离散空间中的节点,样本表征看作节点的信号。当模型结构考虑样本交互时(如信息传递),它可以被看作节点之间的信号流动,随着模型层数加深(即时间的推移),样本表征会不断被更新。 扩散过程的描述⼀个经典的扩散过程可以由⼀个热传导⽅程(带初始条件的偏微分⽅程)来描述 这⾥的 , 和 分别表示梯度(gradient)算⼦、散度 (divergence) 算⼦和扩散率(diffusivity)。对于由 N 个节点组成的离散化空间,以上三个概念的具体定义可以如下表示: 在离散空间中,梯度算⼦可以看作两两节点的信号差异,散度算子可以看作单个节点流出信号的总和,⽽扩散率(diffusivity)是⼀种对任意两两节点间信号流动速率的度量由此我们可以写出描述 N 个节点每时每刻状态更新的扩散微分⽅程,它描述了每个状态下系统中每个节点信号的变化等于流向其他节点的信号总和: 这⾥的扩散率 定义了在当前时刻任意两两节点 之间的影响,即信号从节点 流向 的速率的⼀种度量。 由扩散方程导出的信息传递我们进⼀步使⽤数值有限差分(具体的这⾥使⽤显式欧拉法)将上述的微分⽅程展开成迭代更新的形式,引⼊⼀个步⻓ 对连续时间进⾏离散化(再经过⽅程左右重新整理): 这⾥的第⼀项系数可以被视作⼀个常数(如果假设 是经过沿⾏归⼀化的),于是上式就可以视为⼀个对其他样本表征的信息聚合(第⼆项)再加上⼀个对上⼀层⾃身表征的 residual 连接(第⼀项)。这⾥的扩散率 是⼀个 的矩阵,我们可以对其进⾏不同的假设,就可以得到不同模型的层间更新公式: 如果 是⼀个 的单位矩阵:(1)式中每个样本的表征计算只取决于⾃⼰(与其他样本独⽴),此时给出的是 Multi-Layer Perceptron (MLP) 的更新公式,即每个样本被单独输⼊进 encoder 计算表征; 如果 在固定位置存在⾮零值(如输⼊图中存在连边的位置):(1)式中每个样本的表征更新会依赖于图中相邻的其他节点,此时给出的是 Graph Neural Networks (GNN) 的更新公式,其中 是传播矩阵(propagation matrix),例如图卷积⽹络(GCN)模型采⽤归⼀化后的邻接矩阵 ; 如果 在所有位置都允许有⾮零值,且每层的 都可以发⽣变化:(1)式中每个样本的表征更新会依赖于其他所有节点,且每次更新两两节点间的影响也会适应性的变化,此时 (1) 式给出的是 Transformer 结构的更新公式, 表示第 层的 attention 矩阵。 下图概述了这三种信息传递模式: 我们研究最后⼀种信息传递⽅式,每层更新的样本表征会利⽤上⼀层所有其他样本的表征,在理论上模型的表达能⼒是最强的。但由此产⽣的⼀个问题是:要如何才能确定合适的每层任意两两节点之间的 diffusivity,使得模型能够产⽣理想的样本表征?刻画⼀致性的能量函数我们这⾥引⼊⼀个能量函数,来刻画每时每刻由系统中所有节点表征所定义的内在⼀致性,通过能量的最⼩化来引导扩散过程中节点信号的演 变⽅向。具体的,对于样本表征 ,其对应的能量定义为: 这⾥的第⼀项约束了每个节点对⾃身当前状态的局部⼀致性,第⼆项了约束了与系统中其他节点的全局⼀致性。其中 是⼀个单调递增的凹函数(当 与 差别较⼤时, 会返回⼀个适中的能量值,即减⼩对差异较⼤的节点对 的“惩罚”,这有助于提升样本表征的 diversity)。理想情况下,当系统的整体能量达到最⼩化,我们可以认为系统中的每⼀个个体都与整体取得了平衡,样本的表征同时吸收了局部和全局的信息。 能量约束的扩散过程基于此,我们考虑⼀种带能量约束的扩散过程,每⼀步的扩散率 被定义为⼀个待优化的隐变量,我们希望它给出的每⼀步的节点表征都能够使得系统整体的能量下降。带能量约束的扩散过程可以被形式化的描述为: 虽然直接求解 ⾮常复杂(因为他耦合了每⼀步能量下降的约束),不过本⽂通过理论分析建⽴了扩散⽅程数值迭代与能量优化梯度更新的等价性,从⽽得到了每⼀步扩散率的最优闭式解。 定理对于任意的由 (2) 式所定义的能量函数,存在步⻓ 和相应的扩散率估计 使得由 (1) 式定义的扩散⽅程数值迭代保证每⼀步的能量下降,即 。 基于这⼀理论结果,我们进⽽提出了扩散过程诱导下的 Transformer 结构,即 DIFFormer,它的每⼀层更新公式表示为: 这⾥的 表示衡量 和 相似性的函数,在具体设计时具有很⼤的灵活性。下⾯我们提出两种具体设计,分别称相应的模型结构为 DIFFormer-s 和 DIFFormer-a。 DIFFormer-s :采⽤简单的 dot-product 来衡量相似性,作为 attention function(这⾥使⽤ L2 normalization 将输⼊向量限制在 [-1,1] 之间从⽽保证得到的注意⼒权重⾮负): DIFFormer-a :在计算相似度时引⼊⾮线性,从⽽提升模型学习复杂结构的表达能⼒: 当我们考虑每层两两节点之间的全局 attention,⼀个潜在的问题是 all-pair attention 带来的 平⽅复杂度。庆幸的是,这⾥ DIFFormer-s 的 attention 定义可以保证每⼀层更新 个样本表征的计算复杂度在 之内,这⾮常有利于提升模型的时空效率(特别是空间效率,当需要扩展到包含⼤量样本的数据集时)。 为什么能实现复杂度呢? 我们可以把 代⼊更新单个样本的聚合公式,然后通过矩阵乘法结合律交换矩阵运算的顺序(这⾥假设 ): 在上式左边的式⼦中,计算⼀次需要 复杂度,⽽⼜因为这是对单个样本的更新公式,因此更新 个不同的样本需要的复杂度是。但在右边的式⼦中,分⼦和分⺟的两个求和项对于所有样本是共享的,也就是说在实际计算中只需要 算⼀次,⽽后对每个样本的更新只需要 ,因此更新 个样本的总复杂度是 。不过对于 DIFFormer-a 的 attention 设计,则⽆法保证 的计算复杂度,因为⾮线性的引⼊导致了⽆法交换矩阵运输的次序。下图总结了两个模型在具体实现(采⽤矩阵乘法更新⼀层所有样本的表征)中的运算过程。 两种模型 DIFFormer-s 和 DIFFormer-a 每层更新的运算过程(矩阵形式),红⾊标注的矩阵乘法操作是计算瓶颈。DIFFormer-s 的优势在于可以实现对样本数量 N 的线性复杂度,有利于模型扩展到⼤规模数据集模型扩展更进⼀步的,我们可以引⼊更多设计来提升模型的适⽤性和灵活度。上述的模型主要考虑了样本间的 all-pair attention。对于输⼊数据本身就含有样本间图结构的情况,我们可以加⼊现有图神经⽹络(GNN)中常⽤的传播矩阵(propagation matrix)来融合已知的图结构信息,从⽽定义每层的样本表征更新如下 ⽐如如果采⽤图卷积⽹络(GCN)中的传播矩阵,则这⾥ , 表示输⼊图, 表示其对应的(对⻆)度矩阵。 类似其他 Transformer ⼀样,在每层更新中我们可以加⼊ residual link,layer normalization,以及⾮线性激活。下图展示了 DIFFormer 的单层更新过程。 DIFFormer 的全局输⼊包含样本输⼊特征 X 以及可能存在的图结构 A(可以省略),通过堆叠 DIFFormer layer 更新计算样本表征。在每层更新时,需要计算⼀个全局 attention(具体的可以使⽤ DIFFormer-s 和 DIFFormer-a 两种实现),如果考虑输⼊图结构则加⼊ GCN Conv另⼀个值得探讨的问题,是如何处理⼤规模数据集(尤其是包含⼤量样本的数据集,此时考虑全局 all-pair attention ⾮常耗费资源)。在这种情况下我们默认使⽤线性复杂度的 DIFFormer-s 的架构,并且可以在每个训练 epoch 对数据集进⾏ random mini-batch 划分。由于线性复杂度,我们可以使⽤较⼤的 batch size 也能使得模型在单卡上进⾏训练(详⻅实验部分)。 对于包含⼤量样本的数据集,我们可以对样本进⾏随机 minibatch 划分,每次只输⼊⼀个 batch 的样本。当输⼊包含图结构时,我们可以只提取 batch 内部样本所组成的⼦图输⼊进⽹络。由于 DIFFormer-s 只需要对 batch size 的线性复杂度,在实际中就可以使⽤较⼤的 batch size,保证充⾜的全局信息实验结果为了验证 DIFFormer 的有效性和在不同场景下的适⽤性,我们考虑了多个实验场景,包括不同规模图上的节点分类、半监督图⽚ / ⽂本分类和时空预测任务。 图节点分类实验此时输⼊数据是⼀张图,图中的每个节点是⼀个样本(包含特征和标签),⽬标是利⽤节点特征和图结构来预测节点的标签。我们⾸先考虑⼩规模图 的实验,此时可以将⼀整图输⼊ DIFFormer。相⽐于同类模型例如 GNN,DIFFormer 的优势在于可以不受限于输⼊图,学习未被观测到的连边关系,从⽽更好的捕捉⻓距离依赖和潜在关系。下图展示了与 SOTA ⽅法的对⽐结果。 进⼀步的我们考虑在⼤规模图上的实验。此时由于图的规模过⼤,⽆法将⼀整图直接输⼊模型(否则将造成 GPU 过载),我们使⽤ mini-batch 训练。具体的,在每个 epoch,随机的将所有节点分为相同⼤⼩的 mini-batch。每次只将⼀个 mini-batch 的节点输⼊进⽹络;⽽对于输⼊图,只使⽤包含在这个 mini-batch 内部的节点所组成的⼦图输⼊进⽹络;每次迭代过程中,DIFFormer 也只会在 mini-batch 内部的节点之间学习 all-pair attention。这样做就能⼤⼤减⼩空间消耗。⼜因为 DIFFormer-s 的计算复杂度关于 batch size 是线性的,这就允许我们使⽤很⼤的 batch size 进⾏训练。下图显示了在 ogbn-proteins 和 pokec 两个⼤图数据集上的测试性能,其中对于 proteins/pokec 我们分别使⽤了 10K/100K 的 batch size。此外,下图的表格也展示了 batch size 对模型性能的影响,可以看到,当使⽤较⼤ batch size 时,模型性能是⾮常稳定的。 图⽚ / ⽂本分类实验第⼆个场景我们考虑⼀般的分类问题,输⼊是⼀些独⽴的样本(如图⽚、⽂本),样本间没有已观测到的依赖关系。此时尽管没有输⼊图结构, DIFFormer 仍然可以学习隐含在数据中的样本依赖关系。对于对⽐⽅法 GCN/GAT,由于依赖于输⼊图,我们这⾥使⽤ K 近邻⼈⼯构造⼀个样本间的图结构。 时空预测进⼀步的,我们考虑时空预测任务,此时模型需要根据历史的观测图⽚段(包含上⼀时刻节点标签和图结构)来预测下⼀时刻的节点标签。这⾥我们横向对⽐ 了 DIFFormer-s/DIFFormer-a 在使⽤输⼊图和不使⽤输⼊图(w/o g)时的性能,发现在不少情况下不使⽤输⼊图模型反⽽能给出的较⾼预测精度。这也说明了在这类任务中,给定的观测图结构可能是不可靠的,⽽ DIFFormer 则可以通过从数据中学习依赖关系得到更有⽤的结构信息。 扩散过程下的统⼀视⻆从能量约束的扩散过程出发,我们也可以将其他信息传递模型如 MLP/GCN/GAT 看作 DIFFormer 的特殊形式,从⽽给出统⼀的形式化定义。下图概括了⼏种⽅法对应的能量函数和扩散率。相⽐之下,从扩散过程来看, DIFFormer 会考虑任意两两节点之间的信号流动且流动的速率会随着时间适应性的变化,⽽ GNN 则是将信号流动 限制在⼀部分节点对之间。从能量约束来看,DIFFormer 会同时考虑局部(与⾃身状态)和全局(与其他节点)的⼀致性约束,⽽ MLP/GNN 则是分别侧重于⼆者之⼀, 且 GNN 通常只考虑输⼊图中相邻的节点对约束。 总结与讨论在这个⼯作中,我们讨论了如何从扩散⽅程出发得到 MLP/GNN/Transformer 的模型更新公式,⽽后提出了⼀个能量约束下的扩散过程,并通过理论分析得到了最优 扩散率的闭式解。基于理论结果,我们提出了 DIFFormer。总的来说,DIFFormer 主要具有以下两点优势: 从设计思想上看:模型结构从能量下降扩散过程的⻆度导出,相⽐于直接的启发式设计更加具有理论依据; 从模型实现上看:在保留了学习每层所有节点全局 all-pair attention 的表达能⼒的同时,DIFFormer-s 只需要复杂度来更新 个节点的表征,同时兼容 mini-batch training,可以有效扩展到⼤规模数据集。 DIFFormer 作为⼀个通⽤的 encoder,可以被主要应⽤于以下⼏种场景: 建模含有观测结构的数据,得到节点表征(简⾔之就是使⽤ GNN 的场景):输⼊是⼀张图包含了互连的节点,需要计算图中节点的表征。这是⼀个相对已被⼴泛研究的领域,DIFFormer 的优势在于可以挖掘未被观测的隐式结构(如图中的缺失边、⻓距离依赖等),以及在低标签率的情况下提升精度。 建模不含观测结构但样本间存在隐式依赖的数据(如⼀般的分类 / 回归问题):数据集包含⼀系列独⽴样本 ,样本间的依赖关系未知。此时 DIFFormer 可⽤于学习样本间的隐式依赖关系,利⽤全局信息来计算每个样本的表征。这是⼀个较少被研究的领域,传统⽅法的主要 bottleneck 是在⼩数据集上容易过拟合(由于考虑了样本依赖模型过于复杂),⼤数据集上⼜⽆法有效扩展(学习任意两两样本的关系带来了平⽅复杂度)。DIFFormer 的优势在于简单的模型结构有效避免了过拟合问题,⽽且保证了相对于样本数量的复杂度可以有效扩展到⼤规模数据集。
作为⼀般的即插即⽤式 encoder,解决各式各样的下游任务(如⽣成 / 预测 / 决策问题)。此时 DIFFormer 可以直接⽤于⼤框架下的某个部件,得到输⼊数据的隐空间表征,⽤于下游任务。相⽐于其他 encoder (如 MLP/GNN/Transformer),DIFFormer 的优势在于可以⾼效的计算全局 attention,同时具有⼀定的理论基础(能量下降扩散过程的观点)。
最后欢迎感兴趣的朋友们阅读论⽂和访问我们的 GitHub,共同学习进步~ 参考⽂献[1] Qitian Wu et al., DIFFormer: Scalable (Graph) Transformers Induced by Energy Constrained Diffusion, ICLR 2023.[2] Qitian Wu et al., NodeFormer: A Scalable Graph Structure Learning Transformer for Node Classification, NeurIPS 2022.[3] Chenxiao Yang et al., Geometric Knowledge Distillation: Topology Compression for Graph Neural Networks, NeurIPS 2022©THE END
转载请联系本公众号获得授权
投稿或寻求报道:content@jiqizhixin.com
标签: