news 2026/4/17 19:36:30

MPK(Mirage Persistent Kernel)源码笔记(1)--- 基础原理

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
MPK(Mirage Persistent Kernel)源码笔记(1)--- 基础原理

x00 概要

CMU 贾志豪老师团队提出的MPK(Mirage Persistent Kernel)是依托 Mirage 编译器生态的创新运行时系统,其核心能力在于将多GPU环境下大语言模型(LLM)推理任务自动转换为适配GPU架构的高性能巨型内核(megakernel)。MPK的关键优势在于将传统由CPU负责的内核调度和任务依赖管理工作转移到GPU端,通过“长期驻留的巨型内核(Persistent Kernel)”自主完成,同时统筹GPU内部计算与跨GPU通信任务。这种设计不仅大幅削减了CPU-GPU交互带来的内核启动开销,还通过计算与通信的细粒度重叠,将推理延迟优化至接近硬件物理极限,显著提升推理效率。

0.1 传统LLM推理框架的瓶颈

传统LLM推理框架的流程存在固有瓶颈:CPU需逐个发起CUDA内核调用(如矩阵乘法、激活函数计算),待GPU执行完当前内核并反馈后,再触发下一个内核。这种“CPU发起-GPU执行-CPU等待”的循环,会产生频繁的CPU-GPU通信与内核启动开销,尤其在自回归生成场景中,单次token生成需多轮内核调用,开销会持续累积,严重拖累整体推理性能。

0.2 MPK的流程重构

MPK彻底重构了这一流程:它仅需CPU在推理初始化阶段,向GPU提交一个“永不主动退出的persistent_kernel”,之后所有任务分派(如层间计算顺序)、依赖管理(如等待前一层结果再执行下一层)均由GPU内部自主完成。此时CPU的角色从“实时调度的包工头”转变为“启动初始化的门卫”,仅负责触发首次内核启动,后续不再参与任何具体调度。

0.3 MPK的关键优势

MPK通过将多GPU的LLM推理任务转换为高性能的巨型内核,从根本上改变了GPU的运行模式。它不仅减少了内核启动开销,还通过细粒度的软件流水线和计算通信重叠,显著提高了推理效率。MPK提供了一种全新的思路,将性能优化的重心从“如何调用优化库”转移到了“如何为整个模型生成一个最优的、原生的执行体”,在多GPU环境下实现了更高的吞吐量和更低的延迟。

0x01 问题

重新设计类似 Mirage 的 MegaKernel的优势,是将所有计算和通信融合进一个单一的巨型内核(也称为持续内核)是降低大语言模型推理延迟的最有效方法之一。这种方法通过启动一个GPU内核来执行整个模型,从逐层计算到GPU间通信,整个过程无需中断。尽管有这些优势,将LLM编译成巨型内核仍然极具挑战性。

1.1 现有框架问题

现有框架难以支持单一的巨型内核。

现有的高级ML框架,如PyTorch、Triton和TVM,并不原生支持端到端巨型内核生成。

现代LLM系统由各种不同的专用内核库构建而成,这种碎片化使得将整个推理流水线整合进一个单一的、统一的内核变得非常困难。

高性能GPU内核的手工编写需要大量的专家知识,如何自动生成高性能内核代码是一个痛点问题。传统做法依赖于专家编写好的内核或者手工融合规则,但这些方法维护成本高,容易漏掉跨内核/层级组合优化的机会。

1.2 编程抽象层级

从编程抽象层级上来看,也缺乏最优系统。

1.2.1 GPU架构

下图展示了当今 GPU 的层次结构。GPU 上的计算被组织为内核,每个内核都是一个函数,以单程序多数据(SPMD)的方式在多个 GPU 核心上同时执行。一个内核包括一个线程块网格,每个线程块在一个 GPU 流式多处理器上执行,并包括多个线程来对单个数据元素进行计算。每个线程都与一个每线程寄存器文件相关联,并且线程块内的所有线程都可以访问共享内存以启用集体操作。最后,内核的所有输入和输出都存储在 GPU 设备内存中。

下图是GPU hierarchy。

gpu_hierarchy

下图为GPU 计算架构和编程抽象示意图

image

1.2.2 编程视角

