clean for release
This commit is contained in:
committed by
Julien Chaumond
parent
2a64107e44
commit
f7eba09007
@@ -1,161 +0,0 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2018 The HuggingFace Inc. team.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
""" Convert BertExtAbs's checkpoints """
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
from collections import namedtuple
|
|
||||||
import logging
|
|
||||||
import pdb
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from models.model_builder import AbsSummarizer # The authors' implementation
|
|
||||||
from model_bertabs import BertAbsSummarizer
|
|
||||||
|
|
||||||
from transformers import BertTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
SAMPLE_TEXT = 'Hello world! cécé herlolip'
|
|
||||||
|
|
||||||
|
|
||||||
BertAbsConfig = namedtuple(
|
|
||||||
"BertAbsConfig",
|
|
||||||
["temp_dir", "large", "use_bert_emb", "finetune_bert", "encoder", "share_emb", "max_pos", "enc_layers", "enc_hidden_size", "enc_heads", "enc_ff_size", "enc_dropout", "dec_layers", "dec_hidden_size", "dec_heads", "dec_ff_size", "dec_dropout"],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def convert_bertabs_checkpoints(path_to_checkpoints, dump_path):
|
|
||||||
""" Copy/paste and tweak the pre-trained weights provided by the creators
|
|
||||||
of BertAbs for the internal architecture.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Instantiate the authors' model with the pre-trained weights
|
|
||||||
config = BertAbsConfig(
|
|
||||||
temp_dir=".",
|
|
||||||
finetune_bert=False,
|
|
||||||
large=False,
|
|
||||||
share_emb=True,
|
|
||||||
use_bert_emb=False,
|
|
||||||
encoder="bert",
|
|
||||||
max_pos=512,
|
|
||||||
enc_layers=6,
|
|
||||||
enc_hidden_size=512,
|
|
||||||
enc_heads=8,
|
|
||||||
enc_ff_size=512,
|
|
||||||
enc_dropout=0.2,
|
|
||||||
dec_layers=6,
|
|
||||||
dec_hidden_size=768,
|
|
||||||
dec_heads=8,
|
|
||||||
dec_ff_size=2048,
|
|
||||||
dec_dropout=0.2,
|
|
||||||
)
|
|
||||||
checkpoints = torch.load(path_to_checkpoints, lambda storage, loc: storage)
|
|
||||||
original = AbsSummarizer(config, torch.device("cpu"), checkpoints)
|
|
||||||
original.eval()
|
|
||||||
|
|
||||||
new_model = BertAbsSummarizer(config, torch.device("cpu"))
|
|
||||||
new_model.eval()
|
|
||||||
|
|
||||||
# -------------------
|
|
||||||
# Convert the weights
|
|
||||||
# -------------------
|
|
||||||
|
|
||||||
logging.info("convert the model")
|
|
||||||
new_model.encoder.load_state_dict(original.bert.state_dict())
|
|
||||||
|
|
||||||
new_model.decoder.generator.load_state_dict(original.generator.state_dict())
|
|
||||||
new_model.decoder.embeddings.load_state_dict(original.decoder.embeddings.state_dict())
|
|
||||||
new_model.decoder.pos_emb.load_state_dict(original.decoder.pos_emb.state_dict())
|
|
||||||
new_model.decoder.transformer_layers.load_state_dict(original.decoder.transformer_layers.state_dict())
|
|
||||||
new_model.decoder.layer_norm.load_state_dict(original.decoder.layer_norm.state_dict())
|
|
||||||
|
|
||||||
# ----------------------------------
|
|
||||||
# Make sure the outpus are identical
|
|
||||||
# ----------------------------------
|
|
||||||
|
|
||||||
logging.info("Make sure that the models' outputs are identical")
|
|
||||||
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
|
||||||
|
|
||||||
# prepare the model inputs
|
|
||||||
encoder_input_ids = tokenizer.encode("This is sample éàalj'-.")
|
|
||||||
encoder_input_ids.extend([tokenizer.pad_token_id] * (512 - len(encoder_input_ids)))
|
|
||||||
encoder_input_ids = torch.tensor(encoder_input_ids).unsqueeze(0)
|
|
||||||
decoder_input_ids = tokenizer.encode("This is sample 3 éàalj'-.")
|
|
||||||
decoder_input_ids.extend([tokenizer.pad_token_id] * (512 - len(decoder_input_ids)))
|
|
||||||
decoder_input_ids = torch.tensor(decoder_input_ids).unsqueeze(0)
|
|
||||||
|
|
||||||
# failsafe to make sure the weights reset does not affect the
|
|
||||||
# loaded weights.
|
|
||||||
assert torch.max(torch.abs(original.generator[0].weight - new_model.decoder.generator[0].weight)) == 0
|
|
||||||
|
|
||||||
# forward pass
|
|
||||||
src = encoder_input_ids
|
|
||||||
tgt = decoder_input_ids
|
|
||||||
segs = token_type_ids = None
|
|
||||||
clss = None
|
|
||||||
mask_src = encoder_attention_mask = None
|
|
||||||
mask_tgt = decoder_attention_mask = None
|
|
||||||
mask_cls = None
|
|
||||||
|
|
||||||
# The original model does not apply the geneator layer immediatly but rather in
|
|
||||||
# the beam search (where it combines softmax + linear layer). Since we already
|
|
||||||
# apply the softmax in our generation process we only apply the linear layer here.
|
|
||||||
# We make sure that the outputs of the full stack are identical
|
|
||||||
output_original_model = original(src, tgt, segs, clss, mask_src, mask_tgt, mask_cls)[0]
|
|
||||||
output_original_model = original.generator(output_original_model)
|
|
||||||
|
|
||||||
output_converted_model = new_model(encoder_input_ids, decoder_input_ids, token_type_ids, encoder_attention_mask, decoder_attention_mask)[0]
|
|
||||||
output_converted_model = torch.nn.functional.log_softmax(output_converted_model, dim=-1)
|
|
||||||
|
|
||||||
maximum_absolute_difference = torch.max(torch.abs(output_converted_model - output_original_model)).item()
|
|
||||||
print("Maximum absolute difference beween weights: {:.2f}".format(maximum_absolute_difference))
|
|
||||||
|
|
||||||
are_identical = torch.allclose(output_converted_model, output_original_model, atol=1e-3)
|
|
||||||
if are_identical:
|
|
||||||
logging.info("all weights are equal up to 1e-3")
|
|
||||||
else:
|
|
||||||
raise ValueError("the weights are different. The new model is likely different from the original one.")
|
|
||||||
|
|
||||||
# The model has been saved with torch.save(model) and this is bound to the exact
|
|
||||||
# directory structure. We save the state_dict instead.
|
|
||||||
logging.info("saving the model's state dictionary")
|
|
||||||
torch.save(new_model.state_dict(), "bert-ext-abs.pt")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument(
|
|
||||||
"--bertabs_checkpoint_path",
|
|
||||||
default=None,
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Path the official PyTorch dump.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--pytorch_dump_folder_path",
|
|
||||||
default=None,
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Path to the output PyTorch model.",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
convert_bertabs_checkpoints(
|
|
||||||
args.bertabs_checkpoint_path,
|
|
||||||
args.pytorch_dump_folder_path,
|
|
||||||
)
|
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
# MIT License
|
# MIT License
|
||||||
|
|
||||||
# Copyright (c) 2019 Yang Liu
|
# Copyright (c) 2019 Yang Liu and the HuggingFace team
|
||||||
|
|
||||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
# of this software and associated documentation files (the "Software"), to deal
|
# of this software and associated documentation files (the "Software"), to deal
|
||||||
|
|||||||
9
examples/summarization/requirements.txt
Normal file
9
examples/summarization/requirements.txt
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
# progress bars in model download and training scripts
|
||||||
|
tqdm
|
||||||
|
# Accessing files from S3 directly.
|
||||||
|
boto3
|
||||||
|
# Used for downloading models over HTTP
|
||||||
|
requests
|
||||||
|
# For ROUGE
|
||||||
|
nltk
|
||||||
|
py-rouge
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
#! /usr/bin/python3
|
||||||
import argparse
|
import argparse
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
import logging
|
import logging
|
||||||
@@ -97,6 +98,32 @@ def evaluate(args):
|
|||||||
print(str_scores)
|
print(str_scores)
|
||||||
|
|
||||||
|
|
||||||
|
def save_summaries(summaries, path, original_document_name):
|
||||||
|
""" Write the summaries in fies that are prefixed by the original
|
||||||
|
files' name with the `_summary` appended.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
original_document_names: List[string]
|
||||||
|
Name of the document that was summarized.
|
||||||
|
path: string
|
||||||
|
Path were the summaries will be written
|
||||||
|
summaries: List[string]
|
||||||
|
The summaries that we produced.
|
||||||
|
"""
|
||||||
|
for summary, document_name in zip(summaries, original_document_name):
|
||||||
|
# Prepare the summary file's name
|
||||||
|
if "." in document_name:
|
||||||
|
bare_document_name = ".".join(document_name.split(".")[:-1])
|
||||||
|
extension = document_name.split(".")[-1]
|
||||||
|
name = bare_document_name + "_summary." + extension
|
||||||
|
else:
|
||||||
|
name = document_name + "_summary"
|
||||||
|
|
||||||
|
file_path = os.path.join(path, name)
|
||||||
|
with open(file_path, "w") as output:
|
||||||
|
output.write(summary)
|
||||||
|
|
||||||
|
|
||||||
def format_summary(translation):
|
def format_summary(translation):
|
||||||
""" Transforms the output of the `from_batch` function
|
""" Transforms the output of the `from_batch` function
|
||||||
into nicely formatted summaries.
|
into nicely formatted summaries.
|
||||||
@@ -151,32 +178,6 @@ def save_rouge_scores(str_scores):
|
|||||||
output.write(str_scores)
|
output.write(str_scores)
|
||||||
|
|
||||||
|
|
||||||
def save_summaries(summaries, path, original_document_name):
|
|
||||||
""" Write the summaries in fies that are prefixed by the original
|
|
||||||
files' name with the `_summary` appended.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
original_document_names: List[string]
|
|
||||||
Name of the document that was summarized.
|
|
||||||
path: string
|
|
||||||
Path were the summaries will be written
|
|
||||||
summaries: List[string]
|
|
||||||
The summaries that we produced.
|
|
||||||
"""
|
|
||||||
for summary, document_name in zip(summaries, original_document_name):
|
|
||||||
# Prepare the summary file's name
|
|
||||||
if "." in document_name:
|
|
||||||
bare_document_name = ".".join(document_name.split(".")[:-1])
|
|
||||||
extension = document_name.split(".")[-1]
|
|
||||||
name = bare_document_name + "_summary." + extension
|
|
||||||
else:
|
|
||||||
name = document_name + "_summary"
|
|
||||||
|
|
||||||
file_path = os.path.join(path, name)
|
|
||||||
with open(file_path, "w") as output:
|
|
||||||
output.write(summary)
|
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# LOAD the dataset
|
# LOAD the dataset
|
||||||
#
|
#
|
||||||
@@ -323,7 +324,7 @@ def main():
|
|||||||
raise FileNotFoundError(
|
raise FileNotFoundError(
|
||||||
"We could not find the directory you specified for the documents to summarize, or it was empty. Please specify a valid path."
|
"We could not find the directory you specified for the documents to summarize, or it was empty. Please specify a valid path."
|
||||||
)
|
)
|
||||||
maybe_create_output_dir(args.summaries_output_dir)
|
os.makedirs(args.summaries_output_dir, exist_ok=True)
|
||||||
|
|
||||||
evaluate(args)
|
evaluate(args)
|
||||||
|
|
||||||
@@ -339,10 +340,5 @@ def documents_dir_is_valid(path):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def maybe_create_output_dir(path):
|
|
||||||
if not os.path.exists(path):
|
|
||||||
os.makedirs(path)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -10,6 +10,3 @@ regex
|
|||||||
sentencepiece
|
sentencepiece
|
||||||
# For XLM
|
# For XLM
|
||||||
sacremoses
|
sacremoses
|
||||||
# For ROUGE
|
|
||||||
nltk
|
|
||||||
py-rouge
|
|
||||||
|
|||||||
@@ -1,158 +0,0 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2018 The HuggingFace Inc. team.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
""" Convert BertExtAbs's checkpoints """
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
from collections import namedtuple
|
|
||||||
import logging
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from models.model_builder import AbsSummarizer # The authors' implementation
|
|
||||||
|
|
||||||
from transformers import BertConfig, Model2Model, BertModel, BertForMaskedLM
|
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
BertExtAbsConfig = namedtuple(
|
|
||||||
"BertExtAbsConfig",
|
|
||||||
["temp_dir", "large", "finetune_bert", "encoder", "share_emb", "max_pos", "enc_layers", "enc_hidden_size", "enc_heads", "enc_ff_size", "enc_dropout", "dec_layers", "dec_hidden_size", "dec_heads", "dec_ff_size", "dec_dropout"],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def convert_bertextabs_checkpoints(path_to_checkpoints, dump_path):
|
|
||||||
""" Copy/paste and tweak the pre-trained weights provided by the creators
|
|
||||||
of BertExtAbs for the internal architecture.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Load checkpoints in memory
|
|
||||||
checkpoints = torch.load(path_to_checkpoints, lambda storage, loc: storage)
|
|
||||||
|
|
||||||
# Instantiate the authors' model with the pre-trained weights
|
|
||||||
config = BertExtAbsConfig(
|
|
||||||
temp_dir=".",
|
|
||||||
finetune_bert=False,
|
|
||||||
large=False,
|
|
||||||
share_emb=True,
|
|
||||||
encoder="bert",
|
|
||||||
max_pos=512,
|
|
||||||
enc_layers=6,
|
|
||||||
enc_hidden_size=512,
|
|
||||||
enc_heads=8,
|
|
||||||
enc_ff_size=512,
|
|
||||||
enc_dropout=0.2,
|
|
||||||
dec_layers=6,
|
|
||||||
dec_hidden_size=768,
|
|
||||||
dec_heads=8,
|
|
||||||
dec_ff_size=2048,
|
|
||||||
dec_dropout=0.2,
|
|
||||||
)
|
|
||||||
bertextabs = AbsSummarizer(config, torch.device("cpu"), checkpoints)
|
|
||||||
bertextabs.eval()
|
|
||||||
|
|
||||||
# Instantiate our version of the model
|
|
||||||
decoder_config = BertConfig(
|
|
||||||
hidden_size=config.dec_hidden_size,
|
|
||||||
num_hidden_layers=config.dec_layers,
|
|
||||||
num_attention_heads=config.dec_heads,
|
|
||||||
intermediate_size=config.dec_ff_size,
|
|
||||||
hidden_dropout_prob=config.dec_dropout,
|
|
||||||
attention_probs_dropout_prob=config.dec_dropout,
|
|
||||||
is_decoder=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
decoder_model = BertForMaskedLM(decoder_config)
|
|
||||||
model = Model2Model.from_pretrained('bert-base-uncased', decoder_model=decoder_model)
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
# Let us now start the weight copying process
|
|
||||||
model.encoder.load_state_dict(bertextabs.bert.model.state_dict())
|
|
||||||
|
|
||||||
# Decoder
|
|
||||||
|
|
||||||
# Embeddings. The positional embeddings are equal to the word embedding plus a modulation
|
|
||||||
# that is computed at each forward pass. This may be a source of discrepancy.
|
|
||||||
model.decoder.bert.embeddings.word_embeddings.weight = bertextabs.decoder.embeddings.weight
|
|
||||||
model.decoder.bert.embeddings.position_embeddings.weight = bertextabs.decoder.embeddings.weight
|
|
||||||
model.decoder.bert.embeddings.token_type_embeddings.weight.data = torch.zeros_like(bertextabs.decoder.embeddings.weight) # not defined for BertExtAbs decoder
|
|
||||||
|
|
||||||
# In the original code the LayerNorms are applied twice in the layers, at the beginning and between the
|
|
||||||
# attention layers.
|
|
||||||
model.decoder.bert.embeddings.LayerNorm.weight = bertextabs.decoder.transformer_layers[0].layer_norm_1.weight
|
|
||||||
|
|
||||||
for i in range(config.dec_layers):
|
|
||||||
|
|
||||||
# self attention
|
|
||||||
model.decoder.bert.encoder.layer[i].attention.self.query.weight = bertextabs.decoder.transformer_layers[i].self_attn.linear_query.weight
|
|
||||||
model.decoder.bert.encoder.layer[i].attention.self.key.weight = bertextabs.decoder.transformer_layers[i].self_attn.linear_keys.weight
|
|
||||||
model.decoder.bert.encoder.layer[i].attention.self.value.weight = bertextabs.decoder.transformer_layers[i].self_attn.linear_values.weight
|
|
||||||
model.decoder.bert.encoder.layer[i].attention.output.dense.weight = bertextabs.decoder.transformer_layers[i].self_attn.final_linear.weight
|
|
||||||
model.decoder.bert.encoder.layer[i].attention.output.LayerNorm.weight = bertextabs.decoder.transformer_layers[i].layer_norm_2.weight
|
|
||||||
|
|
||||||
# attention
|
|
||||||
model.decoder.bert.encoder.layer[i].crossattention.self.query.weight = bertextabs.decoder.transformer_layers[i].context_attn.linear_query.weight
|
|
||||||
model.decoder.bert.encoder.layer[i].crossattention.self.key.weight = bertextabs.decoder.transformer_layers[i].context_attn.linear_keys.weight
|
|
||||||
model.decoder.bert.encoder.layer[i].crossattention.self.value.weight = bertextabs.decoder.transformer_layers[i].context_attn.linear_values.weight
|
|
||||||
model.decoder.bert.encoder.layer[i].crossattention.output.dense.weight = bertextabs.decoder.transformer_layers[i].context_attn.final_linear.weight
|
|
||||||
model.decoder.bert.encoder.layer[i].crossattention.output.LayerNorm.weight = bertextabs.decoder.transformer_layers[i].feed_forward.layer_norm.weight
|
|
||||||
|
|
||||||
# intermediate
|
|
||||||
model.decoder.bert.encoder.layer[i].intermediate.dense.weight = bertextabs.decoder.transformer_layers[i].feed_forward.w_1.weight
|
|
||||||
|
|
||||||
# output
|
|
||||||
model.decoder.bert.encoder.layer[i].output.dense.weight = bertextabs.decoder.transformer_layers[i].feed_forward.w_2.weight
|
|
||||||
|
|
||||||
try:
|
|
||||||
model.decoder.bert.encoder.layer[i].output.LayerNorm.weight = bertextabs.decoder.transformer_layers[i + 1].layer_norm_1.weight
|
|
||||||
except IndexError:
|
|
||||||
model.decoder.bert.encoder.layer[i].output.LayerNorm.weight = bertextabs.decoder.layer_norm.weight
|
|
||||||
|
|
||||||
# LM Head
|
|
||||||
"""
|
|
||||||
model.decoder.cls.predictions.transform.dense.weight
|
|
||||||
model.decoder.cls.predictions.transform.dense.biais
|
|
||||||
model.decoder.cls.predictions.transform.LayerNorm.weight
|
|
||||||
model.decoder.cls.predictions.transform.LayerNorm.biais
|
|
||||||
model.decoder.cls.predictions.decoder.weight
|
|
||||||
model.decoder.cls.predictions.decoder.biais
|
|
||||||
model.decoder.cls.predictions.biais.data
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument(
|
|
||||||
"--bertextabs_checkpoint_path",
|
|
||||||
default=None,
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Path the official PyTorch dump.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--pytorch_dump_folder_path",
|
|
||||||
default=None,
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Path to the output PyTorch model.",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
convert_bertextabs_checkpoints(
|
|
||||||
args.bertextabs_checkpoint_path,
|
|
||||||
args.pytorch_dump_folder_path,
|
|
||||||
)
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
from .beam_search import BeamSearch
|
|
||||||
@@ -117,7 +117,8 @@ class PreTrainedEncoderDecoder(nn.Module):
|
|||||||
kwargs_common = {
|
kwargs_common = {
|
||||||
argument: value
|
argument: value
|
||||||
for argument, value in kwargs.items()
|
for argument, value in kwargs.items()
|
||||||
if not argument.startswith("encoder_") and not argument.startswith("decoder_")
|
if not argument.startswith("encoder_")
|
||||||
|
and not argument.startswith("decoder_")
|
||||||
}
|
}
|
||||||
kwargs_decoder = kwargs_common.copy()
|
kwargs_decoder = kwargs_common.copy()
|
||||||
kwargs_encoder = kwargs_common.copy()
|
kwargs_encoder = kwargs_common.copy()
|
||||||
@@ -157,27 +158,14 @@ class PreTrainedEncoderDecoder(nn.Module):
|
|||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def save_pretrained(self, save_directory, model_type="bert"):
|
def save_pretrained(self, save_directory):
|
||||||
""" Save an EncoderDecoder model and its configuration file in a format such
|
""" Save a Seq2Seq model and its configuration file in a format such
|
||||||
that it can be loaded using `:func:`~transformers.PreTrainedEncoderDecoder.from_pretrained`
|
that it can be loaded using `:func:`~transformers.PreTrainedEncoderDecoder.from_pretrained`
|
||||||
|
|
||||||
We save the encoder' and decoder's parameters in two separate directories.
|
We save the encoder' and decoder's parameters in two separate directories.
|
||||||
|
|
||||||
If we want the weight loader to function we need to preprend the model
|
|
||||||
type to the directories' names. As far as I know there is no simple way
|
|
||||||
to infer the type of the model (except maybe by parsing the class'
|
|
||||||
names, which is not very future-proof). For now, we ask the user to
|
|
||||||
specify the model type explicitly when saving the weights.
|
|
||||||
"""
|
"""
|
||||||
encoder_path = os.path.join(save_directory, "{}_encoder".format(model_type))
|
self.encoder.save_pretrained(os.path.join(save_directory, "encoder"))
|
||||||
if not os.path.exists(encoder_path):
|
self.decoder.save_pretrained(os.path.join(save_directory, "decoder"))
|
||||||
os.makedirs(encoder_path)
|
|
||||||
self.encoder.save_pretrained(encoder_path)
|
|
||||||
|
|
||||||
decoder_path = os.path.join(save_directory, "{}_decoder".format(model_type))
|
|
||||||
if not os.path.exists(decoder_path):
|
|
||||||
os.makedirs(decoder_path)
|
|
||||||
self.decoder.save_pretrained(decoder_path)
|
|
||||||
|
|
||||||
def forward(self, encoder_input_ids, decoder_input_ids, **kwargs):
|
def forward(self, encoder_input_ids, decoder_input_ids, **kwargs):
|
||||||
""" The forward pass on a seq2eq depends what we are performing:
|
""" The forward pass on a seq2eq depends what we are performing:
|
||||||
@@ -205,7 +193,8 @@ class PreTrainedEncoderDecoder(nn.Module):
|
|||||||
kwargs_common = {
|
kwargs_common = {
|
||||||
argument: value
|
argument: value
|
||||||
for argument, value in kwargs.items()
|
for argument, value in kwargs.items()
|
||||||
if not argument.startswith("encoder_") and not argument.startswith("decoder_")
|
if not argument.startswith("encoder_")
|
||||||
|
and not argument.startswith("decoder_")
|
||||||
}
|
}
|
||||||
kwargs_decoder = kwargs_common.copy()
|
kwargs_decoder = kwargs_common.copy()
|
||||||
kwargs_encoder = kwargs_common.copy()
|
kwargs_encoder = kwargs_common.copy()
|
||||||
@@ -228,7 +217,9 @@ class PreTrainedEncoderDecoder(nn.Module):
|
|||||||
encoder_hidden_states = kwargs_encoder.pop("hidden_states", None)
|
encoder_hidden_states = kwargs_encoder.pop("hidden_states", None)
|
||||||
if encoder_hidden_states is None:
|
if encoder_hidden_states is None:
|
||||||
encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder)
|
encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder)
|
||||||
encoder_hidden_states = encoder_outputs[0] # output the last layer hidden state
|
encoder_hidden_states = encoder_outputs[
|
||||||
|
0
|
||||||
|
] # output the last layer hidden state
|
||||||
else:
|
else:
|
||||||
encoder_outputs = ()
|
encoder_outputs = ()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user