[VeRL,SGLang] RL训推显存管理优化
SGLang团队的博客:https://hebiao064.github.io/rl-memory-management Overview 上述是简化的在线RL训练流程,隐去了reference和critic model,并且用基础的reward function而非reward model来说明流程。实际上就是policy model存在的training engine和rollout engine上需要进行优化。 从简化的PPO流程开始: 1 2 3 4 5 6 7 8 9 for prompts, pretrain_batch in dataloader: # Stage 1: Rollout generation (inference) batch = actor.generate_sequences(prompts) # Stage 2: Prepare experience batch = reference.compute_log_prob(batch) batch = reward.compute_reward(batch) # Reward function or model batch = compute_advantages(batch, algo_type) # Stage 3: Actor training actor_metrics = actor.update_actor(batch) 每一个iter相当于是actor model进行一次rollout再进行training,而veRL因为rollout和training共部署,所以两边可能不用version的actor model是在相同的GPU组上的,这导致了虽然资源共享但是显存管理会变得更复杂。 显存问题 训练阶段显存 FSDP(fully sharded + full activation checkpointing)下,每个GPU占据显存: 每个GPU的峰值显存:~48GB 推理阶段显存 During inference, the full model is typically loaded (not sharded): ...