WSDM2023 | 学习蒸馏图神经网络

admin 2023年3月20日22:21:07评论28 views字数 4564阅读15分12秒阅读模式

WSDM2023 | 学习蒸馏图神经网络


题目:Learning to Distill Graph Neural Networks

会议:WSDM 2023

图神经网络(GNNs)能够有效地获取图的拓扑和属性信息,在许多领域得到了广泛的研究。近年来,为提高GNN的效率和有效性,为GNN上配置知识蒸馏成为一种新趋势。然而,据我们所知,现有的应用于GNN的知识蒸馏方法都采用了预定义的蒸馏过程,这些过程由几个超参数控制,而不受蒸馏模型性能的监督。蒸馏和评价之间的这种隔离会导致次优结果。在这项工作中,我们旨在提出一个通用的知识蒸馏框架,可以应用于任何预先训练的GNN模型,以进一步提高它们的性能。为了解决分离问题,我们提出了参数化和学习适合蒸馏GNN的蒸馏过程。具体地说,我们没有像以前的大多数工作那样引入一个统一的温度超参数,我们将学习节点特定的蒸馏温度,以获得更好的蒸馏模型性能。我们首先通过一个关于节点邻域编码和预测分布的函数将每个节点的温度参数化,然后设计了一种新的迭代学习过程来进行模型蒸馏和温度学习。我们还引入了我们的方法的一个可扩展的变体来加速模型训练。在5个基准数据集上的实验结果表明,我们提出的框架可以应用于5个流行的GNN模型,并使其预测精度平均相对提高3.12%。此外,可扩展的变体模型以1%的预测精度为代价,使训练速度提高了8倍。

1 简介

图神经网络(GNNs)已经成为最先进的图上的半监督学习技术,并在过去的五年中受到了广泛的关注。数以百计的图神经网络模型已经被提出并成功地应用于各种领域,如计算机视觉、自然语言处理和数据挖掘。近年来,在图神经网络中加入知识蒸馏来达到更好的效率或效果是一种新趋势。在知识蒸馏中,学生模型通过训练来模仿预先训练的教师模型的软预测来学习知识。从效率的角度来看,知识蒸馏可以将深层的图卷积神经网络(GCN)模型(教师)压缩为浅层模型(学生),从而实现更快的推理。从有效性的角度来看,知识蒸馏可以提取图神经网络模型(教师)的知识,并将其注入到设计良好的非图神经网络模型(学生)中,从而利用更多的先验知识,得到更准确的预测结果。

除了教师和学生的选择,蒸馏过程决定了教师和学生模型的软预测在损失函数中如何匹配,也对蒸馏后的学生对下游任务的预测表现至关重要。例如,全局超参数“温度”在知识蒸馏中被广泛采用,它软化了教师模型和学生模型的预测,以促进知识转移。然而,据我们所知,应用于图神经网络的现有知识蒸馏方法都采用了预先定义的蒸馏过程,即只有超参数而没有任何可学习的参数。换句话说,蒸馏过程是启发式或经验式设计的,没有任何来自蒸馏学生的监督,这将分离蒸馏与评价,从而导致次优结果。针对现有的图上知识蒸馏方法的上述缺点,本文提出了一种参数化蒸馏过程的框架。

在本工作中,我们的目标是提出一个通用的知识蒸馏框架,可以应用于任何预训练过的图神经网络模型,以进一步提高其性能。注意,我们关注的是蒸馏过程的研究,而不是学生模型的选择,因此,就像BAN建议的那样,简单地让一个学生模型拥有与其老师相同的神经结构。为了克服蒸馏和评估之间的隔离问题,我们没有将全局温度作为超参数引入,而是创新性地提出通过蒸馏GNN学生的表现来学习特定节点的温度。

本工作的主要思想是为图上的每个节点学到一个特定的温度。我们通过一个关于节点邻域编码和节点预测分布的函数来参数化每个节点的温度。由于传统知识蒸馏框架存在隔离问题,经过蒸馏的学生的性能对节点温度的偏导数不存在,这使得温度参数化中的参数学习有着一定的困难。因此,我们设计了一种新的迭代学习过程,交替执行准备、提取和学习步骤,用于参数训练。在准备阶段,我们将根据当前参数计算每个节点的温度,并建立基于节点温度的知识蒸馏损失;在蒸馏阶段,学生模型的参数将根据蒸馏损失进行更新;在学习阶段,温度建模中的参数将更新,以提高提取的图神经网络学生模型的分类精度。