Triton 是一款高级 GPU 编程框架,其编程视角主要聚焦于块(Block)级别。该框架的设计允许开发者以块为单位进行编程,而块内部的优化工作则由 Triton 编译器自动完成。这种设计模式使开发者能够将精力集中在高层逻辑的构建上,无需深入研究线程(Thread)级别的细节实现。Triton 的核心优势在于其简洁的编程模型和自动化优化能力,这使得它在处理复杂并行任务时具有更高的效率。

Cutlass 则属于底层 GPU 编程库,其编程视角覆盖了块(Block)、线程束(Warp)与线程(Thread)的完整层级。Cutlass 提供了丰富的 CUDA 模板和底层控制接口,开发者可以利用这些工具精细调控每个线程的行为,从而实现高度优化的计算内核。这种细粒度的控制能力让 Cutlass 在对性能有极致要求的场景中表现出色,但同时也增加了编程的复杂性。

正是这种编程视角的层级差异,构成了当前高性能 GPU 编程领域的核心挑战:缺乏一套能够 “跨内核(Kernel)、线程块(Block)、线程(Thread)三个层级” 联合搜索最优计算方案,并自动验证方案正确性的系统。现有框架要么局限于单一层级的优化(例如 Triton 仅针对块内部逻辑进行优化,而 Cutlass 则需要开发者手动协调全层级的适配),要么无法在多层级协同后确保计算结果的准确性。这一问题在大型语言模型(LLM)推理等复杂张量计算场景中,会显著增加开发成本与优化难度。

0x02 总体思路

MegaKernel 可被视为一种 grid 级(网格级)的内核抽象。与 CUDA 的 thread 级(线程级)抽象、Triton 的 block 级(块级)抽象不同,它提供了层次更高的抽象能力,允许开发者在 grid 级开展编程工作。这种抽象设计能让开发者更灵活地管理 GPU 上的计算资源,进而实现更高效的内核生成与执行。

MPK 的工作原理主要包含以下两部分:

MPK 编译器:负责将大语言模型(LLM)的计算图转换为经过优化的任务图。

MPK 运行时系统:在单个巨型内核内部执行任务图,以此达成高吞吐量与低延迟的目标。

2.1 编译过程

模型翻译:将 PyTorch 框架下的模型翻译为 MPK 的指令集,这一步骤本质上相当于用 MPK 的指令重新构建模型的过程。尽管 PyTorch 具备强大的自动微分与优化能力,但要将模型完整转换为 MPK 的指令集,仍需进行大量手动调整与优化操作。

任务图生成:编译器会将翻译后的模型进一步转换为细粒度任务图。该任务图属于有向无环图(DAG),图中每个节点代表一项具体任务,节点间的边则代表任务之间的依赖关系。这一步骤要求编译器能够准确识别并优化任务间的依赖关系,为后续高效调度奠定基础。

2.2 执行过程

任务调度:将生成的任务图交付调度器执行。调度器负责管理 GPU 中的流式多处理器(SM),并通过 warp specialization(线程束特化)技术将 SM 划分为 worker(工作单元)与 scheduler(调度单元)。这种设计与数据处理领域的 actor 模型(角色模型)相似:scheduler 负责协调任务的执行顺序,worker 则负责具体执行分配到的任务。

性能优化:在小模型与低 Batch(批次)场景下,MPK 通过多种方式显著降低延迟,具体包括:消除内核启动开销、打破内核边界限制、实现细粒度的 SM 调度,以及对任务特定模式进行融合。

0x03 通过代码来打通流程

我们以demo_chat.py为例来进行全局打通。在此文件中会将Python模型结构映射为Mirage的计算图表示,然后编译为高效的持久化CUDA内核执行。

3.1 核心模块说明

3.1.1 三层结构化图模型

Mirage 实现了多层次计算图表示(μGraphs),通过 kernel-graph、block-graph 和 thread-graph 这三层结构化图模型,精确映射了 GPU 程序从内核到线程的执行逻辑与存储层级。这种三层结构与 CUDA 程序的执行层级及 GPU 的存储体系紧密对应,每层都清晰定义了“算子类型 - 张量存储 - 核心功能”的关联。

