OpenRLHF源码中,如何模型处理以适应?
摘要:本文主要介绍了在 **OpenRLHF**中模型框架设计,主要分为3类模型:1、`actor model`;2、`critic model`;3、`reward model`这三类模型中分别起到作用:1、直接更具prompt输出respon
强化学习框架:OpenRLHF源码解读,模型处理
本文主要介绍 强化学习框架:OpenRLHF源码解读,模型处理
models框架设计
了解一下 OpenRLHF的模型框架设计范式:
From:https://arxiv.org/pdf/2405.11143
可以知道一个大概的流程:输入Pormpt通过Actor model输出回复 Response,而后将两部分进行拼接再去由其他模型进行处理
1、actor.py
https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/models/actor.py
这部分主要为加载所需要的模型
class Actor(nn.Module):
def __init__(...):
if isinstance(pretrain_or_model, str):
...
self.model = model_class.from_pretrained(
pretrain_or_model,
trust_remote_code=True,
attn_implementation=attn_implementation,
quantization_config=nf4_config,
torch_dtype=torch.bfloat16 if bf16 else "auto",
device_map=device_map,
)
if lora_rank > 0:
self.model.enable_input_require_grads()
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=lora_rank,
lora_alpha=lora_alpha,
target_modules=target_modules,
lora_dropout=lora_dropout,
bias="none",
)
self.model = get_peft_model(self.model, lora_config)
...
else:
self.model = pretrain_or_model
@torch.no_grad()
def generate(self, input_ids: torch.Tensor, **kwargs):
...
sequences = self.model.generate(**generate_args)
eos_token_id = generate_args["eos_token_id"]
pad_token_id = generate_args["pad_token_id"]
return self.process_sequences(sequences, input_ids.size(1), eos_token_id, pad_token_id)
def forward(...):
...
output["logits"] = output["logits"].to(torch.float32) # 得到每一个token概率
...
log_probs = log_probs_from_logits(
output["logits"][:, :-1, :], sequences[:, 1:], temperature=self.temperature
)
...
action_log_probs = log_prob
