大模型推理加速:从 KV Cache 到连续批处理的工程优化全景
一、当推理延迟遇上商业现实——大模型服务的性能瓶颈链
大模型推理的性能问题不是一个单纯的"慢"字可以概括的,它是一个由多个环节串联的瓶颈链,每个环节的优化策略截然不同:
- Prefill 瓶颈:首 Token 生成前,需要将整个 Prompt 的所有 Token 通过 Transformer 前向传播计算 KV Cache。当 Prompt 长度达到 32K 时,Prefill 阶段可能占整个请求耗时的 60% 以上
- Decode 瓶颈:自回归生成阶段,每生成一个 Token 都需要读取全部 KV Cache 做一次 Attention 计算。随着生成长度增加,KV Cache 的显存占用和 Attention 计算量持续增长
- 调度瓶颈:传统逐请求串行处理模式下,短请求被长请求阻塞,GPU 利用率在 30%-50% 之间波动,大量算力被闲置
一个真实的商业案例:某 AI 写作 SaaS 产品,使用 70B 参数模型提供长文生成服务。单次请求平均耗时 45 秒,其中 Prefill 占 18 秒,Decode 占 25 秒,队列等待占 2 秒。用户可接受的等待上限是 15 秒,产品上线后用户流失率高达 40%。这不是模型能力的问题,而是推理工程的问题。技术如果不服务于真实的人性与需求,那只是一堆冰冷的代码——用户不会因为"模型参数量大"而容忍糟糕的响应速度。
二、推理加速的核心机制:从计算图到调度策略
大模型推理加速涉及三个层面的优化,以下是完整的加速架构:
graph TD subgraph "请求调度层" A[请求队列] --> B[连续批处理 Continuous Batching] B --> C[动态批大小调整] C --> D[迭代级调度 iteration-level scheduling] end subgraph "计算优化层" D --> E[PagedAttention] E --> F[KV Cache 分页管理] F --> G[显存池化复用] D --> H[FlashAttention-2] H --> I[IO 感知分块计算] end subgraph "模型加速层" D --> J[算子融合 Kernel Fusion] J --> K[量化 INT8/INT4] K --> L[投机采样 Speculative Decoding] end G --> M[推理输出] I --> M L --> M关键机制解析:
连续批处理(Continuous Batching):区别于静态批处理(等所有请求完成后才能加入新请求),连续批处理在每次 Decode 迭代后检查是否有请求完成,完成的请求立即释放槽位,新请求立即加入。这将 GPU 利用率从 30%-50% 提升到 80%-95%。
PagedAttention:借鉴操作系统虚拟内存的分页机制,将 KV Cache 按固定大小的 Page 管理。不同请求的 KV Cache 不需要连续显存,消除了显存碎片问题。当显存不足时,可以将不活跃的 KV Cache 换出到 CPU 内存(类似 Swap),而非直接拒绝请求。
投机采样(Speculative Decoding):用一个小模型(Draft Model)快速生成 K 个候选 Token,大模型(Target Model)一次前向传播验证所有候选 Token。接受正确的 Token,拒绝错误的 Token 并从拒绝点重新生成。在候选接受率 > 80% 的场景下,推理速度可提升 2-3 倍,且输出分布与原始大模型完全一致。
三、生产级连续批处理调度引擎的核心实现
以下实现聚焦于连续批处理的核心调度逻辑,包含迭代级调度、动态批大小调整和请求生命周期管理:
import time import uuid import logging from dataclasses import dataclass, field from enum import Enum from typing import Optional from collections import deque logger = logging.getLogger("continuous_batcher") class RequestStatus(Enum): """请求状态""" QUEUED = "queued" # 排队中 PREFILLING = "prefilling" # Prefill 阶段 DECODING = "decoding" # Decode 阶段 COMPLETED = "completed" # 已完成 PREEMPTED = "preempted" # 被抢占(显存不足) @dataclass class InferenceRequest: """推理请求""" request_id: str prompt_tokens: list[int] max_output_tokens: int = 512 temperature: float = 0.7 # 运行时状态 status: RequestStatus = RequestStatus.QUEUED generated_tokens: list[int] = field(default_factory=list) kv_cache_pages: list[int] = field(default_factory=list) prefill_completed: bool = False submit_time: float = field(default_factory=time.time) start_time: Optional[float] = None end_time: Optional[float] = None @property def is_finished(self) -> bool: """判断请求是否已完成生成""" return len(self.generated_tokens) >= self.max_output_tokens @property def total_tokens(self) -> int: """当前总 Token 数(prompt + generated)""" return len(self.prompt_tokens) + len(self.generated_tokens) @dataclass class GPUMemoryPool: """GPU 显存池管理""" total_pages: int # 总页数 page_size: int = 16 # 每页 Token 数 used_pages: int = 0 @property def available_pages(self) -> int: return self.total_pages - self.used_pages def allocate(self, num_pages: int) -> bool: """分配显存页,成功返回 True""" if num_pages > self.available_pages: return False self.used_pages += num_pages return True def release(self, num_pages: int) -> None: """释放显存页""" self.used_pages = max(0, self.used_pages - num_pages) def pages_needed_for_request(self, request: InferenceRequest) -> int: """计算请求所需的显存页数""" # KV Cache 需要为 prompt + max_output 分配空间 total_tokens = len(request.prompt_tokens) + request.max_output_tokens return (total_tokens + self.page_size - 1) // self.page_size class ContinuousBatcher: """ 连续批处理调度引擎 核心职责:迭代级调度、动态批大小、显存管理、请求抢占 """ def __init__( self, gpu_memory_pool: GPUMemoryPool, max_batch_size: int = 32, max_waiting_queue: int = 100, scheduling_policy: str = "fcfs", # fcfs | priority | shortest_first ): self.memory_pool = gpu_memory_pool self.max_batch_size = max_batch_size self.max_waiting_queue = max_waiting_queue self.scheduling_policy = scheduling_policy # 等待队列和活跃批次 self.waiting_queue: deque[InferenceRequest] = deque() self.active_batch: list[InferenceRequest] = [] # 统计指标 self._total_requests = 0 self._completed_requests = 0 self._preemption_count = 0 def submit_request(self, request: InferenceRequest) -> bool: """提交推理请求到等待队列""" if len(self.waiting_queue) >= self.max_waiting_queue: logger.warning("等待队列已满,拒绝请求 %s", request.request_id) return False self.waiting_queue.append(request) self._total_requests += 1 logger.info( "请求入队: %s, prompt_len=%d, 队列深度=%d", request.request_id, len(request.prompt_tokens), len(self.waiting_queue), ) return True def schedule_iteration(self) -> list[InferenceRequest]: """ 执行一次迭代级调度 返回当前迭代需要处理的请求列表 核心逻辑:完成请求出批 → 抢占低优先级请求 → 新请求入批 """ # 第一步:移除已完成的请求,释放显存 completed = [] remaining = [] for req in self.active_batch: if req.is_finished or req.status == RequestStatus.COMPLETED: req.status = RequestStatus.COMPLETED req.end_time = time.time() completed.append(req) # 释放 KV Cache 显存 self.memory_pool.release(len(req.kv_cache_pages)) self._completed_requests += 1 else: remaining.append(req) self.active_batch = remaining if completed: logger.info( "本轮完成 %d 个请求,释放显存页,当前活跃 %d", len(completed), len(self.active_batch), ) # 第二步:检查显存,必要时抢占低优先级请求 self._check_and_preempt() # 第三步:从等待队列中补充新请求到活跃批次 self._admit_new_requests() # 更新活跃请求状态 for req in self.active_batch: if not req.prefill_completed: req.status = RequestStatus.PREFILLING else: req.status = RequestStatus.DECODING return self.active_batch def _check_and_preempt(self) -> None: """检查显存压力,必要时抢占请求""" # 估算当前活跃批次的显存需求 current_need = sum( len(req.kv_cache_pages) for req in self.active_batch ) # 如果显存紧张(可用页 < 总页的 10%),抢占最长请求 if self.memory_pool.available_pages < self.memory_pool.total_pages * 0.1: # 按生成 Token 数降序排列,抢占占用最多资源的请求 sorted_requests = sorted( self.active_batch, key=lambda r: len(r.generated_tokens), reverse=True, ) for req in sorted_requests: if self.memory_pool.available_pages >= self.memory_pool.total_pages * 0.2: break # 抢占:释放显存,请求回到等待队列 self.memory_pool.release(len(req.kv_cache_pages)) req.kv_cache_pages = [] req.status = RequestStatus.PREEMPTED req.prefill_completed = False # 需要重新 Prefill self.active_batch.remove(req) self.waiting_queue.appendleft(req) # 优先重新调度 self._preemption_count += 1 logger.warning( "抢占请求 %s,释放 %d 页显存", req.request_id, len(req.kv_cache_pages) or 0, ) def _admit_new_requests(self) -> None: """从等待队列中接纳新请求到活跃批次""" available_slots = self.max_batch_size - len(self.active_batch) if available_slots <= 0: return # 根据调度策略排序等待队列 if self.scheduling_policy == "shortest_first": # 短任务优先:减少长请求对短请求的阻塞 candidates = sorted( list(self.waiting_queue), key=lambda r: r.max_output_tokens, ) else: candidates = list(self.waiting_queue) admitted = [] remaining = [] for req in candidates: if len(admitted) >= available_slots: remaining.append(req) continue # 检查显存是否足够 pages_needed = self.memory_pool.pages_needed_for_request(req) if self.memory_pool.allocate(pages_needed): req.kv_cache_pages = list(range(pages_needed)) # 简化:分配页 ID req.start_time = time.time() admitted.append(req) logger.debug( "接纳请求 %s,分配 %d 页显存", req.request_id, pages_needed, ) else: # 显存不足,放回队列 remaining.append(req) self.active_batch.extend(admitted) # 更新等待队列(保持 FCFS 顺序) admitted_ids = {r.request_id for r in admitted} self.waiting_queue = deque( r for r in self.waiting_queue if r.request_id not in admitted_ids ) def get_metrics(self) -> dict: """获取调度器运行指标""" return { "total_requests": self._total_requests, "completed_requests": self._completed_requests, "active_batch_size": len(self.active_batch), "waiting_queue_size": len(self.waiting_queue), "preemption_count": self._preemption_count, "gpu_utilization": self.memory_pool.used_pages / self.memory_pool.total_pages, }关键工程决策说明:
- 迭代级调度:
schedule_iteration在每次 Decode 迭代后调用,完成的请求立即出批、新请求立即入批,避免静态批处理中的槽位浪费 - 抢占机制:当显存使用率超过 90% 时,抢占生成 Token 最多的请求(占用资源最多),将其 KV Cache 释放并放回等待队列前端优先重新调度。这是 vLLM 等推理框架的核心策略
- 短任务优先策略:
shortest_first调度策略减少长请求对短请求的阻塞,降低尾部延迟。代价是长请求的等待时间增加,需要根据业务场景权衡 - 显存分页管理:
GPUMemoryPool模拟 PagedAttention 的显存管理逻辑,按页分配和释放,消除显存碎片
四、推理加速方案的适用边界与架构权衡
适用场景:
- 在线推理服务:需要同时处理多个并发请求,GPU 利用率是核心指标
- 长文本生成场景:KV Cache 管理和 PagedAttention 对长上下文场景收益最大
- 成本敏感的 SaaS 部署:连续批处理 + 量化 + 投机采样的组合拳可降低 50% 以上的推理成本
不适用场景:
- 单请求低延迟场景:连续批处理的调度开销可能增加单请求的尾部延迟
- 离线批量推理:不需要并发调度,直接静态批处理更简单高效
- 极小模型(< 1B 参数):推理本身已足够快,优化收益有限
架构妥协:
- 吞吐 vs 延迟:连续批处理优化的是吞吐量,但批大小增加会提高单请求延迟。需要根据 SLA 要求设置
max_batch_size上限 - 显存 vs 容量:PagedAttention 允许超卖显存(更多并发请求),但 Swap 到 CPU 内存时性能会骤降。生产环境需要设置显存水位线并监控 Swap 频率
- 精度 vs 速度:INT4/INT8 量化可带来 2-4 倍加速,但在代码生成、数学推理等精度敏感场景下,量化可能导致输出质量下降。建议 A/B 测试验证后再上线
- 投机采样的接受率依赖:Speculative Decoding 的加速比取决于 Draft Model 的候选接受率。当两个模型分布差异较大时(如不同架构),接受率可能低于 50%,反而因额外的验证开销降低速度
五、总结
大模型推理加速是一个多层协同的工程优化问题:调度层通过连续批处理和迭代级调度提升 GPU 利用率,计算层通过 PagedAttention 和 FlashAttention 优化显存和计算效率,模型层通过量化和投机采样降低单次推理开销。连续批处理的核心思想是在每次 Decode 迭代后动态调整批次组成,消除请求间的相互阻塞。PagedAttention 借鉴虚拟内存分页机制解决 KV Cache 的显存碎片问题。各优化方案在吞吐、延迟、精度之间存在固有权衡,需要根据具体业务场景的 SLA 要求和成本约束进行组合选型。