图神经网络系列四:GraphSage

admin 2025年6月6日03:37:30评论10 views字数 3775阅读12分35秒阅读模式

Inductive Representation Learning on Large Graph

引言概要

两类图算法

在大规模图上学习节点embedding,在很多任务中非常有效,一般来说分位l两类算法:

  • 基于游走类的算法,如学习节点拓扑结构的 DeepWalk;

  • 基于图神经网络的的算法,如同时学习邻居特征和拓扑结构的GCN;

图神经网络系列四:GraphSage

GraphSage对GCN的改进 

在之前的分享中我们介绍过经典的GCN算法:主要通过拉普拉斯矩阵进行特征值分解做邻接关联和迭代更新

但是图的特征值分解是一个特别耗时的操作,具有O(n^3) 的复杂度,很难扩展到海量节点的场景中

GCN的缺点:

1、融合时边权值是固定的,不够灵活;(GAT)

2、可扩展性差,因为是全图卷积融合和梯度更新,当图比较大时,计算开销极大;(GraphSAGE)

3、层数加深时,结果极容易平滑,每个点的特征十分相似;(DeepGCN) 

GraphSAGE(Graph SAmple and aggreGatE) 关注GCN的灵活性问题添加一个节点,意味着许许多多与之相关的节点的表示都应该调整,这会带来极大的计算开销,即使增加几个节点,也要完全重新训练所有的节点;

它从两个方面对传统的GCN做了改进:

  • 在训练时不同于GCN的全图优化,通过采样方式将对节点的邻居抽样,这使得大规模图数据的分布式训练成为可能,并且使得网络可以学习没有见过的节点;

  • 研究了若干种邻居聚合的方式,并通过实验和理论分析对比了不同聚合方式的优缺点;

两种学习模式

文章中引入了两种学习模式:

transductive(直推式学习)

  • 在图中学习目标是直接生成当前节点的embedding;

DeepWalk, Line 把每个节点embedding作为参数,并通过SGD优化;GCN,在训练过程中使用图的拉普拉斯矩阵进行计算;

  • 所有的数据在训练的时候都可以拿到(预测的点或边关系结构也已经在图里),学习的过程是在固定的图上进行学习,一旦图中的某些节点发生变化,则需要对整个图进行重新训练和学习;

inductive(归纳学习)

  • 节点的embedding是通过邻接节点传导聚合生成的(学习在图上生成节点Embedding的方法而不是直接学习节点的Embedding);

  • 预测数据在模型训练时不用看到,类似我们平常做算法模型的样子,训练和预测的数据分开,换句话说图结构可以不是固定的,加入新的节点可以从邻接节点传导聚合出来;

现实情况大多情况图是会演化的,当网络结构改变以及新节点的出现,直推式学习需要重新训练(复杂度高且可能会导致embedding会偏移),很难落地在需要快速生成未知节点embedding的机器学习系统上。

归纳学习无疑拥有更大的应用价值:

  • 一是改变节点不需要对整个图进行重新学习,大大减小了计算量;

使用 K个邻居节点来计算当前节点的属性,计算量小很多,时间复杂度约为 O(|E|d)

  • 二是能够对没有见过的数据进行推测,这将增大其应用价值;

PinSage: 第一个基于GCN的工业级推荐系统,为GCN落地提供了实践经验,而本文是PinSAGE的理论基础,同样出自斯坦福,是GCN非常经典和实用的论文。

Github代码地址: https://github.com/williamleif/GraphSAGE

GraphSAGE框架

前向传播

可视化例子

下图是GraphSAGE 生成目标节点(红色)embedding并供下游任务预测的过程:

图神经网络系列四:GraphSage

  1. 先对邻居随机采样,降低计算复杂度,每一跳抽样的邻居数不多于S_k个(图中一跳邻居采样数=3,二跳邻居采样数=5)

  2. 生成目标节点emebedding,先聚合2跳邻居特征,生成一跳邻居embedding,再聚合一跳邻居embedding,生成目标节点embedding;

  3. 将目标节点的embedding作为全连接层的输入,预测目标节点的标签(有监督或者无监督);

从上面的介绍中我们可以看出,GraphSAGE的思想就是不断的聚合邻居信息,然后进行迭代更新。随着迭代次数的增加,每个节点的聚合的信息几乎都是全局的,这点和CNN中感受野的思想类似。

伪代码

图神经网络系列四:GraphSage
图神经网络系列四:GraphSage
图神经网络系列四:GraphSage

GraphSage进一步解释

GraphSAGE的一个强大之处是:它在一个子集学到的模型也可以应用,原因是因为GraphSage的参数是共享的。

如图所示,在计算节点A和节点B的embedding时,它们在计算两层节点用的参数相同。

图神经网络系列四:GraphSage

当有一个新的图或者有一个节点加入到已训练的图中时,我们只需要知道这个新图或者新节点的结构信息,通过共享的参数,便可以得到它们的特征向量。(跟我们在想的直接推导有区别)

GraphSAGE主要解决了两个问题:

  • 解决了预测中unseen nodes的问题,原来的GCN训练时,需要看到所有nodes的图数据;

  • 解决了图规模较大,全图进行梯度更新,内存消耗大,计算慢的问题;

以上两点都是通过一个方式解决的,也就是采子图的方式,由于采取的子图是局部图且是随机的,从而大大增加模型的可扩展性,还可以看看Cluster-GCN,它通过图聚类的方式去划分子图分区,进一步提高了计算效率。

图神经网络系列四:GraphSage

GraphSage如何聚合信息