WSDM2023 | 学习蒸馏图神经网络

2 预备知识

2.1 节点分类

节点分类是一种典型的图上的半监督学习任务,其目的是对给定的标记节点和图结构中的未标记节点进行分类,被广泛应用于许多GNN模型的评估中。形式上,给定一个连通图是顶点集, 是边集,节点分类的任务是基于图结构、有标签的节点集和节点特征来预测没有标签的节点集中每个节点v的标签。其中矩阵X的每一行表示节点v的d-dimensional特征。设为节点标签的集合,则每个节点的真实标签可以表示为一个维的独热向量

2.2 图神经网络

图神经网络可以通过迭代聚合邻居信息,即消息传递机制,将每个节点v编码为|Y|维logit向量。在本文中,我们提出的算法不是针对特定的图神经网络模型设计的,而是可以应用于任何图神经网络。因此,我们简单地将图神经网络编码器以黑盒形式描述化为:

其中是图神经网络中的可学习参数,是在softmax函数归一化之后的预测标签分布。然后图神经网络会对每个有标签的节点最小化该节点的真实标签与预测标签之间的距离,通常采用交叉熵损失来训练参数:

2.3 知识蒸馏

在本工作中,我们关注的是蒸馏过程的研究,而不是学生模型的选择。因此,我们只需让教师模型和学生模型具有BAN建议的相同的神经结构,并分别表示为,参数分别为。给定教师模型的预训练参数,我们将通过对之间的软预测进行对齐,训练学生模型的参数。从形式上讲,知识蒸馏框架旨在优化:

其中第一项是学生预测和教师预测之间的交叉熵,第二项是中节点的学生预测与真实标签的交叉熵,是平衡超参数。许多知识蒸馏方法会引入额外的温度超参数来软化教师和学生的预测:

