OpenRLHF源码中,如何模型处理以适应?

摘要:本文主要介绍了在 **OpenRLHF**中模型框架设计,主要分为3类模型:1、`actor model`;2、`critic model`;3、`reward model`这三类模型中分别起到作用:1、直接更具prompt输出respon
强化学习框架:OpenRLHF源码解读,模型处理 本文主要介绍 强化学习框架:OpenRLHF源码解读,模型处理 models框架设计 了解一下 OpenRLHF的模型框架设计范式: From:https://arxiv.org/pdf/2405.11143 可以知道一个大概的流程:输入Pormpt通过Actor model输出回复 Response,而后将两部分进行拼接再去由其他模型进行处理 1、actor.py https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/models/actor.py 这部分主要为加载所需要的模型 class Actor(nn.Module): def __init__(...): if isinstance(pretrain_or_model, str): ... self.model = model_class.from_pretrained( pretrain_or_model, trust_remote_code=True, attn_implementation=attn_implementation, quantization_config=nf4_config, torch_dtype=torch.bfloat16 if bf16 else "auto", device_map=device_map, ) if lora_rank > 0: self.model.enable_input_require_grads() lora_config = LoraConfig( task_type=TaskType.CAUSAL_LM, r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules, lora_dropout=lora_dropout, bias="none", ) self.model = get_peft_model(self.model, lora_config) ... else: self.model = pretrain_or_model @torch.no_grad() def generate(self, input_ids: torch.Tensor, **kwargs): ... sequences = self.model.generate(**generate_args) eos_token_id = generate_args["eos_token_id"] pad_token_id = generate_args["pad_token_id"] return self.process_sequences(sequences, input_ids.size(1), eos_token_id, pad_token_id) def forward(...): ... output["logits"] = output["logits"].to(torch.float32) # 得到每一个token概率 ... log_probs = log_probs_from_logits( output["logits"][:, :-1, :], sequences[:, 1:], temperature=self.temperature ) ... action_log_probs = log_prob
阅读全文