最近 RL sys 圈子的吴锡斌老师在 verl 上设计了将 rollout 与 tool 调用解耦的 AgentLoop,实现了自由灵活的 mutli-turn RL。在每个 AgentLoop 内部,rollout engine 只对外提供一个 token-in-token-out 的接口,而 tool 调用则通过 ToolAgentLoop
来实现。我个人比较喜欢这样解耦的设计,同时,AgentLoop 的代码结构也比较清晰。我个人学习了一次整个代码后,觉着 AgentLoop 的设计甚是不错,但是 ActorRolloutRefWorker
的历史包袱还是很重。
本文简单分析了 agent loop 的源码,并给出了一些自己的看法。
如果我们把整个 ActorRolloutRefWorker
当做一个 sgl.Engine
的话,AgentLoop 里面包装的两层 AsyncSGLangServer
和 AsyncLLMServerManager
。AsyncSGLangServer
相当于在 sgl.Engine
上包装了 fastapi
成了 server,而 AsyncLLMServerManager
是在 server 上包了一层 router 做 load balance,相当于 sglang 的 router。这两层设计都是合理的,主要麻烦的是 ActorRolloutRefWorker
,层层调用,最后一共经过 7 个 class 才调到 sgl.Engine
,最近 verl 团队也在致力于对这块 worker class 的重构,敬请期待。最后,AgentLoopManager
,AgentLoopWorker
和 AgentLoop
这三层,我觉得 AgentLoopWorker
可能未必有必要,其他两层挺合理的。
Related Resources
Script
Related PR
https://github.com/volcengine/verl/pull/2124
Design Docs
https://github.com/volcengine/verl/pull/2563
https://github.com/volcengine/verl/pull/A2598
Commit we are looking at
https://github.com/volcengine/verl/tree/c5b189a1af496d0bc68320cd1d5bd7a1f1e3638a
使用 AgentLoop
安装 verl-sglang 的最新版本:
|
|
具体实现自己的 agent loop(见下文分析),然后配置 config 文件:
|
|
注意,不使用 actor_rollout_ref.rollout.mode=async
的话,会启用 SGLangRollout 本身管理的 mutli-turn 功能,在效果上和 AgentLoop 完全一致。
最后,在数据集构建过程中添加一个新的 agent_name
字段,比如我们在 ~/verl/examples/data_preprocess/gsm8k_multiturn_w_tool.py
中追加 "agent_name": "tool_agent"
:
|
|
调用总览
main_ppo.py -> RayPPOTrainer(fit)-> AgentLoopManager(async) -> AgentLoopWorker -> AsyncLLMServerManager -> AsyncSGLangServer -> AsyncActorRolloutRefWorker -> SGLangRollout -> AsyncEngine -> sgl.Engine
TaskRunner
启动训练,调用RayPPOTrainer.fit()
。RayPPOTrainer
管理训练流程,调用AgentLoopManager.generate_sequences()
开始层层向下调用,同时初始化AsyncActorRolloutRefWorker
。AgentLoopManager
初始化 dp 个AsyncSGLangServer
,随后,初始化num_rollout_workers
个AgentLoopWorker
。- 接着,每个
AgentLoopWorker
根据agent_name
从预先注册好的_agent_loop_registry
初始化自身管理的train_batch_size / num_rollout_workers
个AgentLoop
实例,对于 GRPO,train_batch_size
需要乘以 group size。用户可以依照自身需求注册新的AgentLoop
,目前通过ToolAgentLoop
来完全覆盖了SGLangRollout
中基于_req_level_generate_sequences
实现的 tool call 管理。也就是说, 先前的 multi-turn RL 的 tool 状态管理是在SGLangRollout
内实现的,而AgentLoop
将这层管理抽象了出来,SGLangRollout
只是向上包装为AsyncSGLangServer
来完成 token-in-token-out。 AgentLoop
初始化后,管理 tool 调用的各种状态,并且根据 policy 的返回情况,向下层层调用AsyncLLMServerManager
->AsyncSGLangServer
->AsyncActorRolloutRefWorker
->SGLangRollout
->AsyncEngine
->sgl.Engine
,得到模型输出。 返回输出后,AgentLoop
生命周期结束。AgentLoopWorker
收集所有AgentLoop
的返回值,上交给AgentLoopManager
,等待下一次调用。AgentLoopManager
收集所有AgentLoopWorker
的返回值,返回。
AgentLoopManager
AgentLoop 的最顶层管理者,负责管理 AgentLoopWorker 以及 LLM servers 的生命周期。核心方法是generate_sequences
:向下层层调用,得到 policy model 在给定的 agent loop 环境下的 trajectories。
核心 API
在 RayPPOTrainer
中被初始化:
|
|
具体的初始化非常简洁:
__init__
|
|
- 传入 ActorRolloutRefWOrker 对应的 worker group,在
_initialize_llm_servers
里用来查找对应的 RolloutWorker; - 初始化 llm server 和 agent loop workers;
_initialize_llm_servers
- 计算 dp size:
self.rollout_dp_size = self.worker_group.world_size // self.rollout_tp_size
- 通过
async_server_class(rollout_backend=self.config.actor_rollout_ref.rollout.name)
获取服务器类,如Async``SGLang``Server
,作为和下层的sgl.Engine
通信的转接层。 - 用 ray 初始化 dp size 个 server,为每个 dp rank 创建 server 实例。
- 通过
ray.get(server.get_server_address.remote())
获取并记录每个服务器的地址 - 调用
ray.get([server.init_engine.remote() for server in self.async_llm_servers])
;server 从 ray 通过前缀查询,在已经初始化好的 ray actor 中拿到自己对应的所有 SGLang engine。
|
|
_init_agent_loop_workers
在 ray 上初始化 rollout.agent.num_workers
个 AgentLoopWorker
:
|
|
generate_sequences
- 如果配置了
free_cache_engine
,先调用self.wake_up()
chunkes = prompts.chunk(len(self.agent_loop_workers))
将输入批次按 AgentLoopWorker 数量分块。- 每个 agentLoopWorker 处理自身的 chunk,通过
ray.get([worker.generate_sequences.remote(chunk) for ...])
并行执行并得到结果; - 处理完成后调用
self.sleep()
让 server 进入睡眠状态以释放显存 - 计算生成序列和工具调用的性能指标
- 合并所有
A``gentLoopWorker
的输出并返回
Code link [here]
|
|
AsyncSGLangServer
基于 SGLang 的异步服务器实现,继承自AsyncServerBase
。作为 Ray 远程 actor 运行,负责将收到的请求转发给下层的 SGLang Engine。出于 SGLang 的设计,调用 generate
的时候只需要对 master worker(verl 的 inference tp 0)调用即可。
核心 API
init_engine
异步初始化 SGLang 引擎:
- 通过
ray.util.list_named_actors
查找所有匹配的 actors; - 根据命名规则
self.wg_prefix + "WorkerDict_"
解析 actor 名称; - 根据 dp_rank 和 tp_size 分配 actor,确定 master worker(tp rank 0)
|
|
chat_completion
处理 chat_completion
请求:
|
|
- 将请求转发给 master worker 处理
- 返回 JSON 格式的响应
generate
Token in token out 来获得 SGLang Engine 的 inference 结果:
|
|
- 直接调用 master worker 的生成方法
- 支持自定义采样参数
AsyncLLMServerManager
管理多个 OpenAI 兼容的 LLM 服务器 (例如 Async``SGLang``Server
),提供负载均衡和会话粘性功能。支持最少请求负载均衡算法,确保多轮对话发送到同一服务器以实现自动前缀缓存。可以认为就是简单的 router/load balancer 层。
初始化
- 配置服务器句柄列表,随机打乱顺序
- 初始化最少请求负载均衡器:
self.weighted_serveres = [[0, (hash(server), server)] for server in server_handles]
- 创建 LRU 缓存:
self.request_id_to_server = LRUCache(maxsize=max_cache_size)
用于 request_id 到服务器的映射
|
|
_choose_server
|
|
- 会话粘性:相同
request_id
发送给同一server
- 最少请求:新请求分配给当前负载最轻的
server
- 动态更新:使用堆结构维护服务器负载状态
generate
|
|
- 根据
request_id
选择server
- 异步调用 server 的生成接口,token-in-token-out
- 支持性能追踪
AgentLoopWorker
AgentLoopWorker
负责接收数据,向下发给具体的 AgentLoop
。虽然名字是 worker,但是
- 从 ray 的角度来说,
AgentLoopWorker
是有状态的,是 ray actor,而不是 ray worker - 核心函数
generate
是层层套壳,调用其他类;例如single_turn_agent_loop
和tool_agent_loop
来generate
(当然这两个类的generate
也是向下调用,下面会讲到)
__init__
|
|
- 上游传过来的
config
和server_handles
作为参数来初始化AsyncLLMServerManager
,之后会把这个self.server_manager
传给下游; - 根据
config
的config
.
actor_rollout_ref
.
model
.
path
设置model_path, local_path, tokenizer
- 配置
RolloutTraceConfig
用于追踪 trajectories
generate_sequences
|
|
- 利用上游传来的
config
,创建给下游使用的sampling_params
;对 validation batch 要用 validation 参数。 - 利用 batch 的
meta_info
,获得agent_name, raw_prompts, index
。再用这个meta_info
处理获得trajectory_info
;就是利用刚才的 index 来计算在每一个 step 每一个 prompt 被 rollout 的次数,然后存到一个 list 中获得整个 rollout 的 trace; - 利用
agent_names, raw_prompts, trajectory_info
来并发执行_run_agent_loop
。 - 在
_run_agent_loop
函数内,就要进行相应agent_name
的agent_loop
实例化,以及调用agent_loop
对应的 run 函数来 generate。 - 在
_postprocess
中,会根据前面计算出来的 output(被封装成了AgentLoopOutput
格式)来进行后处理;padding,加入 mask,最后封装成一个DataProto
返回。
|
|
AgentLoop
终于进入到了具体的 agent loop 当中,我们观察两种具体的 AgentLoop。
SingleTurnAgentLoop
这个 agent_loop
是默认的单轮对话,处理简单的一问一答,不支持工具调用;最重要的自然是 run
函数:
- 我们传入
agent_loop
的messages
其实是我们从batch
里面获得的raw_prompt
,此处调用apply_chat_template
; - 调用
server_manager
里面的generate
函数来计算response_ids
; - 计算
response_mask
,并根据response_length
截取,封装这些结果成AgentLoopOutput
,padding 在上层AgentLoopManager
的_postprocess
内做;
|
|
ToolAgentLoop
终于到了最核心的地方。ToolAgentLoop
支持多轮对话和工具调用。目前 ToolAgentLoop
可以完全覆盖 SGLangRollout
中基于 _async_rollout_a_request
实现的 tool call 管理。但状态数量和转移关系更加简单。也就是说, 先前的 multi-turn RL 的 tool 状态管理是在 SGLangRollout
内实现的,而 AgentLoop
提前将这层管理抽象了出来。
init_class
下面只介绍一些关键参数的作用:
- **
tool_response_truncate_side
:**控制工具响应内容过长时的截断方式。"left"
:从左侧截断,保留开头部分 + “…(truncated)";"right"
:从右侧截断,保留结尾部分,前面加 “(truncated)…";- 其他值:从中间截断,保留开头和结尾部分,中间加 “…(truncated)…”
tool_config_path
:指定包含工具定义和配置信息的配置文件位置,用于初始化可用的工具列表,比如verl/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml
|
|
tool_list
, tool_schemas
:通过 initialize_tools_from_config(tool_config_path)
函数从配置文件中解析并创建工具实例。
tool_parser
:通过设置类似 actor_rollout_ref.rollout.multi_turn.format=hermes
这样的参数, 可以获取对应的 tool_parser
;比如 HermesToolParser
就是提取 <tool_call></tool_call>
之间的内容,返回对应的 function_call
(function_name
和 function_arguments
), 还有除开 tool_call
内容以外的 content
。
|
|
run
- 和
single_turn_agent_loop
一样,对 promptsapply_chat_template
; - 初始化
user_turns, assistant_turns
,进入 multi-turn 的 loop 循环,直到退出:- 向
server_manager
发送prompt_ids
,得到对应的response_ids
;将本轮返回的response_ids
append 到prompt_ids
中,准备作为下一轮的输入,并且assistant_turns += 1
- 处理边界条件,比如 prompts 过长,没有 tool call 了,或者超出了 max turns;
- 异步执行
_call_tool
:从 response 中 extract 出 Function Call,接着tool
.
execute(instance_id
,
tool_args)
获得相应的tool_response
, 然后截断返回即可。具体的_call_tool
会在后文分析。 tool_responses
随后apply_chat_template
得到tool_response_ids
,同样 append 到prompt_ids
内,然后user_turns += 1
,进入下一轮循环;
- 向
- 退出 tool agent loop 循环后,构造
AgentLoopOutput
注意 num_turns=user_turns+assistant_turns +1,因为 prompt 也算一次 user turn
|
|
call_tool
基于 tool list 内的 tool 来调用工具,例如前面 config 中配置的 calc_gsm8k_reward
,从 tool parser 得到 arguments 就可以代入运算得到相应的tool_response
。如果 tool 调用成功,则会释放 tool 占用的资源,,最后tool_response
根据 tool_response_truncate_side
来做相应的截断。
|
|