clean for release

This commit is contained in:
Rémi Louf
2019-12-06 22:01:48 +01:00
committed by Julien Chaumond
parent 2a64107e44
commit f7eba09007
8 changed files with 49 additions and 376 deletions

View File

@@ -1,158 +0,0 @@
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Convert BertExtAbs's checkpoints """
import argparse
from collections import namedtuple
import logging
import torch
from models.model_builder import AbsSummarizer # The authors' implementation
from transformers import BertConfig, Model2Model, BertModel, BertForMaskedLM
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
BertExtAbsConfig = namedtuple(
"BertExtAbsConfig",
["temp_dir", "large", "finetune_bert", "encoder", "share_emb", "max_pos", "enc_layers", "enc_hidden_size", "enc_heads", "enc_ff_size", "enc_dropout", "dec_layers", "dec_hidden_size", "dec_heads", "dec_ff_size", "dec_dropout"],
)
def convert_bertextabs_checkpoints(path_to_checkpoints, dump_path):
""" Copy/paste and tweak the pre-trained weights provided by the creators
of BertExtAbs for the internal architecture.
"""
# Load checkpoints in memory
checkpoints = torch.load(path_to_checkpoints, lambda storage, loc: storage)
# Instantiate the authors' model with the pre-trained weights
config = BertExtAbsConfig(
temp_dir=".",
finetune_bert=False,
large=False,
share_emb=True,
encoder="bert",
max_pos=512,
enc_layers=6,
enc_hidden_size=512,
enc_heads=8,
enc_ff_size=512,
enc_dropout=0.2,
dec_layers=6,
dec_hidden_size=768,
dec_heads=8,
dec_ff_size=2048,
dec_dropout=0.2,
)
bertextabs = AbsSummarizer(config, torch.device("cpu"), checkpoints)
bertextabs.eval()
# Instantiate our version of the model
decoder_config = BertConfig(
hidden_size=config.dec_hidden_size,
num_hidden_layers=config.dec_layers,
num_attention_heads=config.dec_heads,
intermediate_size=config.dec_ff_size,
hidden_dropout_prob=config.dec_dropout,
attention_probs_dropout_prob=config.dec_dropout,
is_decoder=True,
)
decoder_model = BertForMaskedLM(decoder_config)
model = Model2Model.from_pretrained('bert-base-uncased', decoder_model=decoder_model)
model.eval()
# Let us now start the weight copying process
model.encoder.load_state_dict(bertextabs.bert.model.state_dict())
# Decoder
# Embeddings. The positional embeddings are equal to the word embedding plus a modulation
# that is computed at each forward pass. This may be a source of discrepancy.
model.decoder.bert.embeddings.word_embeddings.weight = bertextabs.decoder.embeddings.weight
model.decoder.bert.embeddings.position_embeddings.weight = bertextabs.decoder.embeddings.weight
model.decoder.bert.embeddings.token_type_embeddings.weight.data = torch.zeros_like(bertextabs.decoder.embeddings.weight) # not defined for BertExtAbs decoder
# In the original code the LayerNorms are applied twice in the layers, at the beginning and between the
# attention layers.
model.decoder.bert.embeddings.LayerNorm.weight = bertextabs.decoder.transformer_layers[0].layer_norm_1.weight
for i in range(config.dec_layers):
# self attention
model.decoder.bert.encoder.layer[i].attention.self.query.weight = bertextabs.decoder.transformer_layers[i].self_attn.linear_query.weight
model.decoder.bert.encoder.layer[i].attention.self.key.weight = bertextabs.decoder.transformer_layers[i].self_attn.linear_keys.weight
model.decoder.bert.encoder.layer[i].attention.self.value.weight = bertextabs.decoder.transformer_layers[i].self_attn.linear_values.weight
model.decoder.bert.encoder.layer[i].attention.output.dense.weight = bertextabs.decoder.transformer_layers[i].self_attn.final_linear.weight
model.decoder.bert.encoder.layer[i].attention.output.LayerNorm.weight = bertextabs.decoder.transformer_layers[i].layer_norm_2.weight
# attention
model.decoder.bert.encoder.layer[i].crossattention.self.query.weight = bertextabs.decoder.transformer_layers[i].context_attn.linear_query.weight
model.decoder.bert.encoder.layer[i].crossattention.self.key.weight = bertextabs.decoder.transformer_layers[i].context_attn.linear_keys.weight
model.decoder.bert.encoder.layer[i].crossattention.self.value.weight = bertextabs.decoder.transformer_layers[i].context_attn.linear_values.weight
model.decoder.bert.encoder.layer[i].crossattention.output.dense.weight = bertextabs.decoder.transformer_layers[i].context_attn.final_linear.weight
model.decoder.bert.encoder.layer[i].crossattention.output.LayerNorm.weight = bertextabs.decoder.transformer_layers[i].feed_forward.layer_norm.weight
# intermediate
model.decoder.bert.encoder.layer[i].intermediate.dense.weight = bertextabs.decoder.transformer_layers[i].feed_forward.w_1.weight
# output
model.decoder.bert.encoder.layer[i].output.dense.weight = bertextabs.decoder.transformer_layers[i].feed_forward.w_2.weight
try:
model.decoder.bert.encoder.layer[i].output.LayerNorm.weight = bertextabs.decoder.transformer_layers[i + 1].layer_norm_1.weight
except IndexError:
model.decoder.bert.encoder.layer[i].output.LayerNorm.weight = bertextabs.decoder.layer_norm.weight
# LM Head
"""
model.decoder.cls.predictions.transform.dense.weight
model.decoder.cls.predictions.transform.dense.biais
model.decoder.cls.predictions.transform.LayerNorm.weight
model.decoder.cls.predictions.transform.LayerNorm.biais
model.decoder.cls.predictions.decoder.weight
model.decoder.cls.predictions.decoder.biais
model.decoder.cls.predictions.biais.data
"""
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--bertextabs_checkpoint_path",
default=None,
type=str,
required=True,
help="Path the official PyTorch dump.",
)
parser.add_argument(
"--pytorch_dump_folder_path",
default=None,
type=str,
required=True,
help="Path to the output PyTorch model.",
)
args = parser.parse_args()
convert_bertextabs_checkpoints(
args.bertextabs_checkpoint_path,
args.pytorch_dump_folder_path,
)