三层图功能如下:

Kernel Graph 是最高计算图,定义整个执行流程。通过自定义操作管理多个block graph

Block Graph 是嵌套在自定义操作中,定义线程块执行序列

Thread Graph是最低层,定义线程级别执行细节

3.1.2 PersistentKernel

PersistentKernel 作为计算图的容器和执行器,提供了从计算图构建、优化到执行的过程。

persistent_kernel.py是 PersistentKernel的Python接口,本质是Python到CUDA持久化内核系统的桥梁,允许用户用python定义复杂的计算图,然后在GPU上高效执行。

3.1.3 层级关系

计算图与 PersistentKernel 的关系如下:

包含关系:PersistentKernel 内部包含并管理一个 Kernel graph

构建关系:通过 PersistentKernel 的各种layer方法构建计算图。

转换关系:PersistentKernel 将计算图转换为可执行的任务图

执行关系:PersistentKernel 是计算图的执行引擎。

3.1.4 数据流关系

数据流关系可以近似如下图所示:

应用层:PersistentKernel.py(创建并管理kernel graph)

输入张量

计算图节点(各种layer方法添加)

任务层:kernel graph(包括所有操作和计算流,即定义张量数据流)

并行层:block graph(嵌套在自定义操作中,定义线程块执行序列,即定义内存访问模式)

执行层:task graph(kernel graph生成的可执行任务图,taskDesc是可执行任务,EventDesc管理事件同步和依赖)

运行时环境:PersistentKernel 执行引擎

硬件层:Thread graph,在实际GPU线程中执行具体操作

3.2 main()代码

demo_chat.py的main()如下。

def main():

world_size, rank = setup_distributed_environment()

model, tokenizer = load_model_and_tokenizer(rank)

tokens = torch.full((1, MAX_SEQ_LEN), 0, dtype=torch.long, device="cuda")

step_tensor = torch.tensor([0], dtype=torch.int32, device="cuda")

mpk = None

if args.use_mirage:

# 构建计算图

mpk = build_mirage_graph(model, world_size, rank, args, tokens, step_tensor)

positions = torch.arange(MAX_SEQ_LEN).unsqueeze(0).to(model.device)

position_embeddings = model.model.rotary_emb(positions)

messages = [{"role": "system", "content": SYSTEM_PROMPT}]

while True:

prompt_container = [None]

if rank == 0:

try:

prompt = input("> User: ")

prompt_container[0] = prompt

except EOFError:

prompt_container[0] = "exit"

if world_size > 1:

dist.broadcast_object_list(prompt_container, src=0)

prompt = prompt_container[0]

messages.append({"role": "user", "content": prompt})

text = tokenizer.apply_chat_template(

messages, tokenize=False, add_generation_prompt=True

)

model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

new_prompt_len = model_inputs.input_ids.shape[-1]

tokens[0, :new_prompt_len] = model_inputs.input_ids[0]

if new_prompt_len < tokens.shape[1]:

tokens[0, new_prompt_len:] = 0

prompt_len = new_prompt_len

if args.use_mirage:

end_pos, run_time, generated_len = run_mirage_generation(

model, mpk, tokens, prompt_len, step_tensor, position_embeddings

)

else:

end_pos, run_time, generated_len = run_pytorch_generation(

model, tokens, prompt_len, step_tensor, position_embeddings

)

if rank == 0:

assistant_response_ids = tokens[0, prompt_len:end_pos]

assistant_response = tokenizer.decode(assistant_response_ids, skip_special_tokens=True)

if world_size > 1:

dist.destroy_process_group()

print("Exiting demo.")

总体过程如下:

模型定义阶段

使用PyTorch/HuggingFace定义模型结构。

加载预训练权重

初始化输入张量和相关参数。

任务图构建阶段

通过KNOperator定义计算操作

构建完整的计算图结构

设置任务配置参数

任务图优化阶段

分析任务间的依赖关系

生成事件描述以管理依赖

对任务进行合理分组以优化执行

任务图转换阶段

生成TaskDesc描述每个计算任务

生成EventDesc描述任务间同步事件

生成CUDA可执行代码

输出JSON配置文件用于运行时加载。

运行时初始化阶段

