AI 自适应索引:当学习型索引打破 B+ 树的固定结构
2026/6/26 1:55:19 网站建设 项目流程

AI 自适应索引:当学习型索引打破 B+ 树的固定结构

一、B+ 树索引的固定结构,是性能天花板还是安全底线

一张 5 亿行的订单表,按create_time建索引。查询模式分两种:白天查最近 7 天数据(高选择性),夜间批处理查全量历史(低选择性)。B+ 树索引对两种模式一视同仁——同样的树深度、同样的二分查找路径。白天查询多走了 2 层不必要的中间节点,夜间批处理又因为索引回表代价不如全表扫描。

B+ 树的固定结构是通用性妥协的结果:它对所有数据分布和查询模式提供一致的 O(log N) 保证,但无法针对特定分布做优化。学习型索引(Learned Index)的核心命题:用模型替代 B+ 树的搜索结构,根据数据分布自适应调整查找路径,在特定场景下突破 B+ 树的性能边界。

二、学习型索引的架构与递归模型设计

2.1 从 B+ 树到模型:搜索结构的本质

B+ 树索引的本质是一个函数:f(key) → position,给定键值返回其在磁盘上的位置。B+ 树通过逐层二分查找近似这个函数,学习型索引直接用模型拟合这个函数。

flowchart LR subgraph B+树索引 A1[Root Node] --> B1[Level 1] B1 --> C1[Level 2] C1 --> D1[Leaf Page] D1 --> E1[目标位置] end subgraph 学习型索引 A2[Stage 1 模型] --> B2[Stage 2 模型] B2 --> C2[Stage 3 模型] C2 --> D2[精确定位] end subgraph 混合索引 A3[Stage 1 模型] --> B3[Stage 2 B+ 树] B3 --> C3[Leaf Page] end

2.2 递归回归模型(RMI)

The Case for Learned Index Structures 提出的 RMI(Recursive Regression Model)是学习型索引的经典架构:

  1. Stage 1:一个轻量级模型,将 key 空间划分为若干区间
  2. Stage 2:每个区间一个专用模型,在局部数据分布上做更精确的预测
  3. Stage N:递归细分,直到预测误差在可接受范围内

关键约束:模型预测的位置可能存在误差,必须在预测位置 ± δ 范围内做局部搜索(如二分查找或指数搜索),保证正确性。

2.3 ALEX:自适应学习型索引

ALEX(Adaptive Learned Index)在 RMI 基础上增加了自适应节点结构:

  • Gapped Array:节点内预留空位,插入时无需移动大量数据
  • 节点分裂策略:根据插入模式选择顺序分裂或随机分裂
  • 模型重训练:累积误差超过阈值时,局部重训练模型

三、生产级学习型索引的实现与混合架构

3.1 Python 原型:RMI 索引

