fix bug in pegasus converter (#7094)
This commit is contained in:
@@ -47,8 +47,8 @@ PATTERNS = [
|
|||||||
|
|
||||||
def rename_state_dict_key(k):
|
def rename_state_dict_key(k):
|
||||||
|
|
||||||
for pegasus_name, bart_name in PATTERNS:
|
for pegasus_name, hf_name in PATTERNS:
|
||||||
k = k.replace(pegasus_name, bart_name)
|
k = k.replace(pegasus_name, hf_name)
|
||||||
return k
|
return k
|
||||||
|
|
||||||
|
|
||||||
@@ -57,13 +57,12 @@ def rename_state_dict_key(k):
|
|||||||
# TODO(SS): one constant
|
# TODO(SS): one constant
|
||||||
|
|
||||||
|
|
||||||
def convert_pegasus_to_bart(tf_weights: dict, cfg_updates: dict) -> PegasusForConditionalGeneration:
|
def convert_pegasus(tf_weights: dict, cfg_updates: dict) -> PegasusForConditionalGeneration:
|
||||||
cfg_kwargs = DEFAULTS.copy()
|
cfg_kwargs = DEFAULTS.copy()
|
||||||
cfg_kwargs.update(cfg_updates)
|
cfg_kwargs.update(cfg_updates)
|
||||||
|
cfg = PegasusConfig(**cfg_kwargs)
|
||||||
cfg = PegasusConfig(**cfg_updates)
|
torch_model = PegasusForConditionalGeneration(cfg)
|
||||||
bart = PegasusForConditionalGeneration(cfg)
|
sd = torch_model.model.state_dict()
|
||||||
sd = bart.model.state_dict()
|
|
||||||
mapping = {}
|
mapping = {}
|
||||||
for k, v in tf_weights.items():
|
for k, v in tf_weights.items():
|
||||||
new_k = rename_state_dict_key(k)
|
new_k = rename_state_dict_key(k)
|
||||||
@@ -80,13 +79,13 @@ def convert_pegasus_to_bart(tf_weights: dict, cfg_updates: dict) -> PegasusForCo
|
|||||||
mapping["decoder.embed_tokens.weight"] = mapping["shared.weight"]
|
mapping["decoder.embed_tokens.weight"] = mapping["shared.weight"]
|
||||||
empty_biases = {k: torch.zeros_like(v) for k, v in sd.items() if k.endswith("bias") and k not in mapping}
|
empty_biases = {k: torch.zeros_like(v) for k, v in sd.items() if k.endswith("bias") and k not in mapping}
|
||||||
mapping.update(**empty_biases)
|
mapping.update(**empty_biases)
|
||||||
missing, extra = bart.model.load_state_dict(mapping, strict=False)
|
missing, extra = torch_model.model.load_state_dict(mapping, strict=False)
|
||||||
unexpected_missing = [
|
unexpected_missing = [
|
||||||
k for k in missing if k not in ["encoder.embed_positions.weight", "decoder.embed_positions.weight"]
|
k for k in missing if k not in ["encoder.embed_positions.weight", "decoder.embed_positions.weight"]
|
||||||
]
|
]
|
||||||
assert unexpected_missing == [], f"no matches found for the following torch keys {unexpected_missing}"
|
assert unexpected_missing == [], f"no matches found for the following torch keys {unexpected_missing}"
|
||||||
assert extra == [], f"no matches found for the following tf keys {extra}"
|
assert extra == [], f"no matches found for the following tf keys {extra}"
|
||||||
return bart
|
return torch_model
|
||||||
|
|
||||||
|
|
||||||
def get_tf_weights_as_numpy(path="./ckpt/aeslc/model.ckpt-32000") -> Dict:
|
def get_tf_weights_as_numpy(path="./ckpt/aeslc/model.ckpt-32000") -> Dict:
|
||||||
@@ -115,7 +114,7 @@ def convert_pegasus_ckpt_to_pytorch(ckpt_path: str, save_dir: str):
|
|||||||
cfg_updates = task_specific_params[f"summarization_{dataset}"]
|
cfg_updates = task_specific_params[f"summarization_{dataset}"]
|
||||||
if dataset == "large":
|
if dataset == "large":
|
||||||
cfg_updates["task_specific_params"] = task_specific_params
|
cfg_updates["task_specific_params"] = task_specific_params
|
||||||
torch_model = convert_pegasus_to_bart(tf_weights, cfg_updates)
|
torch_model = convert_pegasus(tf_weights, cfg_updates)
|
||||||
torch_model.save_pretrained(save_dir)
|
torch_model.save_pretrained(save_dir)
|
||||||
sd = torch_model.state_dict()
|
sd = torch_model.state_dict()
|
||||||
sd.pop("model.decoder.embed_positions.weight")
|
sd.pop("model.decoder.embed_positions.weight")
|
||||||
|
|||||||
Reference in New Issue
Block a user