PyTorch 分布式推理:从单卡到多卡服务,模型部署的性能优化与架构设计
一、模型推理的工程困境:单卡吞吐的天花板
大模型推理面临的核心矛盾是:模型参数量持续增长(7B → 70B → 405B),单卡显存与算力增长缓慢。7B 模型在 FP16 下需要 14GB 显存,单卡 A100 可容纳;70B 模型需要 140GB,单卡无法装载。即使模型能放入单卡,推理吞吐量也受限于单卡算力——A100 的 FP16 算力约 312 TFLOPS,而 7B 模型生成一个 Token 需要约 14 GFLOPS,理论极限约 22 Token/s,实际受内存带宽限制仅 10-15 Token/s。
分布式推理是突破单卡瓶颈的必经之路:Tensor Parallelism(张量并行)将模型层内切分到多卡并行计算,Pipeline Parallelism(流水线并行)将模型层间切分到多卡流水执行。理解分布式推理的架构设计与性能优化,是部署大规模模型服务的工程基础。
二、分布式推理的并行策略与架构
flowchart TD A[大模型推理需求] --> B[单卡瓶颈: 显存/算力不足] B --> C[分布式推理] subgraph 张量并行 TP D[权重按列切分: Column Parallel] E[权重按行切分: Row Parallel] F[注意力头切分: Multi-Head TP] end subgraph 流水线并行 PP G[层间切分: Stage划分] H[微批次流水: Micro-batch] I[气泡优化: 调度策略] end subgraph 数据并行 DP J[同模型多副本] K[请求级负载均衡] L[显存冗余: 每卡完整模型] end C --> D C --> G C --> J subgraph 推理优化技术 M[KV Cache: 避免重复计算] N[Continuous Batching: 动态组批] O[PagedAttention: 显存分页管理] P[Speculative Decoding: 推测解码] end D --> M D --> N D --> O D --> P subgraph 工程选型 Q[7B: 单卡 + Continuous Batching] R[70B: TP=4 + KV Cache + PagedAttention] S[405B: TP=8 + PP=2 + 全部优化] end M --> Q N --> R O --> S张量并行的核心思想:将线性层的权重矩阵按列或行切分到多卡,每卡计算部分结果后通过 All-Reduce 通信聚合。对于 Transformer 的 MLP 层,第一个线性层按列切分(每卡计算不同列的输出),第二个线性层按行切分(每卡计算部分行的结果后求和)。两次 All-Reduce 通信的延迟是 TP 的主要开销。
三、工程实现:张量并行推理引擎
# distributed_inference.py — PyTorch 分布式推理引擎 import torch import torch.nn as nn import torch.distributed as dist from typing import Optional, Tuple def init_distributed(backend: str = "nccl"): """初始化分布式环境""" dist.init_process_group(backend=backend) torch.cuda.set_device(dist.get_rank()) class ColumnParallelLinear(nn.Module): """列并行线性层:权重按列切分到多卡 将权重 W (out_features, in_features) 按列切分为 W_1, W_2, ..., W_n,每卡持有 W_i (out_features/n, in_features)。 前向时每卡独立计算 y_i = x @ W_i^T,结果为完整输出的子集。 """ def __init__( self, in_features: int, out_features: int, world_size: int, rank: int, bias: bool = True, ): super().__init__() self.world_size = world_size self.rank = rank self.out_features_per_partition = out_features // world_size # 每卡仅持有 1/world_size 的权重 self.weight = nn.Parameter( torch.empty(self.out_features_per_partition, in_features) ) if bias: self.bias = nn.Parameter( torch.empty(self.out_features_per_partition) ) else: self.register_parameter('bias', None) self._init_weights() def _init_weights(self): nn.init.kaiming_uniform_(self.weight) if self.bias is not None: nn.init.zeros_(self.bias) def forward(self, x: torch.Tensor) -> torch.Tensor: # x: (B, in_features) → output: (B, out_features/world_size) output = F.linear(x, self.weight, self.bias) return output class RowParallelLinear(nn.Module): """行并行线性层:权重按行切分到多卡 将权重 W (out_features, in_features) 按行切分为 W_1, W_2, ..., W_n,每卡持有 W_i (out_features, in_features/n)。 前向时每卡计算 y_i = x_i @ W_i^T,All-Reduce 求和得到完整输出。 """ def __init__( self, in_features: int, out_features: int, world_size: int, rank: int, bias: bool = True, ): super().__init__() self.world_size = world_size self.rank = rank self.in_features_per_partition = in_features // world_size self.weight = nn.Parameter( torch.empty(out_features, self.in_features_per_partition) ) if bias: # 行并行的 bias 只在 rank 0 上持有,避免重复 self.bias = nn.Parameter(torch.empty(out_features)) else: self.register_parameter('bias', None) self._init_weights() def _init_weights(self): nn.init.kaiming_uniform_(self.weight) if self.bias is not None: nn.init.zeros_(self.bias) def forward(self, x: torch.Tensor) -> torch.Tensor: # x: (B, in_features/world_size) → partial: (B, out_features) partial_output = F.linear(x, self.weight) # All-Reduce 求和:聚合所有卡的部分结果 dist.all_reduce(partial_output, op=dist.ReduceOp.SUM) if self.bias is not None and self.rank == 0: partial_output += self.bias # 广播 bias 加法结果到所有卡 if self.bias is not None: dist.broadcast(partial_output, src=0) return partial_output class TPMLP(nn.Module): """张量并行的 MLP 层:Column Parallel + Row Parallel 标准 Transformer MLP 的并行化: 1. 第一个线性层按列切分(每卡计算不同中间维度) 2. 激活函数各卡独立计算 3. 第二个线性层按行切分(All-Reduce 聚合结果) 通信开销:仅一次 All-Reduce(在 Row Parallel 层后) """ def __init__( self, hidden_dim: int, ff_dim: int, world_size: int, rank: int, ): super().__init__() self.up_proj = ColumnParallelLinear( hidden_dim, ff_dim, world_size, rank ) self.gate_proj = ColumnParallelLinear( hidden_dim, ff_dim, world_size, rank, bias=False ) self.down_proj = RowParallelLinear( ff_dim, hidden_dim, world_size, rank ) def forward(self, x: torch.Tensor) -> torch.Tensor: # SwiGLU 激活:gate * up gate = F.silu(self.gate_proj(x)) up = self.up_proj(x) intermediate = gate * up output = self.down_proj(intermediate) return output class KVCache: """KV Cache 管理:避免自回归解码时重复计算已生成的 Key/Value 核心思想:生成第 t 个 Token 时,前 t-1 个 Token 的 KV 已在 之前步骤中计算并缓存,只需计算第 t 个 Token 的 KV 并追加。 """ def __init__( self, num_layers: int, num_heads: int, head_dim: int, max_batch_size: int = 32, max_seq_len: int = 2048, dtype: torch.dtype = torch.float16, ): self.num_layers = num_layers self.dtype = dtype # 预分配 KV Cache 显存(避免动态分配的碎片化) cache_shape = ( max_batch_size, num_heads, max_seq_len, head_dim ) self.key_cache = [ torch.zeros(cache_shape, dtype=dtype, device='cuda') for _ in range(num_layers) ] self.value_cache = [ torch.zeros(cache_shape, dtype=dtype, device='cuda') for _ in range(num_layers) ] self.current_seq_len = 0 def update( self, layer_idx: int, new_key: torch.Tensor, new_value: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """更新 Cache 并返回完整的 Key/Value""" batch_size = new_key.size(0) # 将新的 KV 追加到 Cache self.key_cache[layer_idx][:batch_size, :, self.current_seq_len] = new_key.squeeze(2) self.value_cache[layer_idx][:batch_size, :, self.current_seq_len] = new_value.squeeze(2) # 返回从 0 到 current_seq_len+1 的完整 KV full_key = self.key_cache[layer_idx][:batch_size, :, :self.current_seq_len + 1] full_value = self.value_cache[layer_idx][:batch_size, :, :self.current_seq_len + 1] if layer_idx == self.num_layers - 1: self.current_seq_len += 1 return full_key, full_value def reset(self) -> None: """重置 Cache(新请求开始时调用)""" self.current_seq_len = 0四、分布式推理的边界与权衡
通信开销与并行度的矛盾:张量并行的每次 All-Reduce 通信延迟约 50-100μs(NVLink),8 卡 TP 需要两次 All-Reduce,通信开销约 100-200μs。当单次推理计算时间小于通信开销时(小 Batch、短序列),TP 反而降低吞吐。建议 TP 并行度不超过 8(单节点 NVLink 带宽内),跨节点使用 PP 而非 TP。
KV Cache 的显存占用:KV Cache 的显存占用与序列长度线性增长。70B 模型在 FP16 下,单请求 2048 Token 的 KV Cache 约需 5GB 显存。32 个并发请求的 KV Cache 占用 160GB,超过模型权重本身。PagedAttention(vLLM 的核心技术)将 KV Cache 分页管理,显著减少显存碎片,提升并发能力 2-4 倍。
Continuous Batching 的调度复杂度:传统 Static Batching 等待所有请求生成完毕才处理下一批,短请求被长请求拖慢。Continuous Batching 在每个 Token 生成步骤动态调度:已完成的请求立即移出,新请求立即加入。调度逻辑复杂度显著增加,但吞吐量提升 2-3 倍。
量化与并行的交互:INT8/INT4 量化将模型显存占用减半,但量化后的矩阵乘法在 TP 模式下需要额外的反量化通信。建议在 TP 之前完成量化,确保每卡持有量化后的权重,减少通信数据量。
五、总结
分布式推理是大规模模型部署的工程基础。张量并行在单节点内切分模型层内计算,流水线并行跨节点切分模型层间计算,数据并行通过多副本提升吞吐。工程落地的关键在于:TP 并行度控制在 NVLink 带宽内(不超过 8)、KV Cache 分页管理突破显存瓶颈、Continuous Batching 动态调度提升吞吐、量化与并行协同优化减少通信开销。分布式推理不是简单的"多卡堆叠",而是通信与计算的精细平衡——理解这一权衡,才能设计出高吞吐、低延迟的模型服务架构。