From d5b0a0e235cc6fccba4f9013cdb54cee01e90a91 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Tue, 4 Aug 2020 09:53:51 -0400 Subject: [PATCH] mBART Conversion script (#6230) --- ..._original_pytorch_checkpoint_to_pytorch.py | 13 ------- ...rt_mbart_original_checkpoint_to_pytorch.py | 36 +++++++++++++++++++ 2 files changed, 36 insertions(+), 13 deletions(-) create mode 100644 src/transformers/convert_mbart_original_checkpoint_to_pytorch.py diff --git a/src/transformers/convert_bart_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/convert_bart_original_pytorch_checkpoint_to_pytorch.py index bba7b5a76b..52efc88f61 100644 --- a/src/transformers/convert_bart_original_pytorch_checkpoint_to_pytorch.py +++ b/src/transformers/convert_bart_original_pytorch_checkpoint_to_pytorch.py @@ -78,19 +78,6 @@ def load_xsum_checkpoint(checkpoint_path): return hub_interface -def convert_checkpoint_from_disk(checkpoint_path, **config_kwargs): - state_dict = torch.load(checkpoint_path, map_location="cpu")["model"] - remove_ignore_keys_(state_dict) - vocab_size = state_dict["encoder.embed_tokens.weight"].shape[0] - state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"] - mbart_config = BartConfig(vocab_size=vocab_size, **config_kwargs) - model = BartForConditionalGeneration(mbart_config) - model.model.load_state_dict(state_dict) - if hasattr(model, "lm_head"): - model.lm_head = _make_linear_from_emb(model.model.shared) - return model - - @torch.no_grad() def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path, hf_checkpoint_name=None): """ diff --git a/src/transformers/convert_mbart_original_checkpoint_to_pytorch.py b/src/transformers/convert_mbart_original_checkpoint_to_pytorch.py new file mode 100644 index 0000000000..e61395d0d4 --- /dev/null +++ b/src/transformers/convert_mbart_original_checkpoint_to_pytorch.py @@ -0,0 +1,36 @@ +import argparse + +import torch + +from transformers import BartForConditionalGeneration, MBartConfig + +from .convert_bart_original_pytorch_checkpoint_to_pytorch import remove_ignore_keys_ + + +def convert_fairseq_mbart_checkpoint_from_disk(checkpoint_path, hf_config_path="facebook/mbart-large-en-ro"): + state_dict = torch.load(checkpoint_path, map_location="cpu")["model"] + remove_ignore_keys_(state_dict) + vocab_size = state_dict["encoder.embed_tokens.weight"].shape[0] + mbart_config = MBartConfig.from_pretrained(hf_config_path, vocab_size=vocab_size) + state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"] + model = BartForConditionalGeneration(mbart_config) + model.model.load_state_dict(state_dict) + return model + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "fairseq_path", type=str, help="bart.large, bart.large.cnn or a path to a model.pt on local filesystem." + ) + parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") + parser.add_argument( + "--hf_config", + default="facebook/mbart-large-cc25", + type=str, + help="Which huggingface architecture to use: bart-large-xsum", + ) + args = parser.parse_args() + model = convert_fairseq_mbart_checkpoint_from_disk(args.fairseq_path, hf_config_path=args.hf_config) + model.save_pretrained(args.pytorch_dump_folder_path)