其中是温度超参数。温度等于1时对应原始的softmax操作。温度越高,预测就越软(趋向均匀分布),而温度越低,预测就越硬(趋向独热分布)。在最流行的蒸馏框架[中,所有的温度都被设置为相同的超参数,即为每个节点𝑣设置。通过调整全局温度超参数,然后对经过蒸馏的学生模型进行评估,并期望其性能优于教师。

3 为每个节点学习特定温度

我们没有引入全局温度作为超参数,而是创新地提出学习特定节点的温度,以获得更好的蒸馏性能。我们将首先介绍如何在温度参数化中引入可学习参数,然后设计一种基于迭代学习过程的参数训练新算法。

3.1 参数化温度

直接为每个节点指定一个自由参数作为节点特定温度将导致严重的过拟合问题。因此,我们假设具有相似编码和邻域预测的节点应该具有相似的蒸馏温度。在实际应用中,每个节点𝑣的温度可以通过一个函数来参数化,该函数需要用到以下特征:(1)学生的logit向量,它直接表征了学生模型当前的预测状态;(2) Logits向量的L2范数,由于softmax函数中的指数算子,较大的范数通常表示较硬的预测分布;(3)中心节点邻居的预测熵,描述了节点邻居的标签多样性。直观上,上述所有特性都会影响模型预测的置信度,因此在温度参数化中都应考虑。形式上,我们将所有学生模型的温度设置为1,以实现更精确的预测,并将教师温度参数化:

我们使用教师而不是学生来建模邻居的预测熵,以获得更好的数值稳定性。我们在实验部分将研究每个连接组件的影响,并讨论我们学习的温度。此外,为了避免梯度爆炸或消失的问题,我们还通过一个基于sigmoid操作的函数,将温度限制在范围内。注意,用来建模温度的其他特征也可能存在,但我们发现我们使用的三个特征足以提高性能,并且这三个特征都是有用的。

3.2迭代学习过程。

为了监督节点特定温度的训练,我们将标记的节点集划分为两个不相交的节点集仍然在损失函数的第二项中用于蒸馏,而用于评估蒸馏的学生模型和学习节点温度。形式上,蒸馏部分的损失可以写成:

评估蒸馏的学生和监督温度的损失是:

然而,由于蒸馏和评估之间的隔离,评估损失只与学生模型的参数有关,评估损失对温度参数的偏导数不存在,这使它不可能通过反向传播来学习温度。为了解决这个问题,我们提出了一个新的迭代学习过程,交替执行以下准备、提炼和学习步骤: 准备步骤:首先计算每个节点v的温度,然后设定好蒸馏损失。蒸馏步骤:对于模型蒸馏,我们通过单步反向传播来更新学生模型参数:

其中是蒸馏过程的学习率。学习步骤:我们用更新之后的学生参数计算评估损失,然后通过链式规则在温度参数上执行反向传播:

其中是蒸馏过程的学习率。这里我们将评估损失对温度参数的偏导分解为评估损失对学生参数的偏导和学生参数对温度参数的偏导的乘积,这两项都可以分别由上述式子的偏导计算得到。通过迭代执行准备、蒸馏和学习步骤,我们可以训练参数化的节点特定温度,从而提高蒸馏学生的预测性能。整个算法的伪代码如下:

WSDM2023 | 学习蒸馏图神经网络

我们还设计了一个轻量化的模型变体来加速模型训练,更多细节请参考原文。

4 实验

在本节中,我们对五个基准数据集进行了实验,以回答以下研究问题(RQs): •RQ1:我们的LTD蒸馏出的GNN学生是否优于其他知识蒸馏框架蒸馏出的学生?与其他蒸馏框架相比,我们模型的效率如何? •RQ2:我们的模型在不同的环境下(如消融研究,GNN教师/学生的不同组合)表现如何? •RQ3:我们可以从学习的LTD参数(即特定节点的温度)中观察到什么模式?

4.1 主实验

我们给出了5个基准数据集和5个GNN模型的结果。我们在表中将最好的结果加粗,并报告了我们的方法相比教师与FT中较好结果的相对提升。

WSDM2023 | 学习蒸馏图神经网络

4.2 消融实验

我们进行了消融研究,以探讨每个特征在温度参数化中的作用。如下图所示,我们将我们的完整模型(V0)与三个消融模型(V1-V3)进行了比较。下图显示了所有三个组件对整体性能的贡献,这说明了每个组件对温度建模的必要性。

WSDM2023 | 学习蒸馏图神经网络

4.3 学习到的温度分析

我们分析了在5 × 5 = 25 gnn数据集组合中学习到的节点特定温度,并提出了以下基于GAT的案例研究,以说明LTD如何帮助学习更好的蒸馏。

(1) 首先,我们计算了随机初始化温度和学习温度之间的Pearson相关系数,以证明训练过程后节点温度发生了显著变化,并且真正具有节点特异性。

(2)我们观察到“令人困惑”的类(即与其他类混合)中的节点往往具有更高的温度。例如,在下图中我们将GAT老师学习到的节点嵌入可视化,并使用节点颜色来表示它们的标签。

WSDM2023 | 学习蒸馏图神经网络

(3)我们观察到具有较小的L2范数的节点倾向于具有较高的温度。

(4)请注意,我们允许负节点温度,这将完全颠倒先训练的教师的预测。我们观察到具有负学习温度的节点很可能被GNN老师错误地预测。

5 结论

在本文中,我们提出了一种新的知识蒸馏框架LTD,可以应用于任何预训练的GNN模型,以进一步提高其预测性能。我们没有像以前的大多数工作那样引入一个全局温度超参数,而是创新地提出通过蒸馏学生的表现来学习节点特定的蒸馏温度。具体而言,我们通过邻域编码和预测情况的函数来参数化每个节点的温度,并设计了一种新的迭代学习过程,用于模型提取和参数学习。作为一种成本有效的选择,LTD的可扩展变体提出了启发式更新节点特定温度。我们在五个基准数据集上进行了实验,并表明我们提出的框架可以成功地应用于五个流行的GNN模型。大量的研究进一步证明了该方法的有效性。



本期责任编辑:杨成
本期编辑:刘佳玮

北邮 GAMMA Lab 公众号
主编:石川
责任编辑:王啸、杨成
编辑:刘佳玮

长按下图并点击“识别图中二维码

即可关注北邮 GAMMA Lab 公众号

WSDM2023 | 学习蒸馏图神经网络

原文始发于微信公众号(北邮 GAMMA Lab):WSDM2023 | 学习蒸馏图神经网络

  • 左青龙
  • 微信扫一扫
  • weinxin
  • 右白虎
  • 微信扫一扫
  • weinxin
admin
  • 本文由 发表于 2023年3月20日22:21:07
  • 转载请保留本文链接(CN-SEC中文网:感谢原作者辛苦付出):
                   WSDM2023 | 学习蒸馏图神经网络http://cn-sec.com/archives/1616314.html

发表评论

匿名网友 填写信息