|
|
|
|
@@ -17,6 +17,7 @@
|
|
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
|
import logging
|
|
|
|
|
import os
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
import fairseq
|
|
|
|
|
@@ -30,10 +31,11 @@ from transformers import (
|
|
|
|
|
BartModel,
|
|
|
|
|
BartTokenizer,
|
|
|
|
|
)
|
|
|
|
|
from transformers.modeling_bart import _make_linear_from_emb
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FAIRSEQ_MODELS = ["bart.large", "bart.large.mnli", "bart.large.cnn"]
|
|
|
|
|
|
|
|
|
|
FAIRSEQ_MODELS = ["bart.large", "bart.large.mnli", "bart.large.cnn", "bart_xsum/model.pt"]
|
|
|
|
|
extra_arch = {"bart.large": BartModel, "bart.large.mnli": BartForSequenceClassification}
|
|
|
|
|
if version.parse(fairseq.__version__) < version.parse("0.9.0"):
|
|
|
|
|
raise Exception("requires fairseq >= 0.9.0")
|
|
|
|
|
|
|
|
|
|
@@ -57,62 +59,79 @@ def rename_key(dct, old, new):
|
|
|
|
|
dct[new] = val
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path):
|
|
|
|
|
def load_xsum_checkpoint(checkpoint_path):
|
|
|
|
|
"""Checkpoint path should end in model.pt"""
|
|
|
|
|
sd = torch.load(checkpoint_path, map_location="cpu")
|
|
|
|
|
hub_interface = torch.hub.load("pytorch/fairseq", "bart.large.cnn").eval()
|
|
|
|
|
hub_interface.model.load_state_dict(sd["model"])
|
|
|
|
|
return hub_interface
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path, hf_checkpoint_name=None):
|
|
|
|
|
"""
|
|
|
|
|
Copy/paste/tweak model's weights to our BERT structure.
|
|
|
|
|
"""
|
|
|
|
|
bart = torch.hub.load("pytorch/fairseq", checkpoint_path)
|
|
|
|
|
bart.eval() # disable dropout
|
|
|
|
|
if not os.path.exists(checkpoint_path):
|
|
|
|
|
bart = torch.hub.load("pytorch/fairseq", checkpoint_path).eval()
|
|
|
|
|
else:
|
|
|
|
|
bart = load_xsum_checkpoint(checkpoint_path)
|
|
|
|
|
|
|
|
|
|
bart.model.upgrade_state_dict(bart.model.state_dict())
|
|
|
|
|
hf_model_name = checkpoint_path.replace(".", "-")
|
|
|
|
|
config = BartConfig.from_pretrained(hf_model_name)
|
|
|
|
|
if hf_checkpoint_name is None:
|
|
|
|
|
hf_checkpoint_name = checkpoint_path.replace(".", "-")
|
|
|
|
|
config = BartConfig.from_pretrained(hf_checkpoint_name)
|
|
|
|
|
tokens = bart.encode(SAMPLE_TEXT).unsqueeze(0)
|
|
|
|
|
tokens2 = BartTokenizer.from_pretrained(hf_model_name).encode(SAMPLE_TEXT, return_tensors="pt").unsqueeze(0)
|
|
|
|
|
tokens2 = BartTokenizer.from_pretrained(hf_checkpoint_name).encode(SAMPLE_TEXT, return_tensors="pt").unsqueeze(0)
|
|
|
|
|
assert torch.eq(tokens, tokens2).all()
|
|
|
|
|
|
|
|
|
|
if checkpoint_path in ["bart.large", "bart.large.cnn"]:
|
|
|
|
|
state_dict = bart.model.state_dict()
|
|
|
|
|
for k in IGNORE_KEYS:
|
|
|
|
|
state_dict.pop(k, None)
|
|
|
|
|
state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"]
|
|
|
|
|
model = BartModel(config)
|
|
|
|
|
their_output = bart.extract_features(tokens)
|
|
|
|
|
else: # MNLI Case
|
|
|
|
|
if checkpoint_path == "bart.large.mnli":
|
|
|
|
|
state_dict = bart.state_dict()
|
|
|
|
|
for k in IGNORE_KEYS:
|
|
|
|
|
state_dict.pop(k, None)
|
|
|
|
|
remove_ignore_keys_(state_dict)
|
|
|
|
|
state_dict["model.shared.weight"] = state_dict["model.decoder.embed_tokens.weight"]
|
|
|
|
|
for src, dest in rename_keys:
|
|
|
|
|
rename_key(state_dict, src, dest)
|
|
|
|
|
model = BartForSequenceClassification(config)
|
|
|
|
|
their_output = bart.predict("mnli", tokens, return_logits=True)
|
|
|
|
|
model = BartForSequenceClassification(config).eval()
|
|
|
|
|
model.load_state_dict(state_dict)
|
|
|
|
|
fairseq_output = bart.predict("mnli", tokens, return_logits=True)
|
|
|
|
|
new_model_outputs = model(tokens)[0] # logits
|
|
|
|
|
else: # no classification heads to worry about
|
|
|
|
|
state_dict = bart.model.state_dict()
|
|
|
|
|
remove_ignore_keys_(state_dict)
|
|
|
|
|
state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"]
|
|
|
|
|
fairseq_output = bart.extract_features(tokens)
|
|
|
|
|
if hf_checkpoint_name == "bart-large":
|
|
|
|
|
model = BartModel(config).eval()
|
|
|
|
|
model.load_state_dict(state_dict)
|
|
|
|
|
new_model_outputs = model(tokens).model[0]
|
|
|
|
|
else:
|
|
|
|
|
model = BartForConditionalGeneration(config).eval() # an existing summarization ckpt
|
|
|
|
|
model.model.load_state_dict(state_dict)
|
|
|
|
|
if hasattr(model, "lm_head"):
|
|
|
|
|
model.lm_head = _make_linear_from_emb(model.model.shared)
|
|
|
|
|
new_model_outputs = model.model(tokens)[0]
|
|
|
|
|
|
|
|
|
|
# Load state dict
|
|
|
|
|
model.load_state_dict(state_dict)
|
|
|
|
|
model.eval()
|
|
|
|
|
# Check results
|
|
|
|
|
|
|
|
|
|
if checkpoint_path == "bart.large.cnn":
|
|
|
|
|
model = BartForConditionalGeneration(config, base_model=model)
|
|
|
|
|
assert "lm_head.weight" in model.state_dict()
|
|
|
|
|
assert model.lm_head.out_features == config.max_position_embeddings
|
|
|
|
|
model.eval()
|
|
|
|
|
our_outputs = model.model(tokens)[0]
|
|
|
|
|
else:
|
|
|
|
|
our_outputs = model(tokens)[0]
|
|
|
|
|
assert their_output.shape == our_outputs.shape
|
|
|
|
|
assert (their_output == our_outputs).all().item()
|
|
|
|
|
assert fairseq_output.shape == new_model_outputs.shape
|
|
|
|
|
assert (fairseq_output == new_model_outputs).all().item()
|
|
|
|
|
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
|
|
|
|
model.save_pretrained(pytorch_dump_folder_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def remove_ignore_keys_(state_dict):
|
|
|
|
|
for k in IGNORE_KEYS:
|
|
|
|
|
state_dict.pop(k, None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
|
# Required parameters
|
|
|
|
|
parser.add_argument("fairseq_path", choices=FAIRSEQ_MODELS, type=str, help="")
|
|
|
|
|
|
|
|
|
|
parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
convert_bart_checkpoint(
|
|
|
|
|
args.fairseq_path, args.pytorch_dump_folder_path,
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"fairseq_path", type=str, help="bart.large, bart.large.cnn or a path to a model.pt on local filesystem."
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--hf_config", default=None, type=str, help="Which huggingface architecture to use: bart-large-xsum"
|
|
|
|
|
)
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
convert_bart_checkpoint(args.fairseq_path, args.pytorch_dump_folder_path, hf_checkpoint_name=args.hf_config)
|
|
|
|
|
|