|
33 | 33 | ) |
34 | 34 |
|
35 | 35 | from verl.models.registry import ModelRegistry |
| 36 | +from verl.utils.import_utils import is_trl_available |
36 | 37 |
|
37 | 38 |
|
38 | 39 | class LambdaLayer(nn.Module): |
@@ -469,3 +470,70 @@ def get_parallel_gptmodel_from_config(tfconfig, hf_config, pre_process=None, pos |
469 | 470 |
|
470 | 471 | parallel_model.output_layer = LinearForLastLayer(input_size=tfconfig.hidden_size, output_size=1, config=tfconfig) |
471 | 472 | return parallel_model |
| 473 | + |
| 474 | + |
| 475 | +def patch_valuehead_model(model) -> None: |
| 476 | + from types import MethodType |
| 477 | + |
| 478 | + from transformers import PreTrainedModel |
| 479 | + |
| 480 | + from trl import AutoModelForCausalLMWithValueHead |
| 481 | + |
| 482 | + def tie_weights(self: "AutoModelForCausalLMWithValueHead") -> None: |
| 483 | + if isinstance(self.pretrained_model, PreTrainedModel): |
| 484 | + self.pretrained_model.tie_weights() |
| 485 | + |
| 486 | + def get_input_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module: |
| 487 | + if isinstance(self.pretrained_model, PreTrainedModel): |
| 488 | + return self.pretrained_model.get_input_embeddings() |
| 489 | + |
| 490 | + def get_output_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module: |
| 491 | + if isinstance(self.pretrained_model, PreTrainedModel): |
| 492 | + return self.pretrained_model.get_output_embeddings() |
| 493 | + |
| 494 | + def can_generate(self): |
| 495 | + return False |
| 496 | + |
| 497 | + ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name] |
| 498 | + setattr(model, "_keys_to_ignore_on_save", ignore_modules) |
| 499 | + setattr(model, "tie_weights", MethodType(tie_weights, model)) |
| 500 | + setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model)) |
| 501 | + setattr(model, "get_output_embeddings", MethodType(get_output_embeddings, model)) |
| 502 | + setattr(model, "can_generate", MethodType(can_generate, model)) |
| 503 | + setattr(model, "_no_split_modules", getattr(model.pretrained_model, "_no_split_modules", [])) |
| 504 | + |
| 505 | + |
| 506 | +def load_valuehead_model(local_path, torch_dtype, model_config, trust_remote_code): |
| 507 | + from transformers import AutoModelForTokenClassification, AutoModelForCausalLM, AutoModelForVision2Seq |
| 508 | + |
| 509 | + try: |
| 510 | + model = AutoModelForTokenClassification.from_pretrained( |
| 511 | + pretrained_model_name_or_path=local_path, |
| 512 | + torch_dtype=torch_dtype, |
| 513 | + config=model_config, |
| 514 | + attn_implementation="flash_attention_2", |
| 515 | + trust_remote_code=trust_remote_code, |
| 516 | + ) |
| 517 | + return model |
| 518 | + except BaseException as e: |
| 519 | + if not is_trl_available(): |
| 520 | + raise RuntimeError(f"model({local_path}) is not a value head model, please install trl to make it valid") from e |
| 521 | + |
| 522 | + assert is_trl_available() |
| 523 | + |
| 524 | + from trl import AutoModelForCausalLMWithValueHead |
| 525 | + |
| 526 | + if type(model_config) in AutoModelForVision2Seq._model_mapping.keys(): |
| 527 | + module_class = AutoModelForVision2Seq |
| 528 | + else: |
| 529 | + module_class = AutoModelForCausalLM |
| 530 | + ori_model = module_class.from_pretrained( |
| 531 | + pretrained_model_name_or_path=local_path, |
| 532 | + torch_dtype=torch_dtype, |
| 533 | + config=model_config, |
| 534 | + attn_implementation="flash_attention_2", |
| 535 | + trust_remote_code=trust_remote_code, |
| 536 | + ) |
| 537 | + model = AutoModelForCausalLMWithValueHead.from_pretrained(ori_model) |
| 538 | + patch_valuehead_model(model) |
| 539 | + return model |
0 commit comments