Inductive Representation Learning on Large Graph
引言概要
两类图算法
在大规模图上学习节点embedding,在很多任务中非常有效,一般来说分位l两类算法:
-
基于游走类的算法,如学习节点拓扑结构的 DeepWalk;
-
基于图神经网络的的算法,如同时学习邻居特征和拓扑结构的GCN;
GraphSage对GCN的改进
在之前的分享中我们介绍过经典的GCN算法:主要通过拉普拉斯矩阵进行特征值分解做邻接关联和迭代更新
但是图的特征值分解是一个特别耗时的操作,具有O(n^3) 的复杂度,很难扩展到海量节点的场景中
GCN的缺点:
1、融合时边权值是固定的,不够灵活;(GAT)
2、可扩展性差,因为是全图卷积融合和梯度更新,当图比较大时,计算开销极大;(GraphSAGE)
3、层数加深时,结果极容易平滑,每个点的特征十分相似;(DeepGCN)
GraphSAGE(Graph SAmple and aggreGatE) 关注GCN的灵活性问题:添加一个节点,意味着许许多多与之相关的节点的表示都应该调整,这会带来极大的计算开销,即使增加几个节点,也要完全重新训练所有的节点;
它从两个方面对传统的GCN做了改进:
-
在训练时不同于GCN的全图优化,通过采样方式将对节点的邻居抽样,这使得大规模图数据的分布式训练成为可能,并且使得网络可以学习没有见过的节点;
-
研究了若干种邻居聚合的方式,并通过实验和理论分析对比了不同聚合方式的优缺点;
两种学习模式
文章中引入了两种学习模式:
transductive(直推式学习)
-
在图中学习目标是直接生成当前节点的embedding;
DeepWalk, Line 把每个节点embedding作为参数,并通过SGD优化;GCN,在训练过程中使用图的拉普拉斯矩阵进行计算;
-
所有的数据在训练的时候都可以拿到(预测的点或边关系结构也已经在图里),学习的过程是在固定的图上进行学习,一旦图中的某些节点发生变化,则需要对整个图进行重新训练和学习;
inductive(归纳学习)
-
节点的embedding是通过邻接节点传导聚合生成的(学习在图上生成节点Embedding的方法而不是直接学习节点的Embedding);
-
预测数据在模型训练时不用看到,类似我们平常做算法模型的样子,训练和预测的数据分开,换句话说图结构可以不是固定的,加入新的节点可以从邻接节点传导聚合出来;
现实情况大多情况图是会演化的,当网络结构改变以及新节点的出现,直推式学习需要重新训练(复杂度高且可能会导致embedding会偏移),很难落地在需要快速生成未知节点embedding的机器学习系统上。
归纳学习无疑拥有更大的应用价值:
-
一是改变节点不需要对整个图进行重新学习,大大减小了计算量;
使用 K个邻居节点来计算当前节点的属性,计算量小很多,时间复杂度约为 O(|E|d)
-
二是能够对没有见过的数据进行推测,这将增大其应用价值;
PinSage: 第一个基于GCN的工业级推荐系统,为GCN落地提供了实践经验,而本文是PinSAGE的理论基础,同样出自斯坦福,是GCN非常经典和实用的论文。
Github代码地址: https://github.com/williamleif/GraphSAGE
GraphSAGE框架
前向传播
可视化例子
下图是GraphSAGE 生成目标节点(红色)embedding并供下游任务预测的过程:
-
先对邻居随机采样,降低计算复杂度,每一跳抽样的邻居数不多于S_k个(图中一跳邻居采样数=3,二跳邻居采样数=5)
-
生成目标节点emebedding,先聚合2跳邻居特征,生成一跳邻居embedding,再聚合一跳邻居embedding,生成目标节点embedding;
-
将目标节点的embedding作为全连接层的输入,预测目标节点的标签(有监督或者无监督);
从上面的介绍中我们可以看出,GraphSAGE的思想就是不断的聚合邻居信息,然后进行迭代更新。随着迭代次数的增加,每个节点的聚合的信息几乎都是全局的,这点和CNN中感受野的思想类似。
伪代码
GraphSage进一步解释
GraphSAGE的一个强大之处是:它在一个子集学到的模型也可以应用,原因是因为GraphSage的参数是共享的。
如图所示,在计算节点A和节点B的embedding时,它们在计算两层节点用的参数相同。
当有一个新的图或者有一个节点加入到已训练的图中时,我们只需要知道这个新图或者新节点的结构信息,通过共享的参数,便可以得到它们的特征向量。(跟我们在想的直接推导有区别)
GraphSAGE主要解决了两个问题:
-
解决了预测中unseen nodes的问题,原来的GCN训练时,需要看到所有nodes的图数据;
-
解决了图规模较大,全图进行梯度更新,内存消耗大,计算慢的问题;
以上两点都是通过一个方式解决的,也就是采子图的方式,由于采取的子图是局部图且是随机的,从而大大增加模型的可扩展性,还可以看看Cluster-GCN,它通过图聚类的方式去划分子图分区,进一步提高了计算效率。
GraphSage如何聚合信息
假设我们要聚合K次,则需要有K个聚合函数(aggregator):
-
每一次聚合,都是把上一层得到的各个node的特征聚合一次,得到该层的特征;
-
如此反复聚合K次,得到该node最后的特征;
-
最下面一层的node特征就是输入的node features;
-
每一层的node的表示都是由上一层生成的,跟本层的其他节点无关;
在GraphSAGE的实践中,作者发现,K不必取很大的值,当K=2时,效果就很好了,也就是只用扩展到2阶邻居即可。至于邻居的个数,文中提到S1×S2<=500,即两次扩展的邻居数之际小于500,大约每次只需要扩展20来个邻居即可。
这也是合情合理,例如在现实生活中,对你影响最大就是亲朋好友,这些属于一阶邻居,然后可能你偶尔从他们口中听说一些他们的同事、朋友的一些故事,这些会对你产生一定的影响,这些人就属于二阶邻居。
但是到了三阶,可能基本对你不会产生什么影响了,例如你听你同学说他同学听说她同学的什么事迹,是不是很绕口,因为你基本不会听到这样的故事,你所接触到的、听到的、看到的,基本都在“二阶”的范围之内。
聚合函数
伪代码第4行可以使用不同聚合函数,介绍几种满足排序不变量(因为邻居没有顺序,聚合函数需要满足排序不变量的特性,即输入顺序不会影响函数结果)的聚合函数:
-
平均聚合函数;
-
GCN归纳式聚合;
-
LSTM聚合;
-
pooling聚合;
-
平均聚合:
-
先对邻居embedding中每个维度取平均,然后与目标节点embedding拼接后进行非线性转换
-
GCN aggregator:
-
直接对目标节点和所有邻居embedding中每个维度取平均(替换伪代码中第5、6行),后再非线性转换:
-
LSTM聚合:
-
LSTM函数不符合“排序不变量”的性质,需要先对邻居随机排序,然后将随机的邻居序列embedding{x_t, t in N(v)}作为LSTM输入
-
Pooling聚合:
-
先对每个邻居节点上一层embedding进行非线性转换(等价单个全连接层,每一维度代表在某方面的表示),再按element-wise 应用max/mean pooling,捕获邻居集上在某方面的突出的表现,以此表示目标节点embedding。;
无监督和有监督损失设定
损失函数根据具体应用情况,可以使用基于图的无监督损失和有监督损失。
实验
3.1 实验目的
-
比较GraphSAGE 相比baseline 算法的提升效果;
-
比较GraphSAGE的不同聚合函数;
3.2 数据集及任务
-
Citation 论文引用网络(节点分类)
-
Reddit web论坛 (节点分类)
-
PPI 蛋白质网络 (graph分类)
3.3 比较方法
-
随机分类器;
-
手工特征(非图特征);
-
deepwalk(图拓扑特征);
-
deepwalk+手工特征;
-
GraphSAGE四个变种 ,并无监督生成embedding输入给LR 和 端到端有监督;
(分类器均采用LR)
3.4 GraphSAGE 设置
-
K=2,聚合两跳内邻居特征;
-
S1=25,S2=10: 对一跳邻居抽样25个,二跳邻居抽样10个;
-
RELU 激活单元;
-
Adam 优化器;
-
对每个节点进行步长为5的50次随机游走;
-
负采样参考word2vec,按平滑degree进行,对每个节点采样20个;
-
保证公平性:所有版本都采用相同的minibatch迭代器、损失函数、邻居抽样器;
3.5 运行时间和参数敏感性
-
计算时间:下图A中GraphSAGE中LSTM训练速度最慢,但相比DeepWalk,GraphSAGE在预测时间减少100-500倍(因为对于未知节点,DeepWalk要重新进行随机游走以及通过SGD学习embedding)。
-
邻居抽样数量:下图B中邻居抽样数量递增,边际收益递减(F1),但计算时间也变大。
-
聚合K跳内信息:在GraphSAGE, K=2 相比K=1 有10-15%的提升;但将K设置超过2,边际效果上只有0-5%的提升,但是计算时间却变大了10-100倍。
3.6 效果
-
GraphSAGE相比baseline 效果大幅度提升;
-
GraphSAGE有监督版本比无监督效果好;
-
LSTM和pool的效果较好;
原文始发于微信公众号(风物长宜 AI):图神经网络系列四:GraphSage
- 左青龙
- 微信扫一扫
-
- 右白虎
- 微信扫一扫
-
评论