配置GPU资源(worker,调度器等)

分配GPU内存给任务队列和事件队列

初始化工作队列和调度队列。

设置事件计数器和相关同步机制

持久化内核运行阶段

worker执行具体计算任务

调度器负责任务调度和事件管理

通过事件机制协调任务间的依赖关系

支持多GPU环境下的分布式执行。

3.3 关键步骤

3.3.1 计算图构建过程

此处对应模型翻译过程,即将 PyTorch 框架下的模型翻译为 MPK 的指令集,这一步骤本质上相当于用 MPK 的指令重新构建模型的过程。尽管 PyTorch 具备强大的自动微分与优化能力,但要将模型完整转换为 MPK 的指令集,仍需进行大量手动调整与优化操作。

模型转换为计算图的工作是在build_mirage_graph函数中,其主要步骤如下:

初始化持久化内核

首先构建PersistentKernel实例。

mpk = mi.PersistentKernel(

world_size=world_size,

mpi_rank=rank,

num_workers=96,

num_local_schedulers=48,

num_remote_schedulers=0,

max_seq_length=4096,

eos_token_id=model.config.eos_token_id,

meta_tensors=[step_tensor, tokens_tensor],

profiler_tensor=profiler_tensor,

)

定义张量

将模型权重和中间张量添加到计算图中。

# 输入张量

x = mpk.attach_input(torch_tensor=input_tokens, name="input_token")

# 位置编码

positions = torch.arange(MAX_SEQ_LEN).unsqueeze(0).to(model.device)

position_embeddings = model.model.rotary_emb(positions)

x = mpk.attach_input(torch_tensor=input_tokens, name="input_token")

cos_pos_embed = mpk.attach_input(

torch_tensor=position_embeddings[0][0, :MAX_CONTEXT_LEN, :],

name="cos_position_embedding",

)

sin_pos_embed = mpk.attach_input(

torch_tensor=position_embeddings[1][0, :MAX_CONTEXT_LEN, :],

name="sin_position_embedding",

)

# 计算图的中间结果张量

embed_out = mpk.new_tensor(dims=(batch_size, hidden_size), dtype=mi.bfloat16, name="embed_out")