import numpy as np from typing import List, Optional, Tuple import bisect import logging logger = logging.getLogger(__name__) class LinearModel: """线性回归模型, 拟合 key → position 映射""" def __init__(self): self.weight = 0.0 # 斜率 self.bias = 0.0 # 截距 self.min_key = 0 self.max_key = 0 self.max_pos = 0 self.max_error = 0 # 最大预测误差 def train(self, keys: np.ndarray, positions: np.ndarray): """用最小二乘法训练线性模型""" if len(keys) < 2: self.weight = 0.0 self.bias = float(positions[0]) if len(positions) > 0 else 0.0 return # 最小二乘: w = Cov(k,p) / Var(k), b = mean(p) - w * mean(k) k_mean = np.mean(keys) p_mean = np.mean(positions) cov = np.mean(keys * positions) - k_mean * p_mean var = np.mean(keys ** 2) - k_mean ** 2 if var < 1e-10: self.weight = 0.0 self.bias = p_mean else: self.weight = cov / var self.bias = p_mean - self.weight * k_mean self.min_key = int(keys[0]) self.max_key = int(keys[-1]) self.max_pos = len(keys) - 1 # 计算最大预测误差, 用于确定搜索范围 predictions = self.weight * keys + self.bias errors = np.abs(predictions - positions) self.max_error = int(np.max(errors)) + 1 def predict(self, key: int) -> int: """预测 key 的位置""" pos = int(self.weight * key + self.bias) return max(0, min(pos, self.max_pos)) def predict_with_error(self, key: int) -> Tuple[int, int, int]: """预测位置并返回误差范围 [lo, hi]""" pos = self.predict(key) lo = max(0, pos - self.max_error) hi = min(self.max_pos, pos + self.max_error) return pos, lo, hi class RMIIndex: """递归回归模型索引 (2-stage RMI)""" def __init__(self, num_stage2_models: int = 100): self.num_stage2 = num_stage2_models self.stage1 = LinearModel() self.stage2_models: List[LinearModel] = [] self.keys: np.ndarray = np.array([]) # 排序后的 key 数组 self.trained = False def build(self, keys: List[int]): """构建 RMI 索引""" self.keys = np.sort(np.array(keys, dtype=np.int64)) n = len(self.keys) positions = np.arange(n, dtype=np.float64) # Stage 1: 全局线性模型 self.stage1.train(self.keys, positions) # Stage 2: 将 key 空间划分为 num_stage2 个区间 self.stage2_models = [] boundaries = np.linspace(0, n, self.num_stage2 + 1, dtype=int) for i in range(self.num_stage2): start = boundaries[i] end = boundaries[i + 1] if end - start < 2: # 区间太小, 复用全局模型 model = LinearModel() model.weight = self.stage1.weight model.bias = self.stage1.bias model.min_key = int(self.keys[start]) if start < n else 0 model.max_key = int(self.keys[min(end, n - 1)]) if end <= n else int(self.keys[-1]) model.max_pos = n - 1 model.max_error = self.stage1.max_error self.stage2_models.append(model) else: model = LinearModel() model.train(self.keys[start:end], positions[start:end]) self.stage2_models.append(model) self.trained = True logger.info(f"RMI 索引构建完成: {n} 个key, {self.num_stage2} 个stage2模型") def lookup(self, key: int) -> Optional[int]: """查找 key 的位置, 返回在排序数组中的 index""" if not self.trained: raise RuntimeError("索引未训练") # Stage 1: 选择 stage2 模型 stage1_pos = self.stage1.predict(key) model_idx = int(stage1_pos * self.num_stage2 / max(self.stage1.max_pos, 1)) model_idx = max(0, min(model_idx, self.num_stage2 - 1)) # Stage 2: 用选中的模型预测位置 model = self.stage2_models[model_idx] pos, lo, hi = model.predict_with_error(key) # 在误差范围内做二分搜索, 保证正确性 lo = max(0, lo) hi = min(len(self.keys) - 1, hi) # numpy 数组的二分搜索 left = bisect.bisect_left(self.keys, key, lo, hi + 1) if left < len(self.keys) and self.keys[left] == key: return int(left) # 误差范围内未找到, 扩大搜索范围 expanded_lo = max(0, lo - self.stage1.max_error) expanded_hi = min(len(self.keys) - 1, hi + self.stage1.max_error) left = bisect.bisect_left(self.keys, key, expanded_lo, expanded_hi + 1) if left < len(self.keys) and self.keys[left] == key: logger.warning(f"key={key} 需要扩大搜索范围, 模型误差被低估") return int(left) return None # key 不存在 def range_lookup(self, lo_key: int, hi_key: int) -> List[int]: """范围查找, 返回 [lo_key, hi_key] 范围内的所有 key""" start = self.lookup(lo_key) if start is None: start = bisect.bisect_left(self.keys, lo_key) end = bisect.bisect_right(self.keys, hi_key) return self.keys[start:end].tolist() def stats(self) -> dict: """返回索引统计信息""" if not self.trained: return {} max_errors = [m.max_error for m in self.stage2_models] return { 'total_keys': len(self.keys), 'stage2_models': self.num_stage2, 'stage1_max_error': self.stage1.max_error, 'stage2_avg_max_error': np.mean(max_errors), 'stage2_p99_max_error': np.percentile(max_errors, 99), }

3.2 混合索引架构:模型 + B+ 树的安全网

生产环境不信任单一模型,混合架构是务实选择:

  • 顶层:学习型模型快速定位到粗粒度区间
  • 底层:B+ 树在区间内做精确查找
  • 回退机制:模型误差超过阈值时,自动回退到完整 B+ 树查找
查找流程: 1. 模型预测位置 pos, 误差范围 [pos-δ, pos+δ] 2. 在 B+ 树中从 pos-δ 开始查找 3. 若 δ 步内找到目标, 返回 (快速路径) 4. 若 δ 步内未找到, 回退到 B+ 树根节点从头查找 (安全路径)

四、学习型索引的现实边界与架构妥协

4.1 数据分布假设

学习型索引的核心假设:key → position 的映射可以被模型拟合。当数据完全随机(如 UUID 主键),线性模型的预测误差接近 N/2,退化为全表扫描。学习型索引对有序、有规律的数据分布效果显著,对随机分布几乎无效。

4.2 写入放大与模型重训练

每次插入新数据,key → position 的映射发生变化。如果插入量超过阈值,模型需要重训练。重训练的代价:对 1 亿行数据训练 RMI 约 2-5 秒,期间索引不可用。ALEX 的 Gapped Array 缓解了小量插入的问题,但大量插入仍需重训练。

4.3 误差边界的安全保证

B+ 树的查找正确性由树结构保证,学习型索引的正确性由误差边界 δ 保证。δ 的计算必须保守:取训练集上的最大误差 × 1.5 作为安全系数。但保守的 δ 增加了局部搜索范围,降低性能收益。

4.4 禁用场景

  • 高频写入(WPS > 10 万):模型频繁重训练,开销不可接受
  • 随机分布主键(UUID):模型预测误差过大
  • 强 ACID 要求的事务系统:学习型索引的误差边界与事务隔离级别存在语义冲突
  • 多维查询(范围查询涉及非前缀列):RMI 只支持单维有序查找

五、总结

学习型索引用模型替代 B+ 树的固定搜索结构,在数据分布有规律、查询模式稳定的场景下,可以减少 2-3 层树的遍历,将点查延迟降低 40%-60%。但其核心局限在于:对数据分布的强假设、写入场景的模型重训练代价、误差边界的安全保证开销。生产落地的务实路径是混合架构——模型做粗定位、B+ 树做精查找,始终保留 B+ 树作为安全回退。学习型索引不是 B+ 树的替代品,而是在特定场景下的加速器。任何没有回退机制的学习型索引,都是数据丢失的隐患。

需要专业的网站建设服务?

联系我们获取免费的网站建设咨询和方案报价,让我们帮助您实现业务目标

立即咨询