fix bug in pegasus converter (#7094)
This commit is contained in:
@@ -47,8 +47,8 @@ PATTERNS = [
|
||||
|
||||
def rename_state_dict_key(k):
|
||||
|
||||
for pegasus_name, bart_name in PATTERNS:
|
||||
k = k.replace(pegasus_name, bart_name)
|
||||
for pegasus_name, hf_name in PATTERNS:
|
||||
k = k.replace(pegasus_name, hf_name)
|
||||
return k
|
||||
|
||||
|
||||
@@ -57,13 +57,12 @@ def rename_state_dict_key(k):
|
||||
# TODO(SS): one constant
|
||||
|
||||
|
||||
def convert_pegasus_to_bart(tf_weights: dict, cfg_updates: dict) -> PegasusForConditionalGeneration:
|
||||
def convert_pegasus(tf_weights: dict, cfg_updates: dict) -> PegasusForConditionalGeneration:
|
||||
cfg_kwargs = DEFAULTS.copy()
|
||||
cfg_kwargs.update(cfg_updates)
|
||||
|
||||
cfg = PegasusConfig(**cfg_updates)
|
||||
bart = PegasusForConditionalGeneration(cfg)
|
||||
sd = bart.model.state_dict()
|
||||
cfg = PegasusConfig(**cfg_kwargs)
|
||||
torch_model = PegasusForConditionalGeneration(cfg)
|
||||
sd = torch_model.model.state_dict()
|
||||
mapping = {}
|
||||
for k, v in tf_weights.items():
|
||||
new_k = rename_state_dict_key(k)
|
||||
@@ -80,13 +79,13 @@ def convert_pegasus_to_bart(tf_weights: dict, cfg_updates: dict) -> PegasusForCo
|
||||
mapping["decoder.embed_tokens.weight"] = mapping["shared.weight"]
|
||||
empty_biases = {k: torch.zeros_like(v) for k, v in sd.items() if k.endswith("bias") and k not in mapping}
|
||||
mapping.update(**empty_biases)
|
||||
missing, extra = bart.model.load_state_dict(mapping, strict=False)
|
||||
missing, extra = torch_model.model.load_state_dict(mapping, strict=False)
|
||||
unexpected_missing = [
|
||||
k for k in missing if k not in ["encoder.embed_positions.weight", "decoder.embed_positions.weight"]
|
||||
]
|
||||
assert unexpected_missing == [], f"no matches found for the following torch keys {unexpected_missing}"
|
||||
assert extra == [], f"no matches found for the following tf keys {extra}"
|
||||
return bart
|
||||
return torch_model
|
||||
|
||||
|
||||
def get_tf_weights_as_numpy(path="./ckpt/aeslc/model.ckpt-32000") -> Dict:
|
||||
@@ -115,7 +114,7 @@ def convert_pegasus_ckpt_to_pytorch(ckpt_path: str, save_dir: str):
|
||||
cfg_updates = task_specific_params[f"summarization_{dataset}"]
|
||||
if dataset == "large":
|
||||
cfg_updates["task_specific_params"] = task_specific_params
|
||||
torch_model = convert_pegasus_to_bart(tf_weights, cfg_updates)
|
||||
torch_model = convert_pegasus(tf_weights, cfg_updates)
|
||||
torch_model.save_pretrained(save_dir)
|
||||
sd = torch_model.state_dict()
|
||||
sd.pop("model.decoder.embed_positions.weight")
|
||||
|
||||
Reference in New Issue
Block a user