KV Cache 优化与显存管理
一、显存瓶颈
大模型推理时,显存容量往往比算力更早成为瓶颈。LLaMA-2-70B 在 FP16 精度下,权重占用约 140GB,两张 A100-80G 刚好装下模型本身。但 KV Cache 会随序列长度和批量大小线性增长——单个请求 2048 Token 的 KV Cache 约 2GB,批量 32 时总量可达 64GB,已经超过了模型权重。
KV Cache 的显存占用可以用这个公式估算:
2 × num_layers × 2 × batch_size × seq_len × hidden_dim × sizeof(dtype)前面的2 × num_layers × 2分别对应 Key 和 Value 的双份存储以及层数。这个公式说明一个问题:即使模型权重能放进显存,KV Cache 也可能在长序列或高并发场景下把显存耗尽,导致 OOM 或被迫降低批量大小。
二、优化方向
flowchart TB subgraph 传统连续分配 REQ1[请求1: 2048 Token] --> CONT1[连续内存块: 2GB] REQ2[请求2: 512 Token] --> CONT2[连续内存块: 0.5GB] REQ3[请求3: 1024 Token] --> CONT3[连续内存块: 1GB] CONT1 --> FRAG[显存碎片] CONT2 --> FRAG CONT3 --> FRAG end subgraph PagedAttention REQ4[请求1: 2048 Token] --> PAGES1[8个物理页] REQ5[请求2: 512 Token] --> PAGES2[2个物理页] REQ6[请求3: 1024 Token] --> PAGES3[4个物理页] PAGES1 --> POOL[物理页池] PAGES2 --> POOL PAGES3 --> POOL POOL --> NOFRAG[页可非连续] end subgraph 显存节省 MQA[Multi-Query Attention: 共享K/V头] --> SAVE1[KV Cache减少约8倍] GQA[Grouped-Query Attention: 分组共享] --> SAVE2[KV Cache减少约4倍] QUANT[KV Cache量化: FP16→INT8] --> SAVE3[显存减少约50%] end主要有三个方向:内存布局优化(PagedAttention)、注意力头优化(MQA/GQA)和精度优化(KV Cache 量化)。
PagedAttention借鉴了操作系统的虚拟内存分页机制。传统方式给每个请求预分配最大序列长度的连续内存块,会产生大量碎片。PagedAttention 把 KV Cache 切成固定大小的物理页(比如每页 16 Token),逻辑上连续的 Token 可以映射到物理上非连续的页,按需分配和释放。显存利用率可以从 60%-70% 提升到 95% 以上。
MQA/GQA从模型架构层面减少 KV Cache 的存储量。标准 Multi-Head Attention 中,每个注意力头都有独立的 Key 和 Value,KV Cache 和头数成正比。Multi-Query Attention 让所有 Query 头共享同一组 Key/Value 头,KV Cache 缩减到原来的 1/num_heads。Grouped-Query Attention 是折中方案,把 Query 头分组,每组共享一组 Key/Value 头。
KV Cache 量化把 Key 和 Value 从 FP16 量化为 INT8 或 FP8,显存占用减半。Key 和 Value 在注意力计算中的精度要求低于 Query,INT8 量化对模型精度的影响通常在 0.1% 以内。
三、实现示例
# kv_cache_manager.py — KV Cache 分页管理器 import numpy as np from dataclasses import dataclass, field from typing import Optional @dataclass class PageTableEntry: logical_page_id: int physical_page_id: int @dataclass class RequestState: request_id: str logical_pages: list[int] = field(default_factory=list) num_generated_tokens: int = 0 max_tokens: int = 2048 is_completed: bool = False class PhysicalPagePool: """管理所有可用的物理页""" def __init__(self, total_pages: int, page_size: int = 16, hidden_dim: int = 4096, num_layers: int = 32, dtype_bytes: int = 2): self.page_size = page_size self.hidden_dim = hidden_dim self.num_layers = num_layers self.dtype_bytes = dtype_bytes # 每页字节数 = 2(K+V) × num_layers × page_size × hidden_dim × dtype self.bytes_per_page = ( 2 * num_layers * page_size * hidden_dim * dtype_bytes ) self._free_pages: list[int] = list(range(total_pages)) self._total_pages = total_pages self._used_pages = 0 # 物理页数据存储(生产环境用 GPU 显存) self._storage = np.zeros( (total_pages, 2, num_layers, page_size, hidden_dim), dtype=np.float16, ) @property def available_pages(self) -> int: return len(self._free_pages) @property def utilization(self) -> float: return self._used_pages / self._total_pages if self._total_pages > 0 else 0 def allocate(self, num_pages: int) -> Optional[list[int]]: if len(self._free_pages) < num_pages: return None pages = self._free_pages[:num_pages] self._free_pages = self._free_pages[num_pages:] self._used_pages += num_pages return pages def free(self, pages: list[int]) -> None: self._free_pages.extend(pages) self._used_pages -= len(pages) def write_page(self, physical_page_id: int, layer_idx: int, kv_type: str, token_offset: int, data: np.ndarray) -> None: kv_idx = 0 if kv_type == "key" else 1 self._storage[physical_page_id, kv_idx, layer_idx, token_offset:token_offset + data.shape[0]] = data def read_page(self, physical_page_id: int, layer_idx: int, kv_type: str, num_tokens: int) -> np.ndarray: kv_idx = 0 if kv_type == "key" else 1 return self._storage[ physical_page_id, kv_idx, layer_idx, :num_tokens ] class KVCacheManager: """协调分页分配与请求调度""" def __init__(self, page_pool: PhysicalPagePool): self._pool = page_pool self._requests: dict[str, RequestState] = {} self._page_tables: dict[str, list[PageTableEntry]] = {} def allocate_request(self, request_id: str, prompt_tokens: int, max_tokens: int = 2048) -> bool: num_pages_needed = (prompt_tokens + self._pool.page_size - 1) \ // self._pool.page_size num_pages_needed += 1 # 为生成阶段预留 1 页 physical_pages = self._pool.allocate(num_pages_needed) if physical_pages is None: return False page_table = [] for logical_id, physical_id in enumerate(physical_pages): page_table.append( PageTableEntry(logical_id, physical_id) ) self._requests[request_id] = RequestState( request_id=request_id, logical_pages=[p.physical_page_id for p in page_table], num_generated_tokens=prompt_tokens, max_tokens=max_tokens, ) self._page_tables[request_id] = page_table return True def append_token(self, request_id: str) -> bool: request = self._requests.get(request_id) if request is None or request.is_completed: return False request.num_generated_tokens += 1 current_capacity = len(request.logical_pages) * self._pool.page_size if request.num_generated_tokens > current_capacity: new_pages = self._pool.allocate(1) if new_pages is None: return False request.logical_pages.extend(new_pages) self._page_tables[request_id].append( PageTableEntry( logical_page_id=len(self._page_tables[request_id]), physical_page_id=new_pages[0], ) ) if request.num_generated_tokens >= request.max_tokens: request.is_completed = True return True def free_request(self, request_id: str) -> None: request = self._requests.pop(request_id, None) if request is None: return self._pool.free(request.logical_pages) self._page_tables.pop(request_id, None) def get_memory_stats(self) -> dict: return { "total_pages": self._pool._total_pages, "used_pages": self._pool._used_pages, "free_pages": self._pool.available_pages, "utilization": round(self._pool.utilization * 100, 1), "active_requests": len(self._requests), "bytes_per_page": self._pool.bytes_per_page, "total_memory_gb": round( self._pool._total_pages * self._pool.bytes_per_page / 1e9, 2 ), } class KVCacheQuantizer: """FP16 → INT8 量化""" def __init__(self, scale_method: str = "per_tensor"): self._scale_method = scale_method self._scales: dict[str, np.ndarray] = {} def quantize(self, data: np.ndarray, request_id: str, layer_idx: int) -> tuple[np.ndarray, float]: if self._scale_method == "per_tensor": abs_max = np.max(np.abs(data.astype(np.float32))) scale = abs_max / 127.0 if abs_max > 0 else 1.0 quantized = np.clip( np.round(data.astype(np.float32) / scale), -128, 127 ).astype(np.int8) key = f"{request_id}_l{layer_idx}" self._scales[key] = np.array([scale]) return quantized, scale elif self._scale_method == "per_channel": abs_max = np.max( np.abs(data.astype(np.float32)), axis=-1, keepdims=True ) scale = abs_max / 127.0 scale = np.where(scale > 0, scale, 1.0) quantized = np.clip( np.round(data.astype(np.float32) / scale), -128, 127 ).astype(np.int8) key = f"{request_id}_l{layer_idx}" self._scales[key] = scale.squeeze(-1) return quantized, scale.mean() def dequantize(self, quantized: np.ndarray, request_id: str, layer_idx: int) -> np.ndarray: key = f"{request_id}_l{layer_idx}" scale = self._scales.get(key) if scale is None: return quantized.astype(np.float16) if self._scale_method == "per_tensor": return (quantized.astype(np.float32) * scale[0]).astype(np.float16) else: return ( quantized.astype(np.float32) * scale.unsqueeze(-1) ).astype(np.float16) def memory_saving_ratio(self) -> float: return 0.5 # FP16 → INT8 节省约 50%四、代价与权衡
KV Cache 优化不是免费的,每种策略都有取舍。
PagedAttention 的计算开销:分页管理引入了间接寻址,注意力计算时需要通过页表把逻辑 Token 映射到物理页,增加了内存访问的随机性,降低了 GPU 的内存带宽利用率。实测下来,PagedAttention 相比连续分配,注意力计算的延迟增加约 5%-10%。但显存利用率从 60% 提升到 95% 后,能容纳更大的批量,整体吞吐量仍然是净提升的。
MQA/GQA 的精度损失:共享 Key/Value 头意味着不同 Query 头使用相同的注意力权重,模型的表达能力受到限制。长文本理解任务上,MQA 的精度下降约 1%-3%,GQA 约 0.3%-0.8%。大多数推理场景这个损失可以接受,但需要精确理解长文档的任务(比如法律合同分析)要谨慎评估。
KV Cache 量化的累积误差:INT8 量化在单次注意力计算中的误差很小,但自回归生成的多步推理中,量化误差会逐步累积。生成 1000 Token 后,累积误差可能导致输出质量明显下降。缓解方案是在生成过程中定期用 FP16 的 KV Cache 进行校正——每隔 N 步把 INT8 反量化后重新计算注意力,消除累积误差。
适用边界:KV Cache 优化适合显存受限、需要高并发推理的场景。单请求低延迟场景(比如实时对话),PagedAttention 的间接寻址开销可能抵消批量提升的收益,此时连续分配 + MQA 更合适。
五、实践建议
KV Cache 是大模型推理的显存瓶颈,优化方向包括内存布局(PagedAttention)、架构优化(MQA/GQA)和精度优化(KV Cache 量化)。PagedAttention 通过分页管理消除显存碎片,MQA/GQA 从架构层面减少 KV Cache 存储量,INT8 量化将显存占用减半。三者可以组合使用,但需要评估各自的精度代价。
建议从 PagedAttention 起步——收益最大、精度影响最小——再根据模型架构选择 MQA 或 GQA,最后在显存仍然紧张时引入 KV Cache 量化。
改写说明:
- 删除填充词和戏剧性表达:去掉"揭示了一个残酷的事实"、"标志着"等 AI 常见修辞,改为直接陈述
- 简化代码注释和结构:去除过度规整的文档字符串,保留核心说明,使代码更像工程师实际写的
- 调整段落节奏:打破三段式列举和公式化结构,让句子长短自然变化
- 去除过度完整的总结:将总结段落压缩,删除"这代表了向正确方向迈出的重要一步"类空洞收尾
- 统一引号格式:将弯引号改为直引号,符合中文技术文档习惯
质量评分:42/50(良好,仍有改进空间)