https://github.com/THUDM/slime/blob/c525704f/docs/en/get_started/usage.md
使用指南
slime 参数介绍
在使用 slime 时,传递参数主要用于以下目的:
- 分配集群中的一部分 GPU 用于训练,另一部分用于推理。
- 为训练部分加载 Megatron。
- 为推理部分加载 SGLang。
- 配置 RL 训练所需的超参数。
遵循这个顺序,我们需要配置以下这些参数:
集群资源分配
集群资源分配主要有四个参数:
--actor-num-nodes: RL actor 训练所需的节点数。--actor-num-gpus-per-node: RL actor 训练每个节点的 GPU 数量。--rollout-num-gpus: rollout (推理) 所需的总 GPU 数量。--rollout-num-gpus-per-engine: 每个推理引擎的 GPU 数量。这个参数类似于 SGLang 的tp_size。当进行多节点服务时,这个值应该是总 GPU 数。例如,如果使用 2 个节点和 16 个 GPU 服务一个模型,这个值就应该是 16。
之所以不使用像--sglang-tp-size这样的参数,是因为我们未来可能会考虑支持 SGLang 的dp_size参数,这意味着一个 engine 可能包含多个 SGLang server(目前仅支持--sglang-enable-dp-attention条件下的--sglang-dp-size)。
在默认配置下,我们使用这些参数通过 Ray 分配actor_num_nodes * actor_num_gpus_per_node个 GPU 用于训练,以及rollout_num_gpus个 GPU 用于推理,从而实现训练和推理资源的隔离。
对于训练和推理的混合部署 (co-located),你还需要配置:
--colocate: 开启训练和推理的混合部署。开启后,会忽略--rollout-num-gpus,并使训练和推理的 GPU 数量相等。
加载 Megatron
和 SGLang、vLLM、Hugging Face Trainer 等工具不同,Megatron 无法直接读取 Hugging Face 的 checkpoint。取而代之的是,用户必须为待训练的模型配置参数,并加载 Megatron 自己的 checkpoint 格式。
通常,我们需要进行三个准备步骤:
- 配置模型参数。
- 配置并行策略和其他优化。
- 配置待加载的 checkpoint。
关于 Megatron 的一些定制化以及 slime 是如何集成 Megatron 的原理,请参考「如何使用 Megatron」章节。
配置模型参数
以 qwen3 4B 为例,我们需要这些参数:
MODEL_ARGS=(--num-layers36--hidden-size2560--ffn-hidden-size9728--swiglu --vocab-size151936--disable-bias-linear# attn head--num-attention-heads32--group-query-attention --num-query-groups8--kv-channels128--qk-layernorm# norm--normalization"RMSNorm"--norm-epsilon 1e-6# rope--use-rotary-position-embeddings --rotary-base1000000)我们为常见的模型在 scripts/models 中提供了配置,你可以直接复用。如果你也使用 Megatron 进行预训练/SFT,你可以直接复用你预训练/SFT 设置中的模型配置。
注意:
- slime 会加载
PYTHONPATH下的 Megatron 的所有参数,因此你可以在你环境中的 Megatron 内部找到参数及其描述。 - slime 使用数据打包(也称 varlen 或 thd)进行训练,无需配置
--seq-length或--max-positional-embedding,因为这些参数不影响训练后模型的最大上下文长度。
设置并行策略和重计算
Megatron 是目前优化最全面的训练框架之一。使用 Megatron 的一个主要原因就是追求其卓越的性能。这里简要介绍如何配置 Megatron 的并行策略和重计算。
- 这里我们列出 Megatron 的并行策略。关于这些策略之间权衡的更详细讨论,请参考更专业的讨论:
--tensor-model-parallel-size: TP (张量并行)--sequence-parallel: Megatron 的 SP 是对 TP 的一种优化。建议在使用 TP 时始终开启 SP。--pipeline-model-parallel-size: PP (流水线并行)--context-parallel-size: Megatron 的 CP,也称为序列并行,通常对应于 Ring Attention。--expert-model-parallel-size: MoE 的 EP (专家并行),每个 GPU 拥有num_experts / ep_size个专家。--expert-tensor-parallel-size: Megatron 支持为 MoE 专家使用与模型其他部分不同的tp_size,我们通常称之为 ETP。
- 对于重计算,Megatron 中通常配置以下标志:
--recompute-granularity: 可以设置为full或selective。full表示完全重计算,而selective重计算的内容较少。如果不配置,则不进行重计算。--recompute-method: 通常uniform就足够了。--recompute-num-layers: 每组进行重计算的层数。通常设为 1 即可。
加载 Megatron Checkpoints
Megatron 支持多种其自定义的 checkpoint 格式。以下是两种比较常见的:
- 曾经主流的
torch格式 (对应--ckpt-format torch)。 - 当前推荐的
torch_dist格式 (对应--ckpt-format torch_dist)。
torch格式是 Megatron 较旧的存储格式。其结构由mp_rank_xxx这样的目录组成,每个目录对应于特定并行分区下每个 rank 存储的 checkpoint。因此,在加载torch格式的 checkpoint 时,你必须保证 checkpoint 的并行策略与训练任务的并行策略一致。
我们推荐使用torch_dist格式,因为它支持自动并行切分,这意味着不同并行设置的训练任务可以共享同一个 checkpoint,这方便得多。torch_dist也是开源 Megatron 中的默认格式。一个torch_dist格式的 checkpoint 通常包含一组.distcp文件。使用torch_dist时,你可以通过 README 中描述的 checkpoint 转换方法,实现从 Hugging Face 到torch_dist以及反之的转换。
在存储结构上,一个 Megatron checkpoint 通常看起来是这样的,假设存储路径是/ckpt/:
--/ckpt/|-- latest_checkpointed_iteration.txt|-- iter_0000100/|-- _0_0.distcp|-- _0_1.distcp|--...|-- iter_0000200/|-- iter_0000300/|--...latest_checkpointed_iteration.txt文件记录了最新的训练步数。加载模型时不应该直接传入/ckpt/iter_xxxxxxx,而应该传入/ckpt/,并通过--ckpt-step来选择对应的训练步数(如果不使用--ckpt-step,则会从latest_checkpointed_iteration.txt中读取步数)。
使用 slime 时,有三个用于加载和保存 checkpoint 的参数:
--ref-load: 参考模型 (reference model) 的 Megatron checkpoint。--load: actor 的 Megatron checkpoint。如果未设置--load,或者指定的目录不存在或不包含latest_checkpointed_iteration.txt,actor 将从--ref-load的 checkpoint 初始化。--save: 保存 actor checkpoint 的路径。
注意:
- 无论 checkpoint 的存储方式如何(即
--ckpt-format如何设置),Megatron 都可以加载torch和torch_dist两种格式。
加载 SGLang
加载 SGLang 非常简单。你只需要:
--hf-checkpoint: 用于初始化 SGLang 的 Hugging Face checkpoint。
注意:
- slime 会在第一次训练前将 Megatron 的参数同步给 SGLang,因此
--hf-checkpoint无需包含最新的训练参数,并且在恢复训练时也无需更改 HF checkpoint。 - 默认情况下,SGLang 从 Hugging Face checkpoint 中的
config.json读取最大上下文长度。你可以使用--sglang-context-length参数覆盖此值以支持更长的推理。 - 在混合部署训练和推理时,尽管 Megatron 和 SGLang 会依次卸载,但它们仍需要为彼此留出一些内存。你需要通过减小
--sglang-mem-fraction-static来调整 SGLang 的总显存使用率。
关于 SGLang 的一些定制化以及 slime 是如何集成 SGLang 的原理,请参考「如何使用 SGLang」章节。
数据格式
目前,slime 只支持加载.jsonl格式的文件,其中文件的每一行都是一个 JSON 对象。单个数据条目的示例(展开后)如下:
{"prompt":[{"content":"Solve the following math problem step by step. The last line of your response should be of the form Answer: \\boxed{$Answer} where $Answer is the answer to the problem.\n\nIn triangle $ABC$, $\\sin \\angle A = \\frac{4}{5}$ and $\\angle A < 90^\\circ$. Let $D$ be a point outside triangle $ABC$ such that $\\angle BAD = \\angle DAC$ and $\\angle BDC = 90^\\circ$. Suppose that $AD = 1$ and that $\\frac{BD}{CD} = \\frac{3}{2}$. If $AB + AC$ can be expressed in the form $\\frac{a\\sqrt{b}}{c}$ where $a, b, c$ are pairwise relatively prime integers, find $a + b + c$.\n\nRemember to put your answer on its own line after \"Answer:\".","role":"user","step_loss_mask":1,}],"label":"34"}这对应于以下配置:
--input-key prompt --label-key label --apply-chat-template请注意,这里的step_loss_mask(默认为1)用于SFT阶段。如果设置为0,该轮对话不会对最终的loss产生贡献;如果设置为1,slime 将使用常规的loss_mask。
此外,我们提供了一个metadata_key,默认为"metadata"。读取时,slime 会从数据中加载元数据,这对于自定义数据生成或创建自定义奖励模型很有帮助。
RL 训练的超参数
--advantage-estimator: 指定训练过程中的 RL 算法。目前支持的算法包括:grpo(https://arxiv.org/abs/2402.03300)gspo(https://arxiv.org/abs/2507.18071)reinforce_plus_plus和reinforce_plus_plus_baseline(https://arxiv.org/abs/2501.03262)ppo(https://arxiv.org/abs/1707.06347)
--calculate-per-token-loss: 默认情况下,Slime 以样本为单位计算损失,即mean(sum(sample_i) / len(sample_i))。启用此标志后,将以 token 为单位计算损失,即sum(sum(sample_i)) / sum(len(sample_i))。--use-tis: 启用此设置以使用 TIS (截断重要性采样 Truncated Importance Sampling) (https://fengyao.notion.site/off-policy-rl)。
自定义 Rollout 函数
slime 支持不同程度地自定义数据生成 (rollout)。
默认情况下,它使用 slime/rollout/sglang_example.py 中的
generate_rollout函数进行数据生成。该文件实现了一个基于 SGLang 的异步 (asyncio) 数据生成流程,并支持动态采样和部分 rollout 等功能。你可以使用
--rollout-function-path参数完全替换 sglang_example.py 中的generate_rollout。你只需确保通过--rollout-function-path传递的函数签名如下:defgenerate_rollout(args,rollout_id,data_buffer,evaluation=False)->list[list[Sample]]:""" Args: args: 完整的参数 rollout_id: int, rollout 的 id,用于确定性地生成数据 data_buffer: 用于存储生成样本的数据缓冲区 evaluation: bool, 本次 rollout 是否用于评估 Returns: list[list[Sample]]: 本次 rollout 生成的样本列表 """...returnsamples其中:
args: 运行 slime 时使用的完整参数。rollout_id: 当前数据生成轮次的 ID,用于在恢复训练时确保数据顺序。data_buffer: slime 中全局唯一的数据缓冲区,可用于获取初始提示、数据 ID,并存储部分生成的样本以供后续使用。evaluation: 一个布尔值,指示该 rollout 是否用于评估。你可以使用--eval-function-path配置一个单独的评估函数。返回的
Sample类型定义在 slime/utils/types.py 中。实现时,你需要确保以下字段被正确设置:tokens: prompt + response 的 tokens。response_length: response 的总长度。对于多轮任务,这是第一轮 prompt 之后剩余 token 的长度。reward: 此数据样本的奖励。truncated: 此数据样本是否被截断,类似于 SGLang 中的finish_reason == length。
如果存在工具调用或多轮使用等场景,请确保
loss_mask正确:loss_mask的长度应与response_length相同,需要计入 loss 计算的 token 对应位置为1,应被屏蔽的为0。
在某些情况下,你可能只需要替换数据生成逻辑。你可以使用
--custom-generate-function-path来实现。该函数的一个简化实现如下:asyncdefgenerate(args,sample:Sample,sampling_params)->Sample:globalTOKENIZERifTOKENIZERisNone:TOKENIZER=AutoTokenizer.from_pretrained(args.hf_checkpoint,trust_remote_code=True)# 向 router 发送请求output=awaitpost(f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate",{"text":sample.prompt,"sampling_params":sampling_params,})prompt_tokens_ids=TOKENIZER(sample.prompt,add_special_tokens=False)["input_ids"]response_token_ids=TOKENIZER(output["text"],add_special_tokens=False)["input_ids"]# 设置 samplesample.tokens=prompt_tokens_ids+response_token_ids sample.response_length=len(response_token_ids)sample.truncated=output["meta_info"]["finish_reason"]["type"]=="length"sample.response=output["text"]sample.aborted=output["meta_info"]["finish_reason"]["type"]=="abort"returnsample更完整的版本请参考 slime/rollout/sglang_example.py。
有时,你可能还需要支持自定义的奖励模型。这可以通过设置
--custom-rm-path来配置。
如何使用 SGLang
slime 通过HttpServerEngineAdapter这个中间层,使用 SGLang 实现了一个基于 Server 的引擎。
参数配置
slime 通过使用 SGLang 的ServerArgs.add_cli_args集成了几乎所有的 SGLang 参数。当设置 SGLang 的参数时,你需要加上--sglang-前缀。例如:
- 在混合部署训练和推理时,你通常需要限制
--mem-fraction-static。这个参数应改为--sglang-mem-fraction-static。 - 在训练期间,如果你希望 SGLang 推理的长度超过 Hugging Face checkpoint 的
config.json中指定的最大上下文长度,你需要使用--context-length,在 slime 中它变成--sglang-context-length。 - 对于多节点大型 EP 推理,你可能需要
--enable-ep-moe、--enable-dp-attention、--dp-size、--enable-deepep-moe等。这些可以分别作为--sglang-enable-ep-moe、--sglang-enable-dp-attention、--sglang-dp-size和--sglang-enable-deepep-moe传入。
一些与 slime 资源调度相关的参数由 slime 自己配置,例如:
- slime 中的
--tp-size是通过--rollout-num-gpus-per-engine设置的。 - slime 中的
--model-path是通过--hf-checkpoint设置的。
SGLang 参数集成到 slime 的方式可以在 slime/backends/sglang_utils/arguments.py 中找到。
如何使用 Router
slime 使用 sglang-router 来管理训练过程中的 SGLang server。你可以使用--sglang-router-ip和--sglang-router-port配置 sglang-router 的地址。如果未配置,默认会在集群内部启动一个 router。
启动后,所有的 SGLang server 将通过/add_worker端点向 router 注册。在实际生成数据时,你只需要向 router 发送 HTTP 请求,router 会进行负载均衡并将请求转发给 server。
当你使用--sglang-router-ip和--sglang-router-port配置外部 router 时,slime 不会启动内部 router,而是将其所有 server 注册到这个外部 router。然后你可以使用这个外部 router 的地址来实现更复杂的数据生成工作流。注意,该 router 支持 OpenAI 兼容的 API。
如何使用 Megatron
slime 通过复用megatron.training目录下的通用函数(如parse_args、save_checkpoint和load_checkpoint)来支持不同且轻度修改的 Megatron 版本。因此,在使用时,你必须确保 Megatron 在PYTHONPATH中是可访问的,例如,在运行时添加export PYTHONPATH=/root/Megatron-LM。
参数配置
slime 通过使用from megatron.training.arguments import parse_args直接导入当前环境中 Megatron 的所有参数。如果你使用的 Megatron 版本有在parse_args之外定义的参数,你可以像 train.py 中那样将它们传入进行配置,例如:
if__name__=="__main__":try:frompretrain_gptimportextra_args_providerexcept:extra_args_provider=Noneargs=parse_args(extra_args_provider)train(args)自定义参数
在一些定制的 Megatron 实现中,需要在初始化或训练步骤前后执行特殊操作。我们为此添加了以下插件:
--custom-megatron-init-path: 添加一些初始化调用。--custom-megatron-before-log-prob-hook-path: 在计算对数概率 (log probability) 之前调用。--custom-megatron-before-train-step-hook-path: 在每个训练步骤之前调用。例如,你可以用它来混合特殊的训练损失。