LLM | VeRL相关文档都包含哪些内容?
摘要:目录 PPO 示例脚本的 readme GRPO 示例脚本的 readme PPO 示例脚本的 readme examplesppo_trainerREADME.md 近端策略优化(Proximal Policy Optimizatio
目录
PPO 示例脚本的 readme
GRPO 示例脚本的 readme
PPO 示例脚本的 readme
examples/ppo_trainer/README.md
近端策略优化(Proximal Policy Optimization,PPO)是一类用于强化学习的策略梯度方法,由 OpenAI 于 2017 年提出。PPO 在简单性、稳定性和性能之间取得了平衡,使其成为现代 RL 应用(包括大规模语言模型微调)中最广泛使用的算法之一。
像 REINFORCE 或 Vanilla Policy Gradient 这样的传统策略梯度方法存在以下问题:
高方差和样本效率低下。
因策略更新过大导致的不稳定性。
PPO 使用一种裁剪后的替代目标函数来解决这个问题,该函数避免了过大的更新,同时不需要二阶导数。
关于 PPO 的更多技术细节,我们建议阅读 OpenAI spinning up 教程 的介绍以及论文 Proximal Policy Optimization Algorithms。
1 关键组件
Actor-Critic 架构:PPO 需要 actor 模型(策略)和 critic 模型(价值函数)。这与 GRPO 和 RLOO 等其他不需要 critic 模型的算法不同。
广义优势估计 (GAE):PPO 使用 GAE 来计算优势值,这有助于在保持低偏差的同时减少策略梯度估计的方差。
裁剪后的替代目标:PPO 的核心是通过裁剪后的替代目标函数实现的,该函数限制了策略更新。
2 配置
请注意,所有包含 micro_batch_size 的配置都用于配置每次前向或后向传递的最大样本数或 token 数,以避免 GPU 内存不足(OOM),其值不应改变算法/收敛行为。
大多数 critic 配置与 actor 的配置类似。注意下图中省略了 critic 模型。
data.train_batch_size:用于生成一组采样轨迹/rollout 的提示的全局批次大小。响应/轨迹的数量是 data.train_batch_size * actor_rollout.ref.rollout.n。
actor_rollout_ref.actor.ppo_mini_batch_size:采样得到的轨迹集被分割成多个大小为 ppo_mini_batch_size 的小批量,用于 PPO actor 的更新。ppo_mini_batch_size 是所有工作节点上的全局大小。
critic.ppo_mini_batch_size:采样得到的轨迹集被分割成多个大小为 ppo_mini_batch_size 的小批量,用于 PPO critic 的更新。ppo_mini_batch_size 是所有工作节点上的全局大小。
actor_rollout_ref.actor.clip_ratio:PPO 的裁剪范围。默认为 0.2。
actor_rollout_ref.actor.ppo_epochs:在一组采样轨迹上对 actor 进行 PPO 更新的轮数。
critic.ppo_epochs:在一组采样轨迹上对 critic 进行 PPO 更新的轮数。默认为 actor_rollout_ref.actor.ppo_epochs。
algorithm.gamma:折扣因子。
algorithm.lam:在 GAE 估计器中用于权衡偏差和方差的 lambda 项。
algorithm.adv_estimator:支持 gae、grpo、reinforce_plus_plus、reinforce_plus_plus_baseline、rloo、rloo_vectorized(似乎 PPO 算法应该使用 gae)。
3 高级扩展
3.1 KL 散度控制
防止策略与参考策略偏离太远的选项。提供了两种机制:KL 奖励惩罚和 KL 损失。更多技术细节,请参阅 Training language models to follow instructions with human feedback。
使用 KL 损失进行 KL 散度控制的选项:
actor_rollout_ref.actor.use_kl_loss:是否在 actor 中使用 KL 损失。使用时,我们不会在奖励函数中应用 KL。默认为 False。
actor_rollout_ref.actor.kl_loss_coef:KL 损失的系数。默认为 0.001。
actor_rollout_ref.actor.kl_loss_type:支持 kl(k1)、abs、mse(k2)、low_var_kl(k3) 和 full。