View File

@@ -1 +0,0 @@
from .beam_search import BeamSearch

View File

@@ -117,7 +117,8 @@ class PreTrainedEncoderDecoder(nn.Module):
kwargs_common = {
argument: value
for argument, value in kwargs.items()
if not argument.startswith("encoder_") and not argument.startswith("decoder_")
if not argument.startswith("encoder_")
and not argument.startswith("decoder_")
}
kwargs_decoder = kwargs_common.copy()
kwargs_encoder = kwargs_common.copy()
@@ -157,27 +158,14 @@ class PreTrainedEncoderDecoder(nn.Module):
return model
def save_pretrained(self, save_directory, model_type="bert"):
""" Save an EncoderDecoder model and its configuration file in a format such
def save_pretrained(self, save_directory):
""" Save a Seq2Seq model and its configuration file in a format such
that it can be loaded using `:func:`~transformers.PreTrainedEncoderDecoder.from_pretrained`
We save the encoder' and decoder's parameters in two separate directories.
If we want the weight loader to function we need to preprend the model
type to the directories' names. As far as I know there is no simple way
to infer the type of the model (except maybe by parsing the class'
names, which is not very future-proof). For now, we ask the user to
specify the model type explicitly when saving the weights.
"""
encoder_path = os.path.join(save_directory, "{}_encoder".format(model_type))
if not os.path.exists(encoder_path):
os.makedirs(encoder_path)
self.encoder.save_pretrained(encoder_path)
decoder_path = os.path.join(save_directory, "{}_decoder".format(model_type))
if not os.path.exists(decoder_path):
os.makedirs(decoder_path)
self.decoder.save_pretrained(decoder_path)
self.encoder.save_pretrained(os.path.join(save_directory, "encoder"))
self.decoder.save_pretrained(os.path.join(save_directory, "decoder"))
def forward(self, encoder_input_ids, decoder_input_ids, **kwargs):
""" The forward pass on a seq2eq depends what we are performing:
@@ -205,7 +193,8 @@ class PreTrainedEncoderDecoder(nn.Module):
kwargs_common = {
argument: value
for argument, value in kwargs.items()
if not argument.startswith("encoder_") and not argument.startswith("decoder_")
if not argument.startswith("encoder_")
and not argument.startswith("decoder_")
}
kwargs_decoder = kwargs_common.copy()
kwargs_encoder = kwargs_common.copy()
@@ -228,7 +217,9 @@ class PreTrainedEncoderDecoder(nn.Module):
encoder_hidden_states = kwargs_encoder.pop("hidden_states", None)
if encoder_hidden_states is None:
encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder)
encoder_hidden_states = encoder_outputs[0] # output the last layer hidden state
encoder_hidden_states = encoder_outputs[
0
] # output the last layer hidden state
else:
encoder_outputs = ()