假设我们要聚合K次,则需要有K个聚合函数(aggregator)

  • 每一次聚合,都是把上一层得到的各个node的特征聚合一次,得到该层的特征;

  • 如此反复聚合K次,得到该node最后的特征;

  • 最下面一层的node特征就是输入的node features;

  • 每一层的node的表示都是由上一层生成的,跟本层的其他节点无关;

图神经网络系列四:GraphSage

图神经网络系列四:GraphSage

在GraphSAGE的实践中,作者发现,K不必取很大的值,当K=2时,效果就很好了,也就是只用扩展到2阶邻居即可。至于邻居的个数,文中提到S1×S2<=500,即两次扩展的邻居数之际小于500,大约每次只需要扩展20来个邻居即可

这也是合情合理,例如在现实生活中,对你影响最大就是亲朋好友,这些属于一阶邻居,然后可能你偶尔从他们口中听说一些他们的同事、朋友的一些故事,这些会对你产生一定的影响,这些人就属于二阶邻居。

但是到了三阶,可能基本对你不会产生什么影响了,例如你听你同学说他同学听说她同学的什么事迹,是不是很绕口,因为你基本不会听到这样的故事,你所接触到的、听到的、看到的,基本都在“二阶”的范围之内

聚合函数

伪代码第4行可以使用不同聚合函数,介绍几种满足排序不变量(因为邻居没有顺序,聚合函数需要满足排序不变量的特性,即输入顺序不会影响函数结果)的聚合函数:

  • 平均聚合函数;

  • GCN归纳式聚合;

  • LSTM聚合;

  • pooling聚合;

  • 平均聚合:

    • 先对邻居embedding中每个维度取平均,然后与目标节点embedding拼接后进行非线性转换

图神经网络系列四:GraphSage

  • GCN aggregator:

    • 直接对目标节点和所有邻居embedding中每个维度取平均(替换伪代码中第5、6行),后再非线性转换:

图神经网络系列四:GraphSage

  • LSTM聚合:

    • LSTM函数不符合“排序不变量”的性质,需要先对邻居随机排序,然后将随机的邻居序列embedding{x_t, t in N(v)}作为LSTM输入

  • Pooling聚合:

    • 先对每个邻居节点上一层embedding进行非线性转换(等价单个全连接层,每一维度代表在某方面的表示),再按element-wise 应用max/mean pooling,捕获邻居集上在某方面的突出的表现,以此表示目标节点embedding。

图神经网络系列四:GraphSage

无监督和有监督损失设定

损失函数根据具体应用情况,可以使用基于图的无监督损失有监督损失

图神经网络系列四:GraphSage

实验

3.1 实验目的

  1. 比较GraphSAGE 相比baseline 算法的提升效果;

  2. 比较GraphSAGE的不同聚合函数;

3.2 数据集及任务

  1. Citation 论文引用网络(节点分类)

  2. Reddit web论坛 (节点分类)

  3. PPI 蛋白质网络 (graph分类)

3.3 比较方法

  1. 随机分类器;

  2. 手工特征(非图特征);

  3. deepwalk(图拓扑特征);

  4. deepwalk+手工特征;

  5. GraphSAGE四个变种 ,并无监督生成embedding输入给LR 和 端到端有监督;

(分类器均采用LR)

3.4 GraphSAGE 设置

  • K=2,聚合两跳内邻居特征;

  • S1=25,S2=10: 对一跳邻居抽样25个,二跳邻居抽样10个;

  • RELU 激活单元;

  • Adam 优化器;

  • 对每个节点进行步长为5的50次随机游走;

  • 负采样参考word2vec,按平滑degree进行,对每个节点采样20个;

  • 保证公平性:所有版本都采用相同的minibatch迭代器、损失函数、邻居抽样器;

3.5 运行时间和参数敏感性

  1. 计算时间:下图A中GraphSAGE中LSTM训练速度最慢,但相比DeepWalk,GraphSAGE在预测时间减少100-500倍(因为对于未知节点,DeepWalk要重新进行随机游走以及通过SGD学习embedding)。

  2. 邻居抽样数量:下图B中邻居抽样数量递增,边际收益递减(F1),但计算时间也变大。 

  3. 聚合K跳内信息:在GraphSAGE, K=2 相比K=1 有10-15%的提升;但将K设置超过2,边际效果上只有0-5%的提升,但是计算时间却变大了10-100倍。

图神经网络系列四:GraphSage

3.6 效果

  1. GraphSAGE相比baseline 效果大幅度提升;

  2. GraphSAGE有监督版本比无监督效果好;

  3. LSTM和pool的效果较好;

图神经网络系列四:GraphSage

原文始发于微信公众号(风物长宜 AI):图神经网络系列四:GraphSage

免责声明:文章中涉及的程序(方法)可能带有攻击性,仅供安全研究与教学之用,读者将其信息做其他用途,由读者承担全部法律及连带责任,本站不承担任何法律及连带责任;如有问题可邮件联系(建议使用企业邮箱或有效邮箱,避免邮件被拦截,联系方式见首页),望知悉。
  • 左青龙
  • 微信扫一扫
  • weinxin
  • 右白虎
  • 微信扫一扫
  • weinxin
admin
  • 本文由 发表于 2025年6月6日03:37:30
  • 转载请保留本文链接(CN-SEC中文网:感谢原作者辛苦付出):
                   图神经网络系列四:GraphSagehttps://cn-sec.com/archives/4138563.html
                  免责声明:文章中涉及的程序(方法)可能带有攻击性,仅供安全研究与教学之用,读者将其信息做其他用途,由读者承担全部法律及连带责任,本站不承担任何法律及连带责任;如有问题可邮件联系(建议使用企业邮箱或有效邮箱,避免邮件被拦截,联系方式见首页),望知悉.

发表评论

匿名网友 填写信息