1. 项目概述
时序链路预测(Temporal Link Prediction)是动态图分析中的核心任务,旨在基于历史交互预测未来节点连接。这项技术在社交网络好友推荐、电商平台商品推荐、学术合作预测等场景中具有广泛应用价值。传统时序图神经网络(TGNN)主要采用判别式方法,通过编码历史交互生成节点嵌入,然后直接预测未来连接概率。然而,这类方法存在两个关键局限:
- 不确定性建模不足:动态图中的交互往往具有高度随机性(如社交网络中用户行为的突发性),传统方法缺乏显式建模这种不确定性的机制
- 序列结构忽视:现有方法通常独立预测每个未来连接,忽略了多个交互之间的时序依赖关系(如用户购物行为序列中的模式演变)
针对这些挑战,我们提出SDG(Sequence Diffusion for Dynamic Graphs)框架,首次将扩散模型(Diffusion Model)引入动态图时序链路预测领域。SDG的核心创新在于:
- 序列级噪声注入:不同于传统方法仅在最终预测时考虑噪声,SDG对整个历史交互序列和目标节点同时注入噪声
- 条件去噪解码器:设计跨注意力机制,利用编码的历史交互信息指导目标序列的去噪过程
- 端到端联合优化:将扩散重建损失与排序目标相结合,确保生成的嵌入既保留时序模式又适合下游预测任务
实验表明,SDG在10个基准数据集上平均MRR指标提升1-15%,特别是在非重复边预测场景(如GoogleLocal数据集)表现突出。该方法为动态图分析提供了新的生成式建模视角,在保持高效计算的同时,显著提升了预测准确性。
2. 核心原理与技术方案
2.1 动态图的形式化定义
连续时间动态图可表示为带时间戳的交互序列:
G = {(u₁,v₁,t₁), (u₂,v₂,t₂), ..., (uₙ,vₙ,tₙ)}其中uᵢ,vᵢ ∈ V分别表示源节点和目标节点,tᵢ ∈ T为时间戳。对于给定源节点u和时间t,其历史交互序列定义为:
S_{u,t} = {(v₁,t₁), (v₂,t₂), ..., (v_L,t_L)} (t₁ ≤ t₂ ≤ ... ≤ t_L < t)L为预设的序列最大长度。时序链路预测任务即评估在时间t,节点u与候选节点v产生连接的概率p(v|u,t,S_{u,t})。
2.2 扩散模型基础
扩散模型通过正向噪声注入和反向去噪过程学习数据分布。给定初始数据x₀,正向过程逐步添加高斯噪声:
q(xₖ|xₖ₋₁) = N(xₖ; √(1-βₖ)xₖ₋₁, βₖI)βₖ为噪声调度参数。经过K步后,x_K近似纯噪声。反向过程通过训练去噪网络fθ预测原始数据:
pθ(xₖ₋₁|xₖ) = N(xₖ₋₁; μθ(xₖ,k), σₖI)传统扩散模型在图像生成等领域表现优异,但其在动态图中的应用面临两个关键挑战:
- 如何将节点交互的离散结构映射到连续扩散空间
- 如何保持时序依赖关系在去噪过程中的一致性
2.3 SDG框架设计
SDG的创新架构如下图所示(图示见原文Figure 1),包含三个核心组件:
2.3.1 序列编码器
采用因果Transformer编码历史交互序列:
- 节点嵌入层:H ∈ R^{N×d}为可学习嵌入表
- 位置编码:添加正弦位置编码保持时序顺序
- 注意力机制:使用因果掩码确保位置i只能关注≤i的交互
数学表达为:
Z_{1:L} = Transformer([H(v₁),...,H(v_L)] + PE; M)M为因果掩码矩阵,PE为位置编码。
2.3.2 序列扩散过程
关键创新在于目标序列构建:
T_{u,t} = {(v₂,t₂), ..., (v_L,t_L), (v,t)}即历史序列去掉最早交互,追加预测目标。对完整序列嵌入X₀ = H(T_{u,t})执行扩散:
- 正向过程:
Xₖ = √ᾱₖ X₀ + √(1-ᾱₖ)ε, ε∼N(0,I)ᾱₖ = ∏(1-βₖ)为累积噪声系数
- 反向过程:
pθ(Xₖ₋₁|Xₖ,Z) = N(Xₖ₋₁; μθ(Xₖ,Z,k), σₖI)其中均值预测采用x₀参数化:
μθ = [√(1-βₖ)(1-ᾱₖ)]/(1-ᾱₖ) Xₖ + (αₖ₋₁βₖ)/(1-ᾱₖ) X̂₀2.3.3 跨注意力去噪器
设计时间条件的跨注意力机制:
- 时间嵌入:通过MLP编码扩散步数k
- 上下文交互:用历史编码Z_{ctx}作为Query
- 噪声序列处理:时间嵌入加噪后的序列作为Key/Value
具体计算流程:
Z_{ctx} = Transformer(Z_{1:L}, M) X̂₀ = CrossAttn(Z_{ctx}, Xₖ + MLP(γ(k)))这种设计确保去噪过程始终受历史交互模式引导。
3. 实现细节与优化策略
3.1 损失函数设计
SDG采用联合损失函数:
L = L_task + λ_diff L_diff3.1.1 扩散重建损失
创新性使用余弦相似度替代传统MSE:
L_diff = 1/L ∑(1 - cos(X̂₀_i, X₀_i))²理论分析表明,当嵌入归一化时,该损失与MSE等价但具有尺度不变性优势。
3.1.2 排序任务损失
采用位置感知的BCE损失:
L_task = -logσ(ŷ_{t,L}^+) - log(1-σ(ŷ_{t,L}^-)) + λ_inter/(L-1) ∑[-logσ(ŷ_{t,i}^+) - log(1-σ(ŷ_{t,i}^-))]其中λ_inter控制中间位置监督的强度,实验表明设为1.0效果最佳。
3.2 高效推理算法
SDG的推理过程如Algorithm 1所示,关键优化包括:
- 缓存机制:历史序列编码Z_{1:L}只需计算一次
- 采样加速:采用DDIM采样策略,可将扩散步数从100降至32
- 并行解码:利用Transformer的并行性同时处理多个候选节点
实验显示,在RTX 4090上处理百万级边图时,SDG比传统TGNN仅增加约20%推理时间。
3.3 超参数配置
基于网格搜索得到最优配置:
| 参数 | 取值范围 | 典型值 |
|---|---|---|
| 嵌入维度d | {64,128} | 128 |
| 扩散步数K | {32,64,96} | 32 |
| λ_diff | [0.2,1.0] | 0.8 |
| λ_inter | [0.2,1.0] | 1.0 |
| 序列长度L | {30,60,90} | 60 |
4. 实验验证与分析
4.1 基准数据集
评估采用10个数据集,分为两类:
小规模重复边数据集(重复率>80%):
- Wikipedia、Reddit、MOOC等
- 主要评估重复交互预测能力
大规模非重复边数据集(重复率<20%):
- GoogleLocal、Taobao、ML-20M等
- 测试模型处理新连接的能力
数据集统计特性如下表所示:
| 数据集 | 节点数 | 边数 | 重复率 |
|---|---|---|---|
| Wikipedia | 9,227 | 157K | 88.4% |
| 10,984 | 672K | 88.3% | |
| GoogleLocal | 473K | 1.9M | 0% |
| ML-20M | 110K | 14.5M | 0% |
4.2 对比实验
4.2.1 主要结果
如表1-2所示,SDG在大部分数据集上取得SOTA:
重复边数据集:
- MOOC:MRR 60.55(+2.99%)
- Wikipedia:MRR 89.17(+0.41%)
非重复边数据集:
- GoogleLocal:MRR 62.60(+14.48%)
- Taobao:MRR 69.70(+3.40%)
4.2.2 效率对比
如图3所示,SDG在训练效率和内存消耗间取得平衡:
- 训练速度:比DyGFormer快3倍
- 内存占用:仅为TGN的1/4(ML-20M数据集)
4.3 消融实验
关键组件的影响如表3所示:
移除序列扩散(w/o Seq):
- GoogleLocal MRR下降7.33
- 证明序列级建模的必要性
替换为MLP解码器(MLP):
- YouTube MRR下降8.63
- 显示Transformer结构优势
使用MSE损失(MSE):
- 性能下降12-15%
- 验证余弦损失的优越性
4.4 噪声鲁棒性测试
如图4所示,当注入60%噪声边时:
- SDG比CRAFT保持更高稳定性
- YouTube数据集仅下降4.9%,而CRAFT下降7.9%
5. 应用案例与部署建议
5.1 典型应用场景
5.1.1 社交网络推荐
- 问题:预测用户未来关注关系
- SDG优势:建模用户交互序列的突发性(如热点事件引发的密集关注)
5.1.2 电商平台
- 问题:基于用户浏览序列推荐商品
- SDG优势:处理长尾商品和新品上架(非重复边预测)
5.2 实际部署技巧
冷启动处理:
- 对新节点使用特征传播初始化嵌入
- 设置默认历史序列(如平台热门商品)
在线学习:
- 定期用新数据微调扩散模型
- 关键参数:学习率1e-5,批量大小256
资源优化:
- 对长序列(L>100)采用分段处理
- 使用8-bit量化减少显存占用
6. 常见问题与解决方案
6.1 训练不稳定
现象:损失值剧烈波动
解决方法:
- 检查梯度裁剪(阈值设为1.0)
- 调整λ_diff(建议从0.2逐步增加)
- 使用学习率warmup(前1000步线性增加)
6.2 过拟合
现象:验证集MRR下降
对策:
- 增加Dropout(概率0.1-0.3)
- 早停策略(耐心10个epoch)
- 数据增强:随机掩码部分历史交互
6.3 计算资源不足
限制:GPU内存<16GB
优化方案:
- 减小批次大小(最低可至32)
- 使用梯度累积(步数4-8)
- 混合精度训练(FP16+FP32)
7. 未来改进方向
多模态扩展:
- 融合节点特征和边属性
- 设计图-文跨模态扩散模型
动态采样策略:
- 自适应调整扩散步数K
- 关键帧预测减少计算量
可解释性增强:
- 可视化注意力权重
- 生成反事实解释
在实际电商平台A/B测试中,SDG相比原有TGN模型使点击率提升18.7%,验证了其工业应用价值。这提醒我们,时序链路预测不仅需要捕捉局部交互模式,更要通过生成式方法建模全局动态演化规律。