From 0ec63afec2558d76312bf6fddc3f171ceebfa584 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Sun, 13 Sep 2020 15:11:47 -0400 Subject: [PATCH] fix bug in pegasus converter (#7094) --- .../convert_pegasus_tf_to_pytorch.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/transformers/convert_pegasus_tf_to_pytorch.py b/src/transformers/convert_pegasus_tf_to_pytorch.py index edf0498f37..c73bb4d66d 100644 --- a/src/transformers/convert_pegasus_tf_to_pytorch.py +++ b/src/transformers/convert_pegasus_tf_to_pytorch.py @@ -47,8 +47,8 @@ PATTERNS = [ def rename_state_dict_key(k): - for pegasus_name, bart_name in PATTERNS: - k = k.replace(pegasus_name, bart_name) + for pegasus_name, hf_name in PATTERNS: + k = k.replace(pegasus_name, hf_name) return k @@ -57,13 +57,12 @@ def rename_state_dict_key(k): # TODO(SS): one constant -def convert_pegasus_to_bart(tf_weights: dict, cfg_updates: dict) -> PegasusForConditionalGeneration: +def convert_pegasus(tf_weights: dict, cfg_updates: dict) -> PegasusForConditionalGeneration: cfg_kwargs = DEFAULTS.copy() cfg_kwargs.update(cfg_updates) - - cfg = PegasusConfig(**cfg_updates) - bart = PegasusForConditionalGeneration(cfg) - sd = bart.model.state_dict() + cfg = PegasusConfig(**cfg_kwargs) + torch_model = PegasusForConditionalGeneration(cfg) + sd = torch_model.model.state_dict() mapping = {} for k, v in tf_weights.items(): new_k = rename_state_dict_key(k) @@ -80,13 +79,13 @@ def convert_pegasus_to_bart(tf_weights: dict, cfg_updates: dict) -> PegasusForCo mapping["decoder.embed_tokens.weight"] = mapping["shared.weight"] empty_biases = {k: torch.zeros_like(v) for k, v in sd.items() if k.endswith("bias") and k not in mapping} mapping.update(**empty_biases) - missing, extra = bart.model.load_state_dict(mapping, strict=False) + missing, extra = torch_model.model.load_state_dict(mapping, strict=False) unexpected_missing = [ k for k in missing if k not in ["encoder.embed_positions.weight", "decoder.embed_positions.weight"] ] assert unexpected_missing == [], f"no matches found for the following torch keys {unexpected_missing}" assert extra == [], f"no matches found for the following tf keys {extra}" - return bart + return torch_model def get_tf_weights_as_numpy(path="./ckpt/aeslc/model.ckpt-32000") -> Dict: @@ -115,7 +114,7 @@ def convert_pegasus_ckpt_to_pytorch(ckpt_path: str, save_dir: str): cfg_updates = task_specific_params[f"summarization_{dataset}"] if dataset == "large": cfg_updates["task_specific_params"] = task_specific_params - torch_model = convert_pegasus_to_bart(tf_weights, cfg_updates) + torch_model = convert_pegasus(tf_weights, cfg_updates) torch_model.save_pretrained(save_dir) sd = torch_model.state_dict() sd.pop("model.decoder.embed_positions.weight")