attn_in = mpk.new_tensor(dims=(batch_size, fused_outdim_1 // world_size), dtype=mi.bfloat16, name="attn_in")

attn_out = mpk.new_tensor(dims=(batch_size, num_local_q_heads * head_dim), dtype=mi.bfloat16, name="attn_out")

is_nvshmem = "nvshmem_tensor" if world_size > 1 else "cuda_tensor"

attn_proj_out = mpk.new_tensor(dims=(batch_size, hidden_size), dtype=mi.bfloat16, name="attn_proj_out", io_category=is_nvshmem)

allreduce_buf = mpk.new_tensor(dims=(world_size, batch_size, hidden_size), dtype=mi.bfloat16, name="all_reduce_buf", io_category=is_nvshmem)

attn_allreduce_out = mpk.new_tensor(dims=(batch_size, hidden_size), dtype=mi.bfloat16, name="attn_allreduce_out", io_category=is_nvshmem)

mlp_mid = mpk.new_tensor(dims=(batch_size, fused_outdim_2 // world_size), dtype=mi.bfloat16, name="mlp_mid")

mlp_out = mpk.new_tensor(dims=(batch_size, hidden_size), dtype=mi.bfloat16, name="mlp_out", io_category=is_nvshmem)

mlp_final = mpk.new_tensor(dims=(batch_size, hidden_size), dtype=mi.bfloat16, name="mlp_final", io_category=is_nvshmem)

argmax_in = mpk.new_tensor(dims=(batch_size, vocab_size), dtype=mi.bfloat16, name="argmax_in")

argmax_part_value = mpk.new_tensor(dims=(batch_size, 96), dtype=mi.bfloat16, name="argmax_part_value")

argmax_part_index = mpk.new_tensor(dims=(batch_size, 96), dtype=mi.int64, name="argmax_part_index")

argmax_out = mpk.new_tensor(dims=(batch_size, 1), dtype=mi.int64, name="argmax_out")

构建计算层

通过调用各种layer方法将模型层添加到计算图。此处会把HuggingFace模型权重映射到Mirage张量。也可以融合张量以提高计算效率。

# --- Define the Model Graph ---

w_embed = mpk.attach_input(torch_tensor=model.model.embed_tokens.weight, name="embed_tokens")

mpk.embed_layer(input=x, weight=w_embed, output=embed_out, grid_dim=(1, 1, 1), block_dim=(128, 1, 1))

x = embed_out

for i, layer in enumerate(model.model.layers):

# Attention block

w_norm_attn = mpk.attach_input(torch_tensor=layer.input_layernorm.weight, name=f"layer_{i}_input_layernorm")

w_q = mpk.attach_input(torch_tensor=layer.self_attn.q_proj.weight, name=f"layer_{i}_q_proj")

w_k = mpk.attach_input(torch_tensor=layer.self_attn.k_proj.weight, name=f"layer_{i}_k_proj")

w_v = mpk.attach_input(torch_tensor=layer.self_attn.v_proj.weight, name=f"layer_{i}_v_proj")

w_qkv = mpk.fuse_tensors(inputs=[w_q, w_k, w_v], fused_dim=0, num_groups=num_local_kv_heads, name=f"layer_{i}_qkv_proj")

mpk.rmsnorm_linear_layer(input=x, weight_norm=w_norm_attn, weight_linear=w_qkv, output=attn_in, grid_dim=(96, 1, 1), block_dim=(128, 1, 1))

w_q_norm = mpk.attach_input(torch_tensor=layer.self_attn.q_norm.weight, name=f"layer_{i}_q_norm")

w_k_norm = mpk.attach_input(torch_tensor=layer.self_attn.k_norm.weight, name=f"layer_{i}_k_norm")

k_cache = mpk.attach_input(torch_tensor=model.model.kv_cache[0][i], name=f"layer_{i}_k_cache")

v_cache = mpk.attach_input(torch_tensor=model.model.kv_cache[1][i], name=f"layer_{i}_v_cache")

mpk.attention_layer(input=attn_in, q_norm=w_q_norm, k_norm=w_k_norm, k_cache=k_cache, v_cache=v_cache, cos_pos_embed=cos_pos_embed, sin_pos_embed=sin_pos_embed, output=attn_out, grid_dim=(batch_size, num_local_kv_heads, 1), block_dim=(128, 1, 1))

w_o_proj = mpk.attach_input(torch_tensor=layer.self_attn.o_proj.weight, name=f"layer_{i}_o_proj")

mpk.linear_with_residual_layer(input=attn_out, weight=w_o_proj, residual=x, output=attn_proj_out, grid_dim=(hidden_size // 64, 1, 1), block_dim=(128, 1, 1))

x = attn_proj_out

if world_size > 1:

mpk.allreduce_layer(input=attn_proj_out, buffer=allreduce_buf, output=attn_allreduce_out, grid_dim=(hidden_size // 64, 1, 1), block_dim=(128, 1, 1))

x = attn_allreduce_out

# MLP block

residual_mlp = x

w_norm_mlp = mpk.attach_input(torch_tensor=layer.post_attention_layernorm.weight, name=f"layer_{i}_post_attn_layernorm")

w_gate_proj = mpk.attach_input(torch_tensor=layer.mlp.gate_proj.weight, name=f"layer_{i}_gate_proj")

w_up_proj = mpk.attach_input(torch_tensor=layer.mlp.up_proj.weight, name=f"layer_{i}_up_proj")

w_gatedup = mpk.fuse_tensors(inputs=[w_gate_proj, w_up_proj], fused_dim=0, num_groups=1, name=f"layer_{i}_gatedup_proj")

mpk.rmsnorm_linear_layer(input=x, weight_norm=w_norm_mlp, weight_linear=w_gatedup, output=mlp_mid, grid_dim=(96, 1, 1), block_dim=(128, 1, 1))

w_down_proj = mpk.attach_input(torch_tensor=layer.mlp.down_proj.weight, name=f"layer_{i}_down_proj")

mpk.silu_mul_linear_with_residual_layer(input=mlp_mid, weight=w_down_proj, residual=residual_mlp, output=mlp_out, grid_dim=(hidden_size // 64, 1, 1), block_dim=(128, 1, 1))

x = mlp_out

if world_size > 1:

mpk.allreduce_layer(input=mlp_out, buffer=allreduce_buf, output=mlp_final, grid_dim=(hidden_size // 64, 1, 1), block_dim=(128, 1, 1))

x = mlp_final

# Final layer

w_final_norm = mpk.attach_input(torch_tensor=model.model.norm.weight, name="model_norm_weight")

w_lm_head = mpk.attach_input(torch_tensor=lm_head_weight, name="lm_head")

mpk.rmsnorm_linear_layer(input=x, weight_norm=w_final_norm, weight_linear=w_lm_head, output=argmax_in, grid_dim=(96, 1, 1), block_dim=(128, 1, 1))

# Argmax

mpk.argmax_partial_layer(input=argmax_in, output=(argmax_part_value, argmax_part_index), grid_dim=(96, 1, 1), block_dim=(128, 1, 1))

mpk.argmax_reduce_layer(input=(argmax_part_value, argmax_part_index), output=argmax_out, grid_dim=(1, 1, 1), block_dim=(128, 1, 1))

3.3.2 任务图生成

此处对应任务图生成:编译器会将翻译后的模型进一步转换为细粒度任务图。该任务图属于有向无环图(DAG),图中每个节点代表一项具体任务,节点间的边则代表任务之间的依赖关系。这一步骤要求编译器能够准确识别并优化任务间的依赖关系,为后续高效调度奠定基础。

调用compile()方法生成最终的执行图。compile()函数内会执行:

生成任务图。

创建CUDA代码。

调用nvcc编译器。

创建Python绑定模块。

mpk.compile()

print("Mirage graph compiled.")

return mpk

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/18 5:19:48

PHP 异常处理全攻略 Try-Catch 从入门到精通完全指南

什么是 Try-Catch&#xff1f; Try-catch 是 PHP 处理异常的机制——程序执行期间发生的意外事件或错误。与其让应用程序崩溃&#xff0c;try-catch 允许你拦截这些错误并优雅地处理它们。 把它想象成一张安全网。你“尝试”执行可能失败的代码&#xff0c;如果失败了&#xf…

作者头像 李华
网站建设 2026/4/18 8:16:26

达梦数据库安装

好的&#xff0c;这是一篇关于达梦数据库&#xff08;DM Database&#xff09;安装的详细指南&#xff0c;包含目录、文字说明和图片位置示意&#xff0c;内容丰富&#xff0c;力求达到3000字的要求。达梦数据库安装与配置详细指南目录引言1.1 达梦数据库简介1.2 安装前准备的重…

作者头像 李华
网站建设 2026/4/5 16:03:40

非线性最优化问题求解器Ipopt介绍

文章目录一、关键输入信息1、优化问题的维度2、优化变量的边界3、优化问题的初始迭代点&#xff1a;4、优化问题的数据结构(Structure)&#xff1a;5、优化问题函数的值&#xff1a;二、C Interface1、Ipopt::TNLP::get_nlp_info2、Ipopt::TNLP::get_bounds_info3、Ipopt::TNLP…

作者头像 李华
网站建设 2026/4/18 6:32:28

springboot人事系统(11545)

有需要的同学&#xff0c;源代码和配套文档领取&#xff0c;加文章最下方的名片哦 一、项目演示 项目演示视频 二、资料介绍 完整源代码&#xff08;前后端源代码SQL脚本&#xff09;配套文档&#xff08;LWPPT开题报告&#xff09;远程调试控屏包运行 三、技术介绍 Java…

作者头像 李华
网站建设 2026/4/18 7:03:03

智慧城市与智慧校园之安防暴力检测 校园打架斗殴检测 街边暴力躁动识别 危险物品识别 智能安防 安防领域智能化 数据集第10319期 (1)

暴力检测数据集 本文档为深度学习相关研究与应用开发&#xff0c;提供暴力检测数据集的核心信息说明。数据集核心信息表项目详情类别共 4 类&#xff0c;分别为非暴力&#xff08;NonViolence&#xff09;、暴力&#xff08;Violence&#xff09;、枪支&#xff08;guns&#x…

作者头像 李华