mBART Conversion script (#6230)
This commit is contained in:
@@ -78,19 +78,6 @@ def load_xsum_checkpoint(checkpoint_path):
|
|||||||
return hub_interface
|
return hub_interface
|
||||||
|
|
||||||
|
|
||||||
def convert_checkpoint_from_disk(checkpoint_path, **config_kwargs):
|
|
||||||
state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
|
|
||||||
remove_ignore_keys_(state_dict)
|
|
||||||
vocab_size = state_dict["encoder.embed_tokens.weight"].shape[0]
|
|
||||||
state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"]
|
|
||||||
mbart_config = BartConfig(vocab_size=vocab_size, **config_kwargs)
|
|
||||||
model = BartForConditionalGeneration(mbart_config)
|
|
||||||
model.model.load_state_dict(state_dict)
|
|
||||||
if hasattr(model, "lm_head"):
|
|
||||||
model.lm_head = _make_linear_from_emb(model.model.shared)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path, hf_checkpoint_name=None):
|
def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path, hf_checkpoint_name=None):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -0,0 +1,36 @@
|
|||||||
|
import argparse
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers import BartForConditionalGeneration, MBartConfig
|
||||||
|
|
||||||
|
from .convert_bart_original_pytorch_checkpoint_to_pytorch import remove_ignore_keys_
|
||||||
|
|
||||||
|
|
||||||
|
def convert_fairseq_mbart_checkpoint_from_disk(checkpoint_path, hf_config_path="facebook/mbart-large-en-ro"):
|
||||||
|
state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
|
||||||
|
remove_ignore_keys_(state_dict)
|
||||||
|
vocab_size = state_dict["encoder.embed_tokens.weight"].shape[0]
|
||||||
|
mbart_config = MBartConfig.from_pretrained(hf_config_path, vocab_size=vocab_size)
|
||||||
|
state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"]
|
||||||
|
model = BartForConditionalGeneration(mbart_config)
|
||||||
|
model.model.load_state_dict(state_dict)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
# Required parameters
|
||||||
|
parser.add_argument(
|
||||||
|
"fairseq_path", type=str, help="bart.large, bart.large.cnn or a path to a model.pt on local filesystem."
|
||||||
|
)
|
||||||
|
parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--hf_config",
|
||||||
|
default="facebook/mbart-large-cc25",
|
||||||
|
type=str,
|
||||||
|
help="Which huggingface architecture to use: bart-large-xsum",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
model = convert_fairseq_mbart_checkpoint_from_disk(args.fairseq_path, hf_config_path=args.hf_config)
|
||||||
|
model.save_pretrained(args.pytorch_dump_folder_path)
|
||||||
Reference in New Issue
Block a user