x00 概要
CMU 贾志豪老师团队提出的MPK(Mirage Persistent Kernel)是依托 Mirage 编译器生态的创新运行时系统,其核心能力在于将多GPU环境下大语言模型(LLM)推理任务自动转换为适配GPU架构的高性能巨型内核(megakernel)。MPK的关键优势在于将传统由CPU负责的内核调度和任务依赖管理工作转移到GPU端,通过“长期驻留的巨型内核(Persistent Kernel)”自主完成,同时统筹GPU内部计算与跨GPU通信任务。这种设计不仅大幅削减了CPU-GPU交互带来的内核启动开销,还通过计算与通信的细粒度重叠,将推理延迟优化至接近硬件物理极限,显著提升推理效率。
0.1 传统LLM推理框架的瓶颈
传统LLM推理框架的流程存在固有瓶颈:CPU需逐个发起CUDA内核调用(如矩阵乘法、激活函数计算),待GPU执行完当前内核并反馈后,再触发下一个内核。这种“CPU发起-GPU执行-CPU等待”的循环,会产生频繁的CPU-GPU通信与内核启动开销,尤其在自回归生成场景中,单次token生成需多轮内核调用,开销会持续累积,严重拖累整体推理性能。
0.2 MPK的流程重构
MPK彻底重构了这一流程:它仅需CPU在推理初始化阶段,向GPU提交一个“永不主动退出的persistent_kernel”,之后所有任务分派(如层间计算顺序)、依赖管理(如等待前一层结果再执行下一层)均由GPU内部自主完成。此时CPU的角色从“实时调度的包工头”转变为“启动初始化的门卫”,仅负责触发首次内核启动,后续不再参与任何具体调度。
0.3 MPK的关键优势
MPK通过将多GPU的LLM推理任务转换为高性能的巨型内核,从根本上改变了GPU的运行模式。它不仅减少了内核启动开销,还通过细粒度的软件流水线和计算通信重叠,显著提高了推理效率。MPK提供了一种全新的思路,将性能优化的重心从“如何调用优化库”转移到了“如何为整个模型生成一个最优的、原生的执行体”,在多GPU环境下实现了更高的吞吐量和更低的延迟。
0x01 问题
重新设计类似 Mirage 的 MegaKernel的优势,是将所有计算和通信融合进一个单一的巨型内核(也称为持续内核)是降低大语言模型推理延迟的最有效方法之一。这种方法通过启动一个GPU内核来执行整个模型,从逐层计算到GPU间通信,整个过程无需中断。尽管有这些优势,将LLM编译成巨型内核仍然极具挑战性。
1.1 现有框架问题
现有框架难以支持单一的巨型内核。
现有的高级ML框架,如PyTorch、Triton和TVM,并不原生支持端到端巨型内核生成。
现代LLM系统由各种不同的专用内核库构建而成,这种碎片化使得将整个推理流水线整合进一个单一的、统一的内核变得非常困难。
高性能GPU内核的手工编写需要大量的专家知识,如何自动生成高性能内核代码是一个痛点问题。传统做法依赖于专家编写好的内核或者手工融合规则,但这些方法维护成本高,容易漏掉跨内核/层级组合优化的机会。
1.2 编程抽象层级
从编程抽象层级上来看,也缺乏最优系统。
1.2.1 GPU架构
下图展示了当今 GPU 的层次结构。GPU 上的计算被组织为内核,每个内核都是一个函数,以单程序多数据(SPMD)的方式在多个 GPU 核心上同时执行。一个内核包括一个线程块网格,每个线程块在一个 GPU 流式多处理器上执行,并包括多个线程来对单个数据元素进行计算。每个线程都与一个每线程寄存器文件相关联,并且线程块内的所有线程都可以访问共享内存以启用集体操作。最后,内核的所有输入和输出都存储在 GPU 设备内存中。
下图是GPU hierarchy。
gpu_hierarchy
下图为GPU 计算架构和编程抽象示意图
image
1.2.2 编程视角
Triton 是一款高级 GPU 编程框架,其编程视角主要聚焦于块(Block)级别。该框架的设计允许开发者以块为单位进行编程,而块内部的优化工作则由 Triton 编译器自动完成。这种设计模式使开发者能够将精力集中在高层逻辑的构建上,无需深入研究线程(Thread)级别的细节实现。Triton 的核心优势在于其简洁的编程模型和自动化优化能力,这使得它在处理复杂并行任务时具有更高的效率。
Cutlass 则属于底层 GPU 编程库,其编程视角覆盖了块(Block)、线程束(Warp)与线程(Thread)的完整层级。Cutlass 提供了丰富的 CUDA 模板和底层控制接口,开发者可以利用这些工具精细调控每个线程的行为,从而实现高度优化的计算内核。这种细粒度的控制能力让 Cutlass 在对性能有极致要求的场景中表现出色,但同时也增加了编程的复杂性。
正是这种编程视角的层级差异,构成了当前高性能 GPU 编程领域的核心挑战:缺乏一套能够 “跨内核(Kernel)、线程块(Block)、线程(Thread)三个层级” 联合搜索最优计算方案,并自动验证方案正确性的系统。现有框架要么局限于单一层级的优化(例如 Triton 仅针对块内部逻辑进行优化,而 Cutlass 则需要开发者手动协调全层级的适配),要么无法在多层级协同后确保计算结果的准确性。这一问题在大型语言模型(LLM)推理等复杂张量计算场景中,会显著增加开发成本与优化难度。
0x02 总体思路
MegaKernel 可被视为一种 grid 级(网格级)的内核抽象。与 CUDA 的 thread 级(线程级)抽象、Triton 的 block 级(块级)抽象不同,它提供了层次更高的抽象能力,允许开发者在 grid 级开展编程工作。这种抽象设计能让开发者更灵活地管理 GPU 上的计算资源,进而实现更高效的内核生成与执行。
MPK 的工作原理主要包含以下两部分:
MPK 编译器:负责将大语言模型(LLM)的计算图转换为经过优化的任务图。
MPK 运行时系统:在单个巨型内核内部执行任务图,以此达成高吞吐量与低延迟的目标。
2.1 编译过程
模型翻译:将 PyTorch 框架下的模型翻译为 MPK 的指令集,这一步骤本质上相当于用 MPK 的指令重新构建模型的过程。尽管 PyTorch 具备强大的自动微分与优化能力,但要将模型完整转换为 MPK 的指令集,仍需进行大量手动调整与优化操作。
任务图生成:编译器会将翻译后的模型进一步转换为细粒度任务图。该任务图属于有向无环图(DAG),图中每个节点代表一项具体任务,节点间的边则代表任务之间的依赖关系。这一步骤要求编译器能够准确识别并优化任务间的依赖关系,为后续高效调度奠定基础。
2.2 执行过程
任务调度:将生成的任务图交付调度器执行。调度器负责管理 GPU 中的流式多处理器(SM),并通过 warp specialization(线程束特化)技术将 SM 划分为 worker(工作单元)与 scheduler(调度单元)。这种设计与数据处理领域的 actor 模型(角色模型)相似:scheduler 负责协调任务的执行顺序,worker 则负责具体执行分配到的任务。
性能优化:在小模型与低 Batch(批次)场景下,MPK 通过多种方式显著降低延迟,具体包括:消除内核启动开销、打破内核边界限制、实现细粒度的 SM 调度,以及对任务特定模式进行融合。
0x03 通过代码来打通流程
我们以demo_chat.py为例来进行全局打通。在此文件中会将Python模型结构映射为Mirage的计算图表示,然后编译为高效的持久化CUDA内核执行。
3.1 核心模块说明
3.1.1 三层结构化图模型
Mirage 实现了多层次计算图表示(μGraphs),通过 kernel-graph、block-graph 和 thread-graph 这三层结构化图模型,精确映射了 GPU 程序从内核到线程的执行逻辑与存储层级。这种三层结构与 CUDA 程序的执行层级及 GPU 的存储体系紧密对应,每层都清晰定义了“算子类型 - 张量存储 - 核心功能”的关联。
三层图功能如下:
Kernel Graph 是最高计算图,定义整个执行流程。通过自定义操作管理多个block graph
Block Graph 是嵌套在自定义操作中,定义线程块执行序列
Thread Graph是最低层,定义线程级别执行细节
3.1.2 PersistentKernel
PersistentKernel 作为计算图的容器和执行器,提供了从计算图构建、优化到执行的过程。
persistent_kernel.py是 PersistentKernel的Python接口,本质是Python到CUDA持久化内核系统的桥梁,允许用户用python定义复杂的计算图,然后在GPU上高效执行。
3.1.3 层级关系
计算图与 PersistentKernel 的关系如下:
包含关系:PersistentKernel 内部包含并管理一个 Kernel graph
构建关系:通过 PersistentKernel 的各种layer方法构建计算图。
转换关系:PersistentKernel 将计算图转换为可执行的任务图
执行关系:PersistentKernel 是计算图的执行引擎。
3.1.4 数据流关系
数据流关系可以近似如下图所示:
应用层:PersistentKernel.py(创建并管理kernel graph)
│
│
▼
输入张量
│
│
▼
计算图节点(各种layer方法添加)
│
│
▼
任务层:kernel graph(包括所有操作和计算流,即定义张量数据流)
│
│
▼
并行层:block graph(嵌套在自定义操作中,定义线程块执行序列,即定义内存访问模式)
│
│
▼
执行层:task graph(kernel graph生成的可执行任务图,taskDesc是可执行任务,EventDesc管理事件同步和依赖)
│
│
▼
运行时环境:PersistentKernel 执行引擎
│
│
▼
硬件层:Thread graph,在实际GPU线程中执行具体操作
3.2 main()代码
demo_chat.py的main()如下。
def main():
world_size, rank = setup_distributed_environment()
model, tokenizer = load_model_and_tokenizer(rank)
tokens = torch.full((1, MAX_SEQ_LEN), 0, dtype=torch.long, device="cuda")
step_tensor = torch.tensor([0], dtype=torch.int32, device="cuda")
mpk = None
if args.use_mirage:
# 构建计算图
mpk = build_mirage_graph(model, world_size, rank, args, tokens, step_tensor)
positions = torch.arange(MAX_SEQ_LEN).unsqueeze(0).to(model.device)
position_embeddings = model.model.rotary_emb(positions)
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
while True:
prompt_container = [None]
if rank == 0:
try:
prompt = input("> User: ")
prompt_container[0] = prompt
except EOFError:
prompt_container[0] = "exit"
if world_size > 1:
dist.broadcast_object_list(prompt_container, src=0)
prompt = prompt_container[0]
messages.append({"role": "user", "content": prompt})
text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
new_prompt_len = model_inputs.input_ids.shape[-1]
tokens[0, :new_prompt_len] = model_inputs.input_ids[0]
if new_prompt_len < tokens.shape[1]:
tokens[0, new_prompt_len:] = 0
prompt_len = new_prompt_len
if args.use_mirage:
end_pos, run_time, generated_len = run_mirage_generation(
model, mpk, tokens, prompt_len, step_tensor, position_embeddings
)
else:
end_pos, run_time, generated_len = run_pytorch_generation(
model, tokens, prompt_len, step_tensor, position_embeddings
)
if rank == 0:
assistant_response_ids = tokens[0, prompt_len:end_pos]
assistant_response = tokenizer.decode(assistant_response_ids, skip_special_tokens=True)
if world_size > 1:
dist.destroy_process_group()
print("Exiting demo.")
总体过程如下:
模型定义阶段
使用PyTorch/HuggingFace定义模型结构。
加载预训练权重
初始化输入张量和相关参数。
任务图构建阶段
通过KNOperator定义计算操作
构建完整的计算图结构
设置任务配置参数
任务图优化阶段
分析任务间的依赖关系
生成事件描述以管理依赖
对任务进行合理分组以优化执行
任务图转换阶段
生成TaskDesc描述每个计算任务
生成EventDesc描述任务间同步事件
生成CUDA可执行代码
输出JSON配置文件用于运行时加载。
运行时初始化阶段
配置GPU资源(worker,调度器等)
分配GPU内存给任务队列和事件队列
初始化工作队列和调度队列。
设置事件计数器和相关同步机制
持久化内核运行阶段
worker执行具体计算任务
调度器负责任务调度和事件管理
通过事件机制协调任务间的依赖关系
支持多GPU环境下的分布式执行。
3.3 关键步骤
3.3.1 计算图构建过程
此处对应模型翻译过程,即将 PyTorch 框架下的模型翻译为 MPK 的指令集,这一步骤本质上相当于用 MPK 的指令重新构建模型的过程。尽管 PyTorch 具备强大的自动微分与优化能力,但要将模型完整转换为 MPK 的指令集,仍需进行大量手动调整与优化操作。
模型转换为计算图的工作是在build_mirage_graph函数中,其主要步骤如下:
初始化持久化内核
首先构建PersistentKernel实例。
mpk = mi.PersistentKernel(
world_size=world_size,
mpi_rank=rank,
num_workers=96,
num_local_schedulers=48,
num_remote_schedulers=0,
max_seq_length=4096,
eos_token_id=model.config.eos_token_id,
meta_tensors=[step_tensor, tokens_tensor],
profiler_tensor=profiler_tensor,
)
定义张量
将模型权重和中间张量添加到计算图中。
# 输入张量
x = mpk.attach_input(torch_tensor=input_tokens, name="input_token")
# 位置编码
positions = torch.arange(MAX_SEQ_LEN).unsqueeze(0).to(model.device)
position_embeddings = model.model.rotary_emb(positions)
x = mpk.attach_input(torch_tensor=input_tokens, name="input_token")
cos_pos_embed = mpk.attach_input(
torch_tensor=position_embeddings[0][0, :MAX_CONTEXT_LEN, :],
name="cos_position_embedding",
)
sin_pos_embed = mpk.attach_input(
torch_tensor=position_embeddings[1][0, :MAX_CONTEXT_LEN, :],
name="sin_position_embedding",
)
# 计算图的中间结果张量
embed_out = mpk.new_tensor(dims=(batch_size, hidden_size), dtype=mi.bfloat16, name="embed_out")
attn_in = mpk.new_tensor(dims=(batch_size, fused_outdim_1 // world_size), dtype=mi.bfloat16, name="attn_in")
attn_out = mpk.new_tensor(dims=(batch_size, num_local_q_heads * head_dim), dtype=mi.bfloat16, name="attn_out")
is_nvshmem = "nvshmem_tensor" if world_size > 1 else "cuda_tensor"
attn_proj_out = mpk.new_tensor(dims=(batch_size, hidden_size), dtype=mi.bfloat16, name="attn_proj_out", io_category=is_nvshmem)
allreduce_buf = mpk.new_tensor(dims=(world_size, batch_size, hidden_size), dtype=mi.bfloat16, name="all_reduce_buf", io_category=is_nvshmem)
attn_allreduce_out = mpk.new_tensor(dims=(batch_size, hidden_size), dtype=mi.bfloat16, name="attn_allreduce_out", io_category=is_nvshmem)
mlp_mid = mpk.new_tensor(dims=(batch_size, fused_outdim_2 // world_size), dtype=mi.bfloat16, name="mlp_mid")
mlp_out = mpk.new_tensor(dims=(batch_size, hidden_size), dtype=mi.bfloat16, name="mlp_out", io_category=is_nvshmem)
mlp_final = mpk.new_tensor(dims=(batch_size, hidden_size), dtype=mi.bfloat16, name="mlp_final", io_category=is_nvshmem)
argmax_in = mpk.new_tensor(dims=(batch_size, vocab_size), dtype=mi.bfloat16, name="argmax_in")
argmax_part_value = mpk.new_tensor(dims=(batch_size, 96), dtype=mi.bfloat16, name="argmax_part_value")
argmax_part_index = mpk.new_tensor(dims=(batch_size, 96), dtype=mi.int64, name="argmax_part_index")
argmax_out = mpk.new_tensor(dims=(batch_size, 1), dtype=mi.int64, name="argmax_out")
构建计算层
通过调用各种layer方法将模型层添加到计算图。此处会把HuggingFace模型权重映射到Mirage张量。也可以融合张量以提高计算效率。
# --- Define the Model Graph ---
w_embed = mpk.attach_input(torch_tensor=model.model.embed_tokens.weight, name="embed_tokens")
mpk.embed_layer(input=x, weight=w_embed, output=embed_out, grid_dim=(1, 1, 1), block_dim=(128, 1, 1))
x = embed_out
for i, layer in enumerate(model.model.layers):
# Attention block
w_norm_attn = mpk.attach_input(torch_tensor=layer.input_layernorm.weight, name=f"layer_{i}_input_layernorm")
w_q = mpk.attach_input(torch_tensor=layer.self_attn.q_proj.weight, name=f"layer_{i}_q_proj")
w_k = mpk.attach_input(torch_tensor=layer.self_attn.k_proj.weight, name=f"layer_{i}_k_proj")
w_v = mpk.attach_input(torch_tensor=layer.self_attn.v_proj.weight, name=f"layer_{i}_v_proj")
w_qkv = mpk.fuse_tensors(inputs=[w_q, w_k, w_v], fused_dim=0, num_groups=num_local_kv_heads, name=f"layer_{i}_qkv_proj")
mpk.rmsnorm_linear_layer(input=x, weight_norm=w_norm_attn, weight_linear=w_qkv, output=attn_in, grid_dim=(96, 1, 1), block_dim=(128, 1, 1))
w_q_norm = mpk.attach_input(torch_tensor=layer.self_attn.q_norm.weight, name=f"layer_{i}_q_norm")
w_k_norm = mpk.attach_input(torch_tensor=layer.self_attn.k_norm.weight, name=f"layer_{i}_k_norm")
k_cache = mpk.attach_input(torch_tensor=model.model.kv_cache[0][i], name=f"layer_{i}_k_cache")
v_cache = mpk.attach_input(torch_tensor=model.model.kv_cache[1][i], name=f"layer_{i}_v_cache")
mpk.attention_layer(input=attn_in, q_norm=w_q_norm, k_norm=w_k_norm, k_cache=k_cache, v_cache=v_cache, cos_pos_embed=cos_pos_embed, sin_pos_embed=sin_pos_embed, output=attn_out, grid_dim=(batch_size, num_local_kv_heads, 1), block_dim=(128, 1, 1))
w_o_proj = mpk.attach_input(torch_tensor=layer.self_attn.o_proj.weight, name=f"layer_{i}_o_proj")
mpk.linear_with_residual_layer(input=attn_out, weight=w_o_proj, residual=x, output=attn_proj_out, grid_dim=(hidden_size // 64, 1, 1), block_dim=(128, 1, 1))
x = attn_proj_out
if world_size > 1:
mpk.allreduce_layer(input=attn_proj_out, buffer=allreduce_buf, output=attn_allreduce_out, grid_dim=(hidden_size // 64, 1, 1), block_dim=(128, 1, 1))
x = attn_allreduce_out
# MLP block
residual_mlp = x
w_norm_mlp = mpk.attach_input(torch_tensor=layer.post_attention_layernorm.weight, name=f"layer_{i}_post_attn_layernorm")
w_gate_proj = mpk.attach_input(torch_tensor=layer.mlp.gate_proj.weight, name=f"layer_{i}_gate_proj")
w_up_proj = mpk.attach_input(torch_tensor=layer.mlp.up_proj.weight, name=f"layer_{i}_up_proj")
w_gatedup = mpk.fuse_tensors(inputs=[w_gate_proj, w_up_proj], fused_dim=0, num_groups=1, name=f"layer_{i}_gatedup_proj")
mpk.rmsnorm_linear_layer(input=x, weight_norm=w_norm_mlp, weight_linear=w_gatedup, output=mlp_mid, grid_dim=(96, 1, 1), block_dim=(128, 1, 1))
w_down_proj = mpk.attach_input(torch_tensor=layer.mlp.down_proj.weight, name=f"layer_{i}_down_proj")
mpk.silu_mul_linear_with_residual_layer(input=mlp_mid, weight=w_down_proj, residual=residual_mlp, output=mlp_out, grid_dim=(hidden_size // 64, 1, 1), block_dim=(128, 1, 1))
x = mlp_out
if world_size > 1:
mpk.allreduce_layer(input=mlp_out, buffer=allreduce_buf, output=mlp_final, grid_dim=(hidden_size // 64, 1, 1), block_dim=(128, 1, 1))
x = mlp_final
# Final layer
w_final_norm = mpk.attach_input(torch_tensor=model.model.norm.weight, name="model_norm_weight")
w_lm_head = mpk.attach_input(torch_tensor=lm_head_weight, name="lm_head")
mpk.rmsnorm_linear_layer(input=x, weight_norm=w_final_norm, weight_linear=w_lm_head, output=argmax_in, grid_dim=(96, 1, 1), block_dim=(128, 1, 1))
# Argmax
mpk.argmax_partial_layer(input=argmax_in, output=(argmax_part_value, argmax_part_index), grid_dim=(96, 1, 1), block_dim=(128, 1, 1))
mpk.argmax_reduce_layer(input=(argmax_part_value, argmax_part_index), output=argmax_out, grid_dim=(1, 1, 1), block_dim=(128, 1, 1))
3.3.2 任务图生成
此处对应任务图生成:编译器会将翻译后的模型进一步转换为细粒度任务图。该任务图属于有向无环图(DAG),图中每个节点代表一项具体任务,节点间的边则代表任务之间的依赖关系。这一步骤要求编译器能够准确识别并优化任务间的依赖关系,为后续高效调度奠定基础。
调用compile()方法生成最终的执行图。compile()函数内会执行:
生成任务图。
创建CUDA代码。
调用nvcc编译器。
创建Python绑定模块。
mpk.compile()
print("Mirage graph compiled.")
return mpk