Fix bart conversion script (#9923)
* fix conversion script * typo * import nn
This commit is contained in:
@@ -22,6 +22,7 @@ from pathlib import Path
|
|||||||
import fairseq
|
import fairseq
|
||||||
import torch
|
import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
BartConfig,
|
BartConfig,
|
||||||
@@ -30,7 +31,6 @@ from transformers import (
|
|||||||
BartModel,
|
BartModel,
|
||||||
BartTokenizer,
|
BartTokenizer,
|
||||||
)
|
)
|
||||||
from transformers.models.bart.modeling_bart import _make_linear_from_emb
|
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -78,6 +78,13 @@ def load_xsum_checkpoint(checkpoint_path):
|
|||||||
return hub_interface
|
return hub_interface
|
||||||
|
|
||||||
|
|
||||||
|
def make_linear_from_emb(emb):
|
||||||
|
vocab_size, emb_size = emb.weight.shape
|
||||||
|
lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
|
||||||
|
lin_layer.weight.data = emb.weight.data
|
||||||
|
return lin_layer
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path, hf_checkpoint_name=None):
|
def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path, hf_checkpoint_name=None):
|
||||||
"""
|
"""
|
||||||
@@ -119,7 +126,7 @@ def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path, hf_checkp
|
|||||||
model = BartForConditionalGeneration(config).eval() # an existing summarization ckpt
|
model = BartForConditionalGeneration(config).eval() # an existing summarization ckpt
|
||||||
model.model.load_state_dict(state_dict)
|
model.model.load_state_dict(state_dict)
|
||||||
if hasattr(model, "lm_head"):
|
if hasattr(model, "lm_head"):
|
||||||
model.lm_head = _make_linear_from_emb(model.model.shared)
|
model.lm_head = make_linear_from_emb(model.model.shared)
|
||||||
new_model_outputs = model.model(tokens)[0]
|
new_model_outputs = model.model(tokens)[0]
|
||||||
|
|
||||||
# Check results
|
# Check results
|
||||||
|
|||||||
Reference in New Issue
Block a user