From 343057e1413924152c1a3716a31775660dedb229 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Mon, 1 Feb 2021 21:47:14 +0530 Subject: [PATCH] Fix bart conversion script (#9923) * fix conversion script * typo * import nn --- ...ert_bart_original_pytorch_checkpoint_to_pytorch.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/bart/convert_bart_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/bart/convert_bart_original_pytorch_checkpoint_to_pytorch.py index 8978b8b2e5..baa2fff290 100644 --- a/src/transformers/models/bart/convert_bart_original_pytorch_checkpoint_to_pytorch.py +++ b/src/transformers/models/bart/convert_bart_original_pytorch_checkpoint_to_pytorch.py @@ -22,6 +22,7 @@ from pathlib import Path import fairseq import torch from packaging import version +from torch import nn from transformers import ( BartConfig, @@ -30,7 +31,6 @@ from transformers import ( BartModel, BartTokenizer, ) -from transformers.models.bart.modeling_bart import _make_linear_from_emb from transformers.utils import logging @@ -78,6 +78,13 @@ def load_xsum_checkpoint(checkpoint_path): return hub_interface +def make_linear_from_emb(emb): + vocab_size, emb_size = emb.weight.shape + lin_layer = nn.Linear(vocab_size, emb_size, bias=False) + lin_layer.weight.data = emb.weight.data + return lin_layer + + @torch.no_grad() def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path, hf_checkpoint_name=None): """ @@ -119,7 +126,7 @@ def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path, hf_checkp model = BartForConditionalGeneration(config).eval() # an existing summarization ckpt model.model.load_state_dict(state_dict) if hasattr(model, "lm_head"): - model.lm_head = _make_linear_from_emb(model.model.shared) + model.lm_head = make_linear_from_emb(model.model.shared) new_model_outputs = model.model(tokens)[0] # Check results