在 Part 1 中,我们介绍了 verl 的初始化过程,我们进一步介绍 verl 的训练过程,包括rollout部分、make experience部分以及training部分。
在 GRPO 中,单个 step 包含四个阶段:load data -> rollout -> make experience -> update model。区别于前一节的详述,本节会使用伪代码结合源码的方式进行阐述。
flowchart LR subgraph W2["Initialize"] WP[Process Data] --> A direction TB D1[Data Prepare] --> A A[TaskRunner] --> B1[RayPPOTrainer] B1 --> Workers subgraph Workers["Workers"] direction TB WA[ActorRolloutWorker] --> WD[FSDP Engine] WB[CriticWorker] --> WD WC[RewardModelWorker] --> WD WD --> WE[SGLang Engine] end Workers --> C1[Hybrid Engine] end subgraph W3["Train Loop"] direction TB E[DataLoader] --> RolloutBox subgraph RolloutBox["Rollout"] F1[Prepare Data] --> F2[SGLang Async Rollout] F2 --> F3[Multi-turn Chat Process] end RolloutBox --> ExpBox subgraph ExpBox["Make Experience"] G1[Recompute Log Probs] --> G2[Compute Reward] G2 --> G3[Compute Advantage] end ExpBox --> UpdateBox subgraph UpdateBox["Train The Model"] H1[Load FSDP Model Weight] --> H2[Compute Gradient] H2 --> H3[Weights Update] H3 --> H4[Sync Weights] end UpdateBox --> E end W2 --> W3
数据加载与预处理
verl 通过 DataProto
和 RLHFDataset
来实现数据处理。具体来说,在 main_ppo.py
中,我们观察这个函数:
create_rl_dataset 源码
|
|
非常典型,创造一个了 RLHFDataset
实例,并返回。而具体的 RLHFDataset
实现如下:
RLHFDataset 实现
|
|
- 支持从远程存储下载 Parquet 文件到本地缓存,支持共享内存加速文件访问,自动管理文件路径,支持检查点恢复。
- 使用 HuggingFace
datasets
库读取 Parquet 文件,支持多个数据文件的合并,自动处理数据格式转换。 - 根据最大长度过滤过长的 prompts,支持多进程并行处理,可配置的过滤策略。
- 支持图像和视频的多模态输入,解析
<image>
和<video>
标签,将多模态内容转换为结构化格式。 - 添加 chat template 来格式化对话,将文本转换为 token IDs,生成 attn mask 和 position ids。
- padding 到指定长度,支持多种截断策略(left, right, middle, error),生成位置编码。
- 支持训练中断后的恢复,可以从原始文件重新构建数据集,兼容序列化/反序列化。
- 返回包含以下关键字段的字典:
input_ids
,attention_mask
,position_ids
,raw_prompt_ids
,multi_modal_data
,multi_modal_inputs
,index
,tools_kwargs
。
这里最重要的一个参数是 tools_kwargs
,用于为不同的 tools 提供配置参数。它的结构如下:
|
|
比如 Search-R1 的 tools_kwargs
如下:
|
|
具体这些参数是如何调用了一个 tool,我们会留在后续部分继续介绍。
训练入口 RayPPOTrainer.fit()
- 创建 Tracking 日志记录器,设置全局步数,加载检查点,并在训练前进行验证。
- 使用 tqdm 创建进度条,显示训练进度,并设置初始步数。
- 遍历配置的总 epoch 数和数据加载器,每个 train batch 更新多步。
- 从 batch 中分离出用于 rollout 的数据(
input_ids
,attention_mask
,position_ids
等),保留其他数据用于后续处理。 - 调用
ActorRolloutWorker
生成序列,并记录生成时间。 - 处理 REMAX 基线(如果使用):生成确定性基线序列,计算基线奖励,用于 REMAX 优势估计器。
- 为每个样本分配唯一 ID,重复数据以对齐多次采样,计算响应掩码,并可选地进行批次平衡。
- 根据配置使用奖励模型或自定义奖励函数计算 token 级别的奖励分数,支持同步和异步计算。
- 使用 megatron 基于训练开始前的 policy 重新计算 behaviour policy 的 log probabilities,用于重要性采样,同时计算熵值。(原因在 part 1讲过)
- 使用 reference policy 计算 log probs,用于 KL 散度计算。
- 使用 Critic 网络计算状态价值,用于优势函数估计。
- 根据配置的优势估计器(GAE、GRPO、REMAX 等)计算优势函数,支持 KL 惩罚。
- 使用计算出的优势函数更新 Critic 网络参数。
- 在 Critic 预热完成后,使用 PPO 损失函数更新 Actor 网络参数。
- 将生成的序列、输入、输出和分数保存到指定目录。
- 根据配置的频率执行验证,计算验证指标并记录。
- 根据配置的频率保存模型检查点。
- 收集训练指标、时序指标和吞吐量指标,并记录到日志系统。
- 更新进度条,递增全局步数,并在达到总训练步数时结束训练。
- 根据配置在特定步数启用/禁用性能分析,用于调试和优化。
RayPPOTrainer.fit() 源码
|
|
我们究竟在异步什么?
这里很值得分享一个核心问题,对 SGLang 而言,或者对现在的 RL 而言,我们每天说来说去的 async 究竟是什么意思?和 PD 分离一样,async 也有非常多的层面:
Async RL 代表的是在 training rollout 分离的系统上,rollout 只在 update weights 的时候被打断,其余时刻永远 rollout,哪怕 target policy 正在被 training engine 更新。这方面是 AreaL 和 SLIME。
Async Rollout 这个词是特指在 rollout 的时候,把一个 batch requests 拆为单个 request,然后逐个调用
SGLangEngine.generate()
。
乍一听,这没有什么特别的,似乎还会更慢些。但是考虑到 tool call 的问题,这就非常严肃了。假设我们把一整个 batch 的 requests 作为一个 batch 塞给 sglang 似乎还要快些,毕竟对 SGLang 的 scheduler 而言,更好组 batch。但是,一整个 batch 进去,得一整个 batch 出来。这些 batch 里面的 requests 同时返回,同时被 paser 解析查看是否有 tool call 的 parameter,然后发送请求给 tool。如此以来,整个 tool 的调用大概率会拥堵,甚至在我们考虑到如果要加入多个 tool(虽然目前没有)的话,用一个状态机去管理每个 request 的 tool call 状态会成一场噩梦,何况有的 requests 会在多轮里面多次调用 tool。因此,为了方便管理每个 request tool call 的状态机和让 tool 被调度的更加均匀。SGLang 采取了 Async Rollout 策略,也即把一个 batch 的 requests 拆为单个 request,然后逐个异步调用 SGLangEngine.generate()
。这样每个 reqeuest 自己管理自己的状态机,方便维护并且 tool call 效率更高。
理解了这一层,我们可以来看看代码实现:
generate_sequences 源码
|
|
这里明确指出,如果是用了 mutli-turn 训练,则将 batch 的 requests 拆为单个 request,调用 _req_level_generate_sequences
;而不调用 tool 的单轮 RL,仍旧组 batch 直接发送。
我们只观察 _req_level_generate_sequences
的部分源码:
_req_level_generate_sequences 部分源码
|
|
现在来看,asyncio.gather(*[self._async_rollout_a_request(req, do_sample, is_validate, **kwargs) for req in req_list],)
就显得无比清晰了。
数据流管理
我们继续去理解 RayPPOTrainer.fit()
函数,从数据流管理开始。这里我认为最重要的两个类是 DataProto
和 RLHFDataset
。
DataProto
DataProto
是 verl 的数据交换协议,定义在 protocol.py
:
|
|
DataProto
提供标准化的数据交换协议,基于 PyTorch 的 TensorDict,支持张量的批量操作,同时通过 non_tensor_batch
字典来处理 NumPy 数组等非张量数据。meta_info
存储额外的元信息。本身支持的操作挺基础的,典型的比如数据创建、切片、选择、合并、重命名、重复、填充、分块、以及分布式环境下的数据集合与分发。除此之外,DataProto
还通过数据验证 check_consistency()
确保在数据分离和合并过程的一致性。
RLHFDataset
RLHFDataset
是 verl 中用于 RLHF 数据加载的数据集类,继承自 datasets.Dataset
,主要用于处理 Parquet 文件中的数据,包括数据下载、tokenize、过滤、预处理等。
RLHFDataset 源码
|
|
有了 DataProto
和 RLHFDataset
后,我们来观察数据流:
|
|
事实上,只有最初的三步不是 DataProto
,其他都是通过 DataProto
进行数据交换的。具体每步的数据流向如下:
数据流详细分析
A:Parquet
文件
|
|
B:RLHFDataset
|
|
C:DataLoader + collate_fn
|
|
D:DataProto
原始数据
|
|
E:pop
提取生成数据
|
|
F:Rollout
生成
|
|
G:union
合并数据
|
|
H:奖励计算
|
|
I:优势计算
|
|
J:重新计算 log_probs
|
|
K:计算 reference model 的 log_probs
|
|
L:计算 value function
|
|
M1:更新 critic
|
|
M2:更新 actor
|
|
N:返回训练指标
|
|
Rollout
在 part 1 已经讲过了 SGLang 的几个关键函数:
ActorRolloutRefWorker._build_rollout()
SGLangRollout.__init__()
SGLangRollout.AsyncEngine
SGLangRollout._init_inference_engine()
此外,我们还介绍了在“我们究竟在异步什么?“里面介绍了 SGLang 对 multi-turn 场景下的 _req_level_generate_sequences
的特殊实现。我们接着继续分析 SGLang rollout 对 multi-turn 的处理,包括状态机和 tool 调用。
_req_level_generate_sequences
接着上文的讨论,我们继续来看看源代码。
- 如果当前是 tp rank 0,则将一整个 batch 的 prompts 预处理成单个异步请求,并并发执行这些请求以生成序列。rollout 的返回顺序是乱序的,因此需要按照 batch ID 和在 batch 内的 offset 来对返回值重新排序。
- 如果不是 tp rank 0,则将输出请求列表设置为
None
。这里其实也是之前提到过的 mock SPMD 的体现。 - 使用分布式通信,将 tp rank 0 生成的排序后的请求列表广播给所有其他 rank。
- 提取 prompt IDs、response IDs、attention masks、position IDs、loss masks、原始消息和 reward scores。
- 使用 padding token 对 prompt IDs 和 response IDs 进行填充,使其长度一致。
- 将填充后的 prompt 和 response 的 IDs、attention masks 等在最后一个维度上进行拼接,形成完整的序列数据。
- 将处理后的 prompts 和 responses 存储到
TensorDict
对象中,并设置批次大小。 - 将包含批次化张量数据的
TensorDict
和包含原始消息及奖励分数的字典封装到DataProto
对象中并返回。
这里有个比较有趣的地方,注意到 2 中我们强调了,SGLang 并不是严格的 SPMD,但是 3 中,我们仍旧将 tp 0 得到的 response broadcast 给了所有 rank。但是,为了保持 SGLang 外部的训练循环仍旧得到的是一个 SPMD 的返回结果,我们需要让每个 tp randk 都构造并返回相同的 batch,这就需要通过 broadcast 让其他 tp rank 获得 tp 0 的计算结果。这导致了一定的计算冗余,但是相比推理本身的开销,仍旧是可以负担的。
_req_level_generate_sequences 源码
|
|
显然,_req_level_generate_sequences
的核心在于这两个函数:
_preprocess_prompt_to_async_rollout_requests
_async_rollout_a_request
我们分别展开。
_preprocess_prompt_to_async_rollout_requests
- 将 prompts 展开,首先拆开 batch 中的每个 prompt,内层循环为每个 prompt 生成
n
个不同的序列。每个生成的请求都有唯一的batch_data_id
和rollout_offset
标识。 - 当配置了工具时,
_input_ids
和_attention_mask
被设为None
,因为工具调用需要动态构建输入。而没有配置工具的话,使用_pre_process_inputs
函数处理预处理的 token IDs,去除左填充。 - 每个请求对象包含状态管理、工具配置、序列长度限制、tokenizer 配置等元数据,为后续的异步处理提供完整信息。
_preprocess_prompt_to_async_rollout_requests 源码
|
|
这里其实重要的在于整个 AsyncRolloutRequest
,或者说我们用于管理 tool calling 的整个状态机 schema。
schema 状态机
stateDiagram-v2 [*] --> PENDING PENDING --> RUNNING : _handle_pending_state() RUNNING --> TOOL_CALLING : detect_tool_call TOOL_CALLING --> RUNNING : tool_call_executed TOOL_CALLING --> COMPLETED : tool_call_decode_failed RUNNING --> COMPLETED : stop_reason == STOP RUNNING --> [Exit] : finish_reason == LENGTH COMPLETED --> [Exit] note right of TOOL_CALLING if tool_calls == None: raise ValueError end note note right of RUNNING if exceeds max length: finish_reason = LENGTH end note
这些状态机挺抽象的,需要到了和 SGLang rollout 的交互部分才能真的理解到用法,不过我们还是先列举出来。
LENGTH
:达到最大长度限制STOP
:正常停止(如生成 EOS token)TOOL_CALL
:检测到工具调用
role
:消息角色(user/assistant/tool)content
:消息内容tool_calls
:可选的工具调用列表,每个工具调用包含name
和args
字段
目前的实现只支持单个工具的调用,但是魔改玩家太多了,甚至可以做一个 tool manager。
PENDING
:等待处理RUNNING
:正在运行TOOL_CALLING
:正在调用工具COMPLETED
:已完成FAILED
:失败
initialize_request
:验证必需字段(messages、max_prompt_len、tokenizer),使用 tokenizer 的 chat_template 处理消息,初始化所有序列相关字段(input_ids、attention_mask、position_ids、loss_mask),计算生成提示的位置信息_update_input_ids
:以增量方式更新序列信息,自动计算新的 position_ids,维护数据一致性验证get_generation_prompt_ids
:根据配置决定是否使用推理时的 chat_template,动态添加生成提示到输入序列add_assistant_message
:添加助手回复到消息历史,更新输入序列以包含新的回复内容,支持工具调用信息add_tool_response_messages
:添加工具响应到消息历史,更新输入序列但不标记为损失计算部分finalize
:完成请求处理,执行 tokenization 一致性检查,清理生成提示,截断输出序列到合理长度truncate_output_ids
:确保所有序列长度不超过限制,分别处理 input_ids、attention_mask、position_ids、loss_mask
_async_rollout_a_request
文档写的很详尽了,容易 lost in the middle。不过,我们回到主线,先前通过 _preprocess_prompt_to_async_rollout_requests
构造了 AsyncRolloutRequest
后,返回给 _req_level_generate_sequences
,接着进一步通过 _async_rollout_a_request
根据 AsyncRolloutRequest
的状态来 rollout 到底。
- 通过一个
while
循环来处理多轮对话,循环次数上限由self.config.multi_turn.max_turns
控制,或者 requests 返回FinishReasonTypeEnum.STOP
。 - 在循环内部,函数根据
_req
的当前状态 (AsyncRolloutRequestStateEnum
) 执行不同的操作(这块儿逻辑确实很复杂):PENDING
状态:如果请求处于PENDING
状态,则调用self._handle_pending_state(_req)
初始化,然后将状态更新为RUNNING
。TOOL_CALLING
状态:检查最后一条消息的工具调用信息 (_req.messages[-1].tool_calls
)。解析工具调用信息,并通过asyncio.gather
并发地执行每个工具调用。工具的执行逻辑封装在self._tool_map
中,通过工具的名称进行调用。在 tool call 返回后,通过_req.add_tool_response_messages
将工具的响应添加到消息历史中。遍历每个工具调用及其结果,通过_req.update_metrics
更新请求的指标信息。检查当前输入序列长度是否超过模型最大长度限制,如果超过,则设置finish_reason_type
为STOP
并跳出循环。最后,将请求状态更新回RUNNING
,以便进行下一轮的生成。RUNNING
状态:SGLang engine 需要进行 rollout。检查当前 prompt 的长度加上生成一个 token 的长度是否会超过 model context length。调用self._handle_engine_call
来实际调用 SGLang engine;得到输出后,将 finish reason 从字符串转换为FinishReasonTypeEnum
,并递增当前对话轮数current_turns
。如果完成原因是达到最大长度限制 (LENGTH
),则将生成的内容添加到消息历史中,并结束循环。如果没有到达最大长度,则判断 SGLang engine 生成的内容是否包含工具调用,通过self._function_call_parser
来解析生成的内容。如果检测到工具调用,则将finish_reason_type
设置为TOOL_CALL
,并将请求状态更新为TOOL_CALLING
。然后,使用self._function_call_parser.parse_non_stream
解析出工具调用,转换为OpenAIFunctionToolCall
。如果存在有效的工具调用,则通过_req.add_assistant_message
将工具调用信息添加到消息历史中。否则,只添加生成的内容,并将finish_reason_type
设置为STOP
,请求状态设置为COMPLETED
,并结束循环。如果生成的内容不包含工具调用,则直接通过_req.add_assistant_message
将生成的内容添加到消息历史中,并结束循环。
- 如果循环达到
self.config.multi_turn.max_turns
上限,则将finish_reason_type
设置为STOP
。 - 在对话循环结束后,为每个调用的工具计算奖励。遍历
_req.tools_kwargs
中的每个工具,调用工具的calc_reward
方法来计算奖励,以及release
方法来释放工具占用的·资源。计算结果以字典形式存储在tool_reward_scores
中。 - 调用
_req.finalize
方法,完成请求的最终处理,包括执行 tokenization 一致性检查、清理生成提示、截断输出序列到合理长度等。tool_reward_scores
和最终的finish_reason_type
会传递给finalize
方法。最后,函数最终返回处理完成的AsyncRolloutRequest
对象_req
。
_async_rollout_a_request 源码
|
|
pop and union
经过艰难深挖,我们终于完成了 Rollout 的理解,现在回到 RayPPOTrainer.fit()
上。我们来看看 rollout 部分的实现逻辑:
|
|
值得一提的是,我自己写了代码才理解到在 verl 当中,发给 rollout engine 的并不是整个完整的从 dataset 读取的 batch,而是通过 pop 构造的 gen_batch
。pop 是一个就地操作,完成后 batch 里面的 key 当然就没了。为此,如果想让 pop 前后都有一些需要的 key,得留一手考虑。比如说,我希望通过 uid 来把 gen_batch
和 batch
重新 union 起来,得反复添加 uid。
Make Experience
经过了漫长的战线,我们终于分析完了 rollout 部分的逻辑。我们接着分析 make experience 部分的逻辑。
Make Experience 源码
|
|
这一部分的操作还是很好读懂了,非常 standard:
- 通过
self.reward_fn
或self.rm_wg.compute_rm_score
计算 trajectory 的 reward。verl 支持各式各样的 reward,不单单是 reward model。 - 重算 behaviour policy 的 log probabilities: 使用
self.actor_rollout_wg.compute_log_prob(batch)
来重算 log probs。这里原因在 part 1 讲述 importance sampling 的部分也阐述过了。这里非常让我想吐槽的是,verl 里面old_log_prob
就是用 training engine 重算的 behaviour policy 的 log probs,用 old 来描述让我比较费解。 - 计算 reference policy 的 log probabilities: 如果使用了 reference policy,则计算 reference policy 的 log probs,用于 KL divergence 约束。
- 计算 Critic 的 value: 如果使用了 Critic model,则通过
self.critic_wg.compute_values(batch)
预测当前 state 的 value。 - 估算 Advantage: 调用
compute_advantage
函数,根据配置的advantage estimator、折扣因子 (gamma)、GALA 因子 (lam) 等参数,利用 reward 和 value 估计计算优势函数。
Training
非常标准:
|
|