give transformers API to BertAbs
This commit is contained in:
committed by
Julien Chaumond
parent
4d18199902
commit
2403a66598
@@ -0,0 +1,161 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2018 The HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" Convert BertExtAbs's checkpoints """
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from collections import namedtuple
|
||||||
|
import logging
|
||||||
|
import pdb
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from models.model_builder import AbsSummarizer # The authors' implementation
|
||||||
|
from model_bertabs import BertAbsSummarizer
|
||||||
|
|
||||||
|
from transformers import BertTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
SAMPLE_TEXT = 'Hello world! cécé herlolip'
|
||||||
|
|
||||||
|
|
||||||
|
BertAbsConfig = namedtuple(
|
||||||
|
"BertAbsConfig",
|
||||||
|
["temp_dir", "large", "use_bert_emb", "finetune_bert", "encoder", "share_emb", "max_pos", "enc_layers", "enc_hidden_size", "enc_heads", "enc_ff_size", "enc_dropout", "dec_layers", "dec_hidden_size", "dec_heads", "dec_ff_size", "dec_dropout"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_bertabs_checkpoints(path_to_checkpoints, dump_path):
|
||||||
|
""" Copy/paste and tweak the pre-trained weights provided by the creators
|
||||||
|
of BertAbs for the internal architecture.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Instantiate the authors' model with the pre-trained weights
|
||||||
|
config = BertAbsConfig(
|
||||||
|
temp_dir=".",
|
||||||
|
finetune_bert=False,
|
||||||
|
large=False,
|
||||||
|
share_emb=True,
|
||||||
|
use_bert_emb=False,
|
||||||
|
encoder="bert",
|
||||||
|
max_pos=512,
|
||||||
|
enc_layers=6,
|
||||||
|
enc_hidden_size=512,
|
||||||
|
enc_heads=8,
|
||||||
|
enc_ff_size=512,
|
||||||
|
enc_dropout=0.2,
|
||||||
|
dec_layers=6,
|
||||||
|
dec_hidden_size=768,
|
||||||
|
dec_heads=8,
|
||||||
|
dec_ff_size=2048,
|
||||||
|
dec_dropout=0.2,
|
||||||
|
)
|
||||||
|
checkpoints = torch.load(path_to_checkpoints, lambda storage, loc: storage)
|
||||||
|
original = AbsSummarizer(config, torch.device("cpu"), checkpoints)
|
||||||
|
original.eval()
|
||||||
|
|
||||||
|
new_model = BertAbsSummarizer(config, torch.device("cpu"))
|
||||||
|
new_model.eval()
|
||||||
|
|
||||||
|
# -------------------
|
||||||
|
# Convert the weights
|
||||||
|
# -------------------
|
||||||
|
|
||||||
|
logging.info("convert the model")
|
||||||
|
new_model.encoder.load_state_dict(original.bert.state_dict())
|
||||||
|
|
||||||
|
new_model.decoder.generator.load_state_dict(original.generator.state_dict())
|
||||||
|
new_model.decoder.embeddings.load_state_dict(original.decoder.embeddings.state_dict())
|
||||||
|
new_model.decoder.pos_emb.load_state_dict(original.decoder.pos_emb.state_dict())
|
||||||
|
new_model.decoder.transformer_layers.load_state_dict(original.decoder.transformer_layers.state_dict())
|
||||||
|
new_model.decoder.layer_norm.load_state_dict(original.decoder.layer_norm.state_dict())
|
||||||
|
|
||||||
|
# ----------------------------------
|
||||||
|
# Make sure the outpus are identical
|
||||||
|
# ----------------------------------
|
||||||
|
|
||||||
|
logging.info("Make sure that the models' outputs are identical")
|
||||||
|
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
||||||
|
|
||||||
|
# prepare the model inputs
|
||||||
|
encoder_input_ids = tokenizer.encode("This is sample éàalj'-.")
|
||||||
|
encoder_input_ids.extend([tokenizer.pad_token_id] * (512 - len(encoder_input_ids)))
|
||||||
|
encoder_input_ids = torch.tensor(encoder_input_ids).unsqueeze(0)
|
||||||
|
decoder_input_ids = tokenizer.encode("This is sample 3 éàalj'-.")
|
||||||
|
decoder_input_ids.extend([tokenizer.pad_token_id] * (512 - len(decoder_input_ids)))
|
||||||
|
decoder_input_ids = torch.tensor(decoder_input_ids).unsqueeze(0)
|
||||||
|
|
||||||
|
# failsafe to make sure the weights reset does not affect the
|
||||||
|
# loaded weights.
|
||||||
|
assert torch.max(torch.abs(original.generator[0].weight - new_model.decoder.generator[0].weight)) == 0
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
src = encoder_input_ids
|
||||||
|
tgt = decoder_input_ids
|
||||||
|
segs = token_type_ids = None
|
||||||
|
clss = None
|
||||||
|
mask_src = encoder_attention_mask = None
|
||||||
|
mask_tgt = decoder_attention_mask = None
|
||||||
|
mask_cls = None
|
||||||
|
|
||||||
|
# The original model does not apply the geneator layer immediatly but rather in
|
||||||
|
# the beam search (where it combines softmax + linear layer). Since we already
|
||||||
|
# apply the softmax in our generation process we only apply the linear layer here.
|
||||||
|
# We make sure that the outputs of the full stack are identical
|
||||||
|
output_original_model = original(src, tgt, segs, clss, mask_src, mask_tgt, mask_cls)[0]
|
||||||
|
output_original_model = original.generator(output_original_model)
|
||||||
|
|
||||||
|
output_converted_model = new_model(encoder_input_ids, decoder_input_ids, token_type_ids, encoder_attention_mask, decoder_attention_mask)[0]
|
||||||
|
output_converted_model = torch.nn.functional.log_softmax(output_converted_model, dim=-1)
|
||||||
|
|
||||||
|
maximum_absolute_difference = torch.max(torch.abs(output_converted_model - output_original_model)).item()
|
||||||
|
print("Maximum absolute difference beween weights: {:.2f}".format(maximum_absolute_difference))
|
||||||
|
|
||||||
|
are_identical = torch.allclose(output_converted_model, output_original_model, atol=1e-3)
|
||||||
|
if are_identical:
|
||||||
|
logging.info("all weights are equal up to 1e-3")
|
||||||
|
else:
|
||||||
|
raise ValueError("the weights are different. The new model is likely different from the original one.")
|
||||||
|
|
||||||
|
# The model has been saved with torch.save(model) and this is bound to the exact
|
||||||
|
# directory structure. We save the state_dict instead.
|
||||||
|
logging.info("saving the model's state dictionary")
|
||||||
|
torch.save(new_model.state_dict(), "bert-ext-abs.pt")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--bertabs_checkpoint_path",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path the official PyTorch dump.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pytorch_dump_folder_path",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the output PyTorch model.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
convert_bertabs_checkpoints(
|
||||||
|
args.bertabs_checkpoint_path,
|
||||||
|
args.pytorch_dump_folder_path,
|
||||||
|
)
|
||||||
141
examples/summarization/configuration_bertabs.py
Normal file
141
examples/summarization/configuration_bertabs.py
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2019 The HuggingFace Inc. team.
|
||||||
|
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" BertAbs configuration """
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
BERTABS_FINETUNED_CONFIG_MAP = {
|
||||||
|
"bertabs-finetuned-cnndm": "https://s3.amazonaws.com/models.huggingface.co/bert/remi/bertabs-finetuned-cnndm-extractive-abstractive-summarization-config.json",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class BertAbsConfig(PretrainedConfig):
|
||||||
|
r""" Class to store the configuration of the BertAbs model.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
temp_dir: string
|
||||||
|
Unused in the current situation. Kept for compatibility but will be removed.
|
||||||
|
finetune_bert: bool
|
||||||
|
Whether to fine-tune the model or not. Will be kept for reference
|
||||||
|
in case we want to add the possibility to fine-tune the model.
|
||||||
|
large: bool
|
||||||
|
Whether to use bert-large as a base.
|
||||||
|
share_emb: book
|
||||||
|
Whether the embeddings are shared between the encoder and decoder.
|
||||||
|
encoder: string
|
||||||
|
Not clear what this does. Leave to "bert" for pre-trained weights.
|
||||||
|
max_pos: int
|
||||||
|
The maximum sequence length that this model will be used with.
|
||||||
|
enc_layer: int
|
||||||
|
The numner of hidden layers in the Transformer encoder.
|
||||||
|
enc_hidden_size: int
|
||||||
|
The size of the encoder's layers.
|
||||||
|
enc_heads: int
|
||||||
|
The number of attention heads for each attention layer in the encoder.
|
||||||
|
enc_ff_size: int
|
||||||
|
The size of the encoder's feed-forward layers.
|
||||||
|
enc_dropout: int
|
||||||
|
The dropout probabilitiy for all fully connected layers in the
|
||||||
|
embeddings, layers, pooler and also the attention probabilities in
|
||||||
|
the encoder.
|
||||||
|
dec_layer: int
|
||||||
|
The numner of hidden layers in the decoder.
|
||||||
|
dec_hidden_size: int
|
||||||
|
The size of the decoder's layers.
|
||||||
|
dec_heads: int
|
||||||
|
The number of attention heads for each attention layer in the decoder.
|
||||||
|
dec_ff_size: int
|
||||||
|
The size of the decoder's feed-forward layers.
|
||||||
|
dec_dropout: int
|
||||||
|
The dropout probabilitiy for all fully connected layers in the
|
||||||
|
embeddings, layers, pooler and also the attention probabilities in
|
||||||
|
the decoder.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pretrained_config_archive_map = BERTABS_FINETUNED_CONFIG_MAP
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size_or_config_json_file=30522,
|
||||||
|
temp_dir=".",
|
||||||
|
finetune_bert=False,
|
||||||
|
large=False,
|
||||||
|
share_emb=True,
|
||||||
|
encoder="bert",
|
||||||
|
max_pos=512,
|
||||||
|
enc_layers=6,
|
||||||
|
enc_hidden_size=512,
|
||||||
|
enc_heads=8,
|
||||||
|
enc_ff_size=512,
|
||||||
|
enc_dropout=0.2,
|
||||||
|
dec_layers=6,
|
||||||
|
dec_hidden_size=768,
|
||||||
|
dec_heads=8,
|
||||||
|
dec_ff_size=2048,
|
||||||
|
dec_dropout=0.2,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super(BertAbsConfig, self).__init__(**kwargs)
|
||||||
|
|
||||||
|
if self._input_is_path_to_json(vocab_size_or_config_json_file):
|
||||||
|
path_to_json = vocab_size_or_config_json_file
|
||||||
|
with open(path_to_json, "r", encoding="utf-8") as reader:
|
||||||
|
json_config = json.loads(reader.read())
|
||||||
|
for key, value in json_config.items():
|
||||||
|
self.__dict__[key] = value
|
||||||
|
elif isinstance(vocab_size_or_config_json_file, int):
|
||||||
|
self.temp_dir = temp_dir
|
||||||
|
self.finetune_bert = finetune_bert
|
||||||
|
self.large = large
|
||||||
|
self.vocab_size = vocab_size_or_config_json_file
|
||||||
|
self.max_pos = max_pos
|
||||||
|
|
||||||
|
self.encoder = encoder
|
||||||
|
self.enc_layers = enc_layers
|
||||||
|
self.enc_hidden_size = enc_hidden_size
|
||||||
|
self.enc_heads = enc_heads
|
||||||
|
self.enc_ff_size = enc_ff_size
|
||||||
|
self.enc_dropout = enc_dropout
|
||||||
|
|
||||||
|
self.share_emb = share_emb
|
||||||
|
|
||||||
|
self.dec_layers = dec_layers
|
||||||
|
self.dec_hidden_size = dec_hidden_size
|
||||||
|
self.dec_heads = dec_heads
|
||||||
|
self.dec_ff_size = dec_ff_size
|
||||||
|
self.dec_dropout = dec_dropout
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"First argument must be either a vocabulary size (int)"
|
||||||
|
"or the path to a pretrained model config file (str)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _input_is_path_to_json(self, first_argument):
|
||||||
|
""" Checks whether the first argument passed to config
|
||||||
|
is the path to a JSON file that contains the config.
|
||||||
|
"""
|
||||||
|
is_python_2 = sys.version_info[0] == 2
|
||||||
|
if is_python_2:
|
||||||
|
return isinstance(first_argument, unicode)
|
||||||
|
else:
|
||||||
|
return isinstance(first_argument, str)
|
||||||
@@ -0,0 +1,162 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2018 The HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" Convert BertExtAbs's checkpoints
|
||||||
|
|
||||||
|
The file currently does not do much as we ended up copying the exact model
|
||||||
|
structure, but I leave it here in case we ever want to refactor the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from collections import namedtuple
|
||||||
|
import logging
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from models.model_builder import AbsSummarizer # The authors' implementation
|
||||||
|
from model_bertabs import BertAbsSummarizer
|
||||||
|
|
||||||
|
from transformers import BertTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
SAMPLE_TEXT = 'Hello world! cécé herlolip'
|
||||||
|
|
||||||
|
|
||||||
|
BertAbsConfig = namedtuple(
|
||||||
|
"BertAbsConfig",
|
||||||
|
["temp_dir", "large", "use_bert_emb", "finetune_bert", "encoder", "share_emb", "max_pos", "enc_layers", "enc_hidden_size", "enc_heads", "enc_ff_size", "enc_dropout", "dec_layers", "dec_hidden_size", "dec_heads", "dec_ff_size", "dec_dropout"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_bertabs_checkpoints(path_to_checkpoints, dump_path):
|
||||||
|
""" Copy/paste and tweak the pre-trained weights provided by the creators
|
||||||
|
of BertAbs for the internal architecture.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Instantiate the authors' model with the pre-trained weights
|
||||||
|
config = BertAbsConfig(
|
||||||
|
temp_dir=".",
|
||||||
|
finetune_bert=False,
|
||||||
|
large=False,
|
||||||
|
share_emb=True,
|
||||||
|
use_bert_emb=False,
|
||||||
|
encoder="bert",
|
||||||
|
max_pos=512,
|
||||||
|
enc_layers=6,
|
||||||
|
enc_hidden_size=512,
|
||||||
|
enc_heads=8,
|
||||||
|
enc_ff_size=512,
|
||||||
|
enc_dropout=0.2,
|
||||||
|
dec_layers=6,
|
||||||
|
dec_hidden_size=768,
|
||||||
|
dec_heads=8,
|
||||||
|
dec_ff_size=2048,
|
||||||
|
dec_dropout=0.2,
|
||||||
|
)
|
||||||
|
checkpoints = torch.load(path_to_checkpoints, lambda storage, loc: storage)
|
||||||
|
original = AbsSummarizer(config, torch.device("cpu"), checkpoints)
|
||||||
|
original.eval()
|
||||||
|
|
||||||
|
new_model = BertAbsSummarizer(config, torch.device("cpu"))
|
||||||
|
new_model.eval()
|
||||||
|
|
||||||
|
# -------------------
|
||||||
|
# Convert the weights
|
||||||
|
# -------------------
|
||||||
|
|
||||||
|
logging.info("convert the model")
|
||||||
|
new_model.bert.load_state_dict(original.bert.state_dict())
|
||||||
|
new_model.decoder.load_state_dict(original.decoder.state_dict())
|
||||||
|
new_model.generator.load_state_dict(original.generator.state_dict())
|
||||||
|
|
||||||
|
# ----------------------------------
|
||||||
|
# Make sure the outpus are identical
|
||||||
|
# ----------------------------------
|
||||||
|
|
||||||
|
logging.info("Make sure that the models' outputs are identical")
|
||||||
|
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
||||||
|
|
||||||
|
# prepare the model inputs
|
||||||
|
encoder_input_ids = tokenizer.encode("This is sample éàalj'-.")
|
||||||
|
encoder_input_ids.extend([tokenizer.pad_token_id] * (512 - len(encoder_input_ids)))
|
||||||
|
encoder_input_ids = torch.tensor(encoder_input_ids).unsqueeze(0)
|
||||||
|
decoder_input_ids = tokenizer.encode("This is sample 3 éàalj'-.")
|
||||||
|
decoder_input_ids.extend([tokenizer.pad_token_id] * (512 - len(decoder_input_ids)))
|
||||||
|
decoder_input_ids = torch.tensor(decoder_input_ids).unsqueeze(0)
|
||||||
|
|
||||||
|
# failsafe to make sure the weights reset does not affect the
|
||||||
|
# loaded weights.
|
||||||
|
assert torch.max(torch.abs(original.generator[0].weight - new_model.generator[0].weight)) == 0
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
src = encoder_input_ids
|
||||||
|
tgt = decoder_input_ids
|
||||||
|
segs = token_type_ids = None
|
||||||
|
clss = None
|
||||||
|
mask_src = encoder_attention_mask = None
|
||||||
|
mask_tgt = decoder_attention_mask = None
|
||||||
|
mask_cls = None
|
||||||
|
|
||||||
|
# The original model does not apply the geneator layer immediatly but rather in
|
||||||
|
# the beam search (where it combines softmax + linear layer). Since we already
|
||||||
|
# apply the softmax in our generation process we only apply the linear layer here.
|
||||||
|
# We make sure that the outputs of the full stack are identical
|
||||||
|
output_original_model = original(src, tgt, segs, clss, mask_src, mask_tgt, mask_cls)[0]
|
||||||
|
output_original_generator = original.generator(output_original_model)
|
||||||
|
|
||||||
|
output_converted_model = new_model(encoder_input_ids, decoder_input_ids, token_type_ids, encoder_attention_mask, decoder_attention_mask)[0]
|
||||||
|
output_converted_generator = new_model.generator(output_converted_model)
|
||||||
|
|
||||||
|
maximum_absolute_difference = torch.max(torch.abs(output_converted_model - output_original_model)).item()
|
||||||
|
print("Maximum absolute difference beween weights: {:.2f}".format(maximum_absolute_difference))
|
||||||
|
maximum_absolute_difference = torch.max(torch.abs(output_converted_generator - output_original_generator)).item()
|
||||||
|
print("Maximum absolute difference beween weights: {:.2f}".format(maximum_absolute_difference))
|
||||||
|
|
||||||
|
are_identical = torch.allclose(output_converted_model, output_original_model, atol=1e-3)
|
||||||
|
if are_identical:
|
||||||
|
logging.info("all weights are equal up to 1e-3")
|
||||||
|
else:
|
||||||
|
raise ValueError("the weights are different. The new model is likely different from the original one.")
|
||||||
|
|
||||||
|
# The model has been saved with torch.save(model) and this is bound to the exact
|
||||||
|
# directory structure. We save the state_dict instead.
|
||||||
|
logging.info("saving the model's state dictionary")
|
||||||
|
torch.save(new_model.state_dict(), "bertabs-finetuned-cnndm-extractive-abstractive-summarization-pytorch_model.bin")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--bertabs_checkpoint_path",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path the official PyTorch dump.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pytorch_dump_folder_path",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the output PyTorch model.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
convert_bertabs_checkpoints(
|
||||||
|
args.bertabs_checkpoint_path,
|
||||||
|
args.pytorch_dump_folder_path,
|
||||||
|
)
|
||||||
1250
examples/summarization/modeling_bertabs.py
Normal file
1250
examples/summarization/modeling_bertabs.py
Normal file
File diff suppressed because it is too large
Load Diff
271
examples/summarization/run_summarization.py
Normal file
271
examples/summarization/run_summarization.py
Normal file
@@ -0,0 +1,271 @@
|
|||||||
|
import argparse
|
||||||
|
from collections import namedtuple
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader, SequentialSampler
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from transformers import BertTokenizer
|
||||||
|
|
||||||
|
from modeling_bertabs import BertAbs, build_predictor
|
||||||
|
|
||||||
|
from utils_summarization import (
|
||||||
|
SummarizationDataset,
|
||||||
|
encode_for_summarization,
|
||||||
|
build_mask,
|
||||||
|
fit_to_block_size,
|
||||||
|
compute_token_type_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
Batch = namedtuple(
|
||||||
|
"Batch", ["document_names", "batch_size", "src", "segs", "mask_src", "tgt_str"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(args):
|
||||||
|
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True)
|
||||||
|
model = bertabs = BertAbs.from_pretrained(
|
||||||
|
"bertabs-finetuned-{}".format(args.finetuned_model)
|
||||||
|
)
|
||||||
|
bertabs.to(args.device)
|
||||||
|
bertabs.eval()
|
||||||
|
|
||||||
|
symbols = {
|
||||||
|
"BOS": tokenizer.vocab["[unused0]"],
|
||||||
|
"EOS": tokenizer.vocab["[unused1]"],
|
||||||
|
"PAD": tokenizer.vocab["[PAD]"],
|
||||||
|
}
|
||||||
|
|
||||||
|
# these (unused) arguments are defined to keep the compatibility
|
||||||
|
# with the legacy code and will be deleted in a next iteration.
|
||||||
|
args.result_path = ""
|
||||||
|
args.temp_dir = ""
|
||||||
|
|
||||||
|
data_iterator = build_data_iterator(args, tokenizer)
|
||||||
|
predictor = build_predictor(args, tokenizer, symbols, model)
|
||||||
|
|
||||||
|
logger.info("***** Running evaluation *****")
|
||||||
|
logger.info(" Number examples = %d", len(data_iterator.dataset))
|
||||||
|
logger.info(" Batch size = %d", args.batch_size)
|
||||||
|
logger.info("")
|
||||||
|
logger.info("***** Beam Search parameters *****")
|
||||||
|
logger.info(" Beam size = %d", args.beam_size)
|
||||||
|
logger.info(" Minimum length = %d", args.min_length)
|
||||||
|
logger.info(" Maximum length = %d", args.max_length)
|
||||||
|
logger.info(" Alpha (length penalty) = %.2f", args.alpha)
|
||||||
|
logger.info(" Trigrams %s be blocked", ("will" if args.block_trigram else "will NOT"))
|
||||||
|
|
||||||
|
for batch in tqdm(data_iterator):
|
||||||
|
batch_data = predictor.translate_batch(batch)
|
||||||
|
translations = predictor.from_batch(batch_data)
|
||||||
|
summaries = [format_summary(t) for t in translations]
|
||||||
|
save_summaries(summaries, args.summaries_output_dir, batch.document_names)
|
||||||
|
|
||||||
|
|
||||||
|
def format_summary(translation):
|
||||||
|
""" Transforms the output of the `from_batch` function
|
||||||
|
into nicely formatted summaries.
|
||||||
|
"""
|
||||||
|
raw_summary, _, _ = translation
|
||||||
|
summary = (
|
||||||
|
raw_summary.replace("[unused0]", "")
|
||||||
|
.replace("[unused3]", "")
|
||||||
|
.replace("[PAD]", "")
|
||||||
|
.replace("[unused1]", "")
|
||||||
|
.replace(r" +", " ")
|
||||||
|
.replace(" [unused2] ", ". ")
|
||||||
|
.replace("[unused2]", "")
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
|
|
||||||
|
return summary
|
||||||
|
|
||||||
|
|
||||||
|
def save_summaries(summaries, path, original_document_name):
|
||||||
|
""" Write the summaries in fies that are prefixed by the original
|
||||||
|
files' name with the `_summary` appended.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
original_document_names: List[string]
|
||||||
|
Name of the document that was summarized.
|
||||||
|
path: string
|
||||||
|
Path were the summaries will be written
|
||||||
|
summaries: List[string]
|
||||||
|
The summaries that we produced.
|
||||||
|
"""
|
||||||
|
for summary, document_name in zip(summaries, original_document_name):
|
||||||
|
# Prepare the summary file's name
|
||||||
|
if "." in document_name:
|
||||||
|
bare_document_name = ".".join(document_name.split(".")[:-1])
|
||||||
|
extension = document_name.split(".")[-1]
|
||||||
|
name = bare_document_name + "_summary." + extension
|
||||||
|
else:
|
||||||
|
name = document_name + "_summary"
|
||||||
|
|
||||||
|
file_path = os.path.join(path, name)
|
||||||
|
with open(file_path, "w") as output:
|
||||||
|
output.write(summary)
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
# LOAD the dataset
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
|
def build_data_iterator(args, tokenizer):
|
||||||
|
dataset = load_and_cache_examples(args, tokenizer)
|
||||||
|
sampler = SequentialSampler(dataset)
|
||||||
|
collate_fn = lambda data: collate(data, tokenizer, block_size=512)
|
||||||
|
iterator = DataLoader(
|
||||||
|
dataset, sampler=sampler, batch_size=args.batch_size, collate_fn=collate_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
return iterator
|
||||||
|
|
||||||
|
|
||||||
|
def load_and_cache_examples(args, tokenizer):
|
||||||
|
dataset = SummarizationDataset(args.documents_dir)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
def collate(data, tokenizer, block_size):
|
||||||
|
""" Collate formats the data passed to the data loader.
|
||||||
|
|
||||||
|
In particular we tokenize the data batch after batch to avoid keeping them
|
||||||
|
all in memory. We output the data as a namedtuple to fit the original BertAbs's
|
||||||
|
API.
|
||||||
|
"""
|
||||||
|
data = [x for x in data if not len(x[1]) == 0] # remove empty_files
|
||||||
|
names = [name for name, _, _ in data]
|
||||||
|
|
||||||
|
encoded_text = [
|
||||||
|
encode_for_summarization(story, summary, tokenizer) for _, story, summary in data
|
||||||
|
]
|
||||||
|
stories = torch.tensor(
|
||||||
|
[
|
||||||
|
fit_to_block_size(story, block_size, tokenizer.pad_token_id)
|
||||||
|
for story, _ in encoded_text
|
||||||
|
]
|
||||||
|
)
|
||||||
|
encoder_token_type_ids = compute_token_type_ids(stories, tokenizer.cls_token_id)
|
||||||
|
encoder_mask = build_mask(stories, tokenizer.pad_token_id)
|
||||||
|
|
||||||
|
batch = Batch(
|
||||||
|
document_names=names,
|
||||||
|
batch_size=len(stories),
|
||||||
|
src=stories,
|
||||||
|
segs=encoder_token_type_ids,
|
||||||
|
mask_src=encoder_mask,
|
||||||
|
tgt_str=[""] * len(stories),
|
||||||
|
)
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def decode_summary(summary_tokens, tokenizer):
|
||||||
|
""" Decode the summary and return it in a format
|
||||||
|
suitable for evaluation.
|
||||||
|
"""
|
||||||
|
summary_tokens = summary_tokens.to("cpu").numpy()
|
||||||
|
summary = tokenizer.decode(summary_tokens)
|
||||||
|
sentences = summary.split(".")
|
||||||
|
sentences = [s + "." for s in sentences]
|
||||||
|
return sentences
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
""" The main function defines the interface with the users.
|
||||||
|
"""
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--documents_dir",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The folder where the documents to summarize are located.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--summaries_output_dir",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The folder in wich the summaries should be written.",
|
||||||
|
)
|
||||||
|
# EVALUATION options
|
||||||
|
parser.add_argument(
|
||||||
|
"--visible_gpus",
|
||||||
|
default=-1,
|
||||||
|
type=int,
|
||||||
|
help="Number of GPUs with which to do the training.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.",
|
||||||
|
)
|
||||||
|
# BEAM SEARCH arguments
|
||||||
|
parser.add_argument(
|
||||||
|
"--min_length",
|
||||||
|
default=50,
|
||||||
|
type=int,
|
||||||
|
help="Minimum number of tokens for the summaries.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_length",
|
||||||
|
default=200,
|
||||||
|
type=int,
|
||||||
|
help="Maixmum number of tokens for the summaries.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--beam_size",
|
||||||
|
default=5,
|
||||||
|
type=int,
|
||||||
|
help="The number of beams to start with for each example.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--alpha",
|
||||||
|
default=0.95,
|
||||||
|
type=float,
|
||||||
|
help="The value of alpha for the length penalty in the beam search.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--block_trigram",
|
||||||
|
default=True,
|
||||||
|
type=bool,
|
||||||
|
help="Whether to block the existence of repeating trigrams in the text generated by beam search.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.device = torch.device("cpu") if args.visible_gpus == -1 else torch.device("cuda")
|
||||||
|
|
||||||
|
if not documents_dir_is_valid(args.documents_dir):
|
||||||
|
raise FileNotFoundError(
|
||||||
|
"We could not find the directory you specified for the documents to summarize, or it was empty. Please specify a valid path."
|
||||||
|
)
|
||||||
|
maybe_create_output_dir(args.summaries_output_dir)
|
||||||
|
|
||||||
|
evaluate(args)
|
||||||
|
|
||||||
|
|
||||||
|
def documents_dir_is_valid(path):
|
||||||
|
if not os.path.exists(path):
|
||||||
|
return False
|
||||||
|
|
||||||
|
file_list = os.listdir(path)
|
||||||
|
if len(file_list) == 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_create_output_dir(path):
|
||||||
|
if not os.path.exists(path):
|
||||||
|
os.makedirs(path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -10,9 +10,14 @@ from torch.utils.data import Dataset
|
|||||||
# ------------
|
# ------------
|
||||||
|
|
||||||
|
|
||||||
class CNNDailyMailDataset(Dataset):
|
class SummarizationDataset(Dataset):
|
||||||
""" Abstracts the dataset used to train seq2seq models.
|
""" Abstracts the dataset used to train seq2seq models.
|
||||||
|
|
||||||
|
The class will process the documents that are located in the specified
|
||||||
|
folder. The preprocessing will work on any document that is reasonably
|
||||||
|
formatted. On the CNN/DailyMail dataset it will extract both the story
|
||||||
|
and the summary.
|
||||||
|
|
||||||
CNN/Daily News:
|
CNN/Daily News:
|
||||||
|
|
||||||
The CNN/Daily News raw datasets are downloaded from [1]. The stories are
|
The CNN/Daily News raw datasets are downloaded from [1]. The stories are
|
||||||
@@ -25,32 +30,31 @@ class CNNDailyMailDataset(Dataset):
|
|||||||
[2] https://github.com/abisee/cnn-dailymail/
|
[2] https://github.com/abisee/cnn-dailymail/
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, data_dir="", prefix="train"):
|
def __init__(self, path="", prefix="train"):
|
||||||
assert os.path.isdir(data_dir)
|
""" We initialize the class by listing all the documents to summarize.
|
||||||
|
Files are not read in memory due to the size of some datasets (like CNN/DailyMail).
|
||||||
|
"""
|
||||||
|
assert os.path.isdir(path)
|
||||||
|
|
||||||
# We initialize the class by listing all the files that contain
|
self.documents = []
|
||||||
# stories and summaries. Files are not read in memory given
|
story_filenames_list = os.listdir(path)
|
||||||
# the size of the corpus.
|
for story_filename in story_filenames_list:
|
||||||
self.stories_path = []
|
path_to_story = os.path.join(path, story_filename)
|
||||||
datasets = ("cnn", "dailymail")
|
if not os.path.isfile(path_to_story):
|
||||||
for dataset in datasets:
|
continue
|
||||||
path_to_stories = os.path.join(data_dir, dataset, "stories")
|
self.documents.append(path_to_story)
|
||||||
story_filenames_list = os.listdir(path_to_stories)
|
|
||||||
for story_filename in story_filenames_list:
|
|
||||||
path_to_story = os.path.join(path_to_stories, story_filename)
|
|
||||||
if not os.path.isfile(path_to_story):
|
|
||||||
continue
|
|
||||||
self.stories_path.append(path_to_story)
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.stories_path)
|
""" Returns the number of documents. """
|
||||||
|
return len(self.documents)
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
story_path = self.stories_path[idx]
|
document_path = self.documents[idx]
|
||||||
with open(story_path, encoding="utf-8") as source:
|
document_name = document_path.split("/")[-1]
|
||||||
|
with open(document_path, encoding="utf-8") as source:
|
||||||
raw_story = source.read()
|
raw_story = source.read()
|
||||||
story_lines, summary_lines = process_story(raw_story)
|
story_lines, summary_lines = process_story(raw_story)
|
||||||
return story_lines, summary_lines
|
return document_name, story_lines, summary_lines
|
||||||
|
|
||||||
|
|
||||||
def process_story(raw_story):
|
def process_story(raw_story):
|
||||||
@@ -80,7 +84,7 @@ def process_story(raw_story):
|
|||||||
story_lines.append(element)
|
story_lines.append(element)
|
||||||
except IndexError:
|
except IndexError:
|
||||||
# if "@highlight" is absent from the file we pop
|
# if "@highlight" is absent from the file we pop
|
||||||
# all elements until there is None.
|
# all elements until there is None, raising an exception.
|
||||||
return story_lines, []
|
return story_lines, []
|
||||||
|
|
||||||
# gather summary lines
|
# gather summary lines
|
||||||
@@ -114,14 +118,6 @@ def fit_to_block_size(sequence, block_size, pad_token_id):
|
|||||||
return sequence
|
return sequence
|
||||||
|
|
||||||
|
|
||||||
def build_lm_labels(sequence, pad_token_id):
|
|
||||||
""" Padding token are replaced by the value -1 so they
|
|
||||||
are not taken into account in the loss computation. """
|
|
||||||
padded = sequence.clone()
|
|
||||||
padded[padded == pad_token_id] = -1
|
|
||||||
return padded
|
|
||||||
|
|
||||||
|
|
||||||
def build_mask(sequence, pad_token_id):
|
def build_mask(sequence, pad_token_id):
|
||||||
""" Builds the mask. The attention mechanism will only attend to positions
|
""" Builds the mask. The attention mechanism will only attend to positions
|
||||||
with value 1. """
|
with value 1. """
|
||||||
@@ -165,7 +161,7 @@ def compute_token_type_ids(batch, separator_token_id):
|
|||||||
"""
|
"""
|
||||||
batch_embeddings = []
|
batch_embeddings = []
|
||||||
for sequence in batch:
|
for sequence in batch:
|
||||||
sentence_num = 0
|
sentence_num = -1
|
||||||
embeddings = []
|
embeddings = []
|
||||||
for s in sequence:
|
for s in sequence:
|
||||||
if s == separator_token_id:
|
if s == separator_token_id:
|
||||||
@@ -21,7 +21,6 @@ from utils_summarization import (
|
|||||||
compute_token_type_ids,
|
compute_token_type_ids,
|
||||||
fit_to_block_size,
|
fit_to_block_size,
|
||||||
build_mask,
|
build_mask,
|
||||||
build_lm_labels,
|
|
||||||
process_story,
|
process_story,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -88,20 +87,6 @@ class SummarizationDataProcessingTest(unittest.TestCase):
|
|||||||
expected_summary_lines = ["It was the best of times."]
|
expected_summary_lines = ["It was the best of times."]
|
||||||
self.assertEqual(expected_summary_lines, summary_lines)
|
self.assertEqual(expected_summary_lines, summary_lines)
|
||||||
|
|
||||||
def test_build_lm_labels_no_padding(self):
|
|
||||||
sequence = torch.tensor([1, 2, 3, 4])
|
|
||||||
expected = sequence
|
|
||||||
np.testing.assert_array_equal(
|
|
||||||
build_lm_labels(sequence, 0).numpy(), expected.numpy()
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_build_lm_labels(self):
|
|
||||||
sequence = torch.tensor([1, 2, 3, 4, 0, 0, 0])
|
|
||||||
expected = torch.tensor([1, 2, 3, 4, -1, -1, -1])
|
|
||||||
np.testing.assert_array_equal(
|
|
||||||
build_lm_labels(sequence, 0).numpy(), expected.numpy()
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_build_mask_no_padding(self):
|
def test_build_mask_no_padding(self):
|
||||||
sequence = torch.tensor([1, 2, 3, 4])
|
sequence = torch.tensor([1, 2, 3, 4])
|
||||||
expected = torch.tensor([1, 1, 1, 1])
|
expected = torch.tensor([1, 1, 1, 1])
|
||||||
@@ -125,7 +110,7 @@ class SummarizationDataProcessingTest(unittest.TestCase):
|
|||||||
[[1, 2, 3, 4, 5, 6], [1, 2, 3, 101, 5, 6], [1, 101, 3, 4, 101, 6]]
|
[[1, 2, 3, 4, 5, 6], [1, 2, 3, 101, 5, 6], [1, 101, 3, 4, 101, 6]]
|
||||||
)
|
)
|
||||||
expected = torch.tensor(
|
expected = torch.tensor(
|
||||||
[[0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 1, 1], [0, 1, 1, 1, 0, 0]]
|
[[1, 1, 1, 1, 1, 1], [1, 1, 1, 0, 0, 0], [1, 0, 0, 0, 1, 1]]
|
||||||
)
|
)
|
||||||
|
|
||||||
result = compute_token_type_ids(batch, separator)
|
result = compute_token_type_ids(batch, separator)
|
||||||
@@ -0,0 +1,158 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2018 The HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" Convert BertExtAbs's checkpoints """
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from collections import namedtuple
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from models.model_builder import AbsSummarizer # The authors' implementation
|
||||||
|
|
||||||
|
from transformers import BertConfig, Model2Model, BertModel, BertForMaskedLM
|
||||||
|
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
BertExtAbsConfig = namedtuple(
|
||||||
|
"BertExtAbsConfig",
|
||||||
|
["temp_dir", "large", "finetune_bert", "encoder", "share_emb", "max_pos", "enc_layers", "enc_hidden_size", "enc_heads", "enc_ff_size", "enc_dropout", "dec_layers", "dec_hidden_size", "dec_heads", "dec_ff_size", "dec_dropout"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_bertextabs_checkpoints(path_to_checkpoints, dump_path):
|
||||||
|
""" Copy/paste and tweak the pre-trained weights provided by the creators
|
||||||
|
of BertExtAbs for the internal architecture.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Load checkpoints in memory
|
||||||
|
checkpoints = torch.load(path_to_checkpoints, lambda storage, loc: storage)
|
||||||
|
|
||||||
|
# Instantiate the authors' model with the pre-trained weights
|
||||||
|
config = BertExtAbsConfig(
|
||||||
|
temp_dir=".",
|
||||||
|
finetune_bert=False,
|
||||||
|
large=False,
|
||||||
|
share_emb=True,
|
||||||
|
encoder="bert",
|
||||||
|
max_pos=512,
|
||||||
|
enc_layers=6,
|
||||||
|
enc_hidden_size=512,
|
||||||
|
enc_heads=8,
|
||||||
|
enc_ff_size=512,
|
||||||
|
enc_dropout=0.2,
|
||||||
|
dec_layers=6,
|
||||||
|
dec_hidden_size=768,
|
||||||
|
dec_heads=8,
|
||||||
|
dec_ff_size=2048,
|
||||||
|
dec_dropout=0.2,
|
||||||
|
)
|
||||||
|
bertextabs = AbsSummarizer(config, torch.device("cpu"), checkpoints)
|
||||||
|
bertextabs.eval()
|
||||||
|
|
||||||
|
# Instantiate our version of the model
|
||||||
|
decoder_config = BertConfig(
|
||||||
|
hidden_size=config.dec_hidden_size,
|
||||||
|
num_hidden_layers=config.dec_layers,
|
||||||
|
num_attention_heads=config.dec_heads,
|
||||||
|
intermediate_size=config.dec_ff_size,
|
||||||
|
hidden_dropout_prob=config.dec_dropout,
|
||||||
|
attention_probs_dropout_prob=config.dec_dropout,
|
||||||
|
is_decoder=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder_model = BertForMaskedLM(decoder_config)
|
||||||
|
model = Model2Model.from_pretrained('bert-base-uncased', decoder_model=decoder_model)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# Let us now start the weight copying process
|
||||||
|
model.encoder.load_state_dict(bertextabs.bert.model.state_dict())
|
||||||
|
|
||||||
|
# Decoder
|
||||||
|
|
||||||
|
# Embeddings. The positional embeddings are equal to the word embedding plus a modulation
|
||||||
|
# that is computed at each forward pass. This may be a source of discrepancy.
|
||||||
|
model.decoder.bert.embeddings.word_embeddings.weight = bertextabs.decoder.embeddings.weight
|
||||||
|
model.decoder.bert.embeddings.position_embeddings.weight = bertextabs.decoder.embeddings.weight
|
||||||
|
model.decoder.bert.embeddings.token_type_embeddings.weight.data = torch.zeros_like(bertextabs.decoder.embeddings.weight) # not defined for BertExtAbs decoder
|
||||||
|
|
||||||
|
# In the original code the LayerNorms are applied twice in the layers, at the beginning and between the
|
||||||
|
# attention layers.
|
||||||
|
model.decoder.bert.embeddings.LayerNorm.weight = bertextabs.decoder.transformer_layers[0].layer_norm_1.weight
|
||||||
|
|
||||||
|
for i in range(config.dec_layers):
|
||||||
|
|
||||||
|
# self attention
|
||||||
|
model.decoder.bert.encoder.layer[i].attention.self.query.weight = bertextabs.decoder.transformer_layers[i].self_attn.linear_query.weight
|
||||||
|
model.decoder.bert.encoder.layer[i].attention.self.key.weight = bertextabs.decoder.transformer_layers[i].self_attn.linear_keys.weight
|
||||||
|
model.decoder.bert.encoder.layer[i].attention.self.value.weight = bertextabs.decoder.transformer_layers[i].self_attn.linear_values.weight
|
||||||
|
model.decoder.bert.encoder.layer[i].attention.output.dense.weight = bertextabs.decoder.transformer_layers[i].self_attn.final_linear.weight
|
||||||
|
model.decoder.bert.encoder.layer[i].attention.output.LayerNorm.weight = bertextabs.decoder.transformer_layers[i].layer_norm_2.weight
|
||||||
|
|
||||||
|
# attention
|
||||||
|
model.decoder.bert.encoder.layer[i].crossattention.self.query.weight = bertextabs.decoder.transformer_layers[i].context_attn.linear_query.weight
|
||||||
|
model.decoder.bert.encoder.layer[i].crossattention.self.key.weight = bertextabs.decoder.transformer_layers[i].context_attn.linear_keys.weight
|
||||||
|
model.decoder.bert.encoder.layer[i].crossattention.self.value.weight = bertextabs.decoder.transformer_layers[i].context_attn.linear_values.weight
|
||||||
|
model.decoder.bert.encoder.layer[i].crossattention.output.dense.weight = bertextabs.decoder.transformer_layers[i].context_attn.final_linear.weight
|
||||||
|
model.decoder.bert.encoder.layer[i].crossattention.output.LayerNorm.weight = bertextabs.decoder.transformer_layers[i].feed_forward.layer_norm.weight
|
||||||
|
|
||||||
|
# intermediate
|
||||||
|
model.decoder.bert.encoder.layer[i].intermediate.dense.weight = bertextabs.decoder.transformer_layers[i].feed_forward.w_1.weight
|
||||||
|
|
||||||
|
# output
|
||||||
|
model.decoder.bert.encoder.layer[i].output.dense.weight = bertextabs.decoder.transformer_layers[i].feed_forward.w_2.weight
|
||||||
|
|
||||||
|
try:
|
||||||
|
model.decoder.bert.encoder.layer[i].output.LayerNorm.weight = bertextabs.decoder.transformer_layers[i + 1].layer_norm_1.weight
|
||||||
|
except IndexError:
|
||||||
|
model.decoder.bert.encoder.layer[i].output.LayerNorm.weight = bertextabs.decoder.layer_norm.weight
|
||||||
|
|
||||||
|
# LM Head
|
||||||
|
"""
|
||||||
|
model.decoder.cls.predictions.transform.dense.weight
|
||||||
|
model.decoder.cls.predictions.transform.dense.biais
|
||||||
|
model.decoder.cls.predictions.transform.LayerNorm.weight
|
||||||
|
model.decoder.cls.predictions.transform.LayerNorm.biais
|
||||||
|
model.decoder.cls.predictions.decoder.weight
|
||||||
|
model.decoder.cls.predictions.decoder.biais
|
||||||
|
model.decoder.cls.predictions.biais.data
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--bertextabs_checkpoint_path",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path the official PyTorch dump.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pytorch_dump_folder_path",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the output PyTorch model.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
convert_bertextabs_checkpoints(
|
||||||
|
args.bertextabs_checkpoint_path,
|
||||||
|
args.pytorch_dump_folder_path,
|
||||||
|
)
|
||||||
@@ -25,7 +25,6 @@ Use Beam Search to generate sequences using encoder-decoder models.
|
|||||||
"""
|
"""
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -45,6 +44,7 @@ class BeamSearch(object):
|
|||||||
max_length,
|
max_length,
|
||||||
alpha=0,
|
alpha=0,
|
||||||
block_repeating_trigrams=True,
|
block_repeating_trigrams=True,
|
||||||
|
device=torch.device("cpu"),
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Inputs:
|
Inputs:
|
||||||
@@ -156,18 +156,24 @@ class BeamSearch(object):
|
|||||||
kwargs_decoder["encoder_hidden_states"] = tile(
|
kwargs_decoder["encoder_hidden_states"] = tile(
|
||||||
encoder_hidden_states, self.beam_size, dim=0
|
encoder_hidden_states, self.beam_size, dim=0
|
||||||
)
|
)
|
||||||
kwargs_decoder["encoder_attention_mask"] = tile(
|
try:
|
||||||
kwargs_encoder["attention_mask"], self.beam_size, dim=0
|
kwargs_decoder["encoder_attention_mask"] = tile(
|
||||||
|
kwargs_encoder["attention_mask"], self.beam_size, dim=0
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
kwargs_decoder["state"].src = tile(
|
||||||
|
kwargs_decoder["state"].src, self.beam_size, dim=0
|
||||||
)
|
)
|
||||||
|
|
||||||
# grow the beam iteratively
|
# grow the beam iteratively
|
||||||
batch_size, block_size = encoder_input_ids.size()
|
batch_size, block_size = encoder_input_ids.size()
|
||||||
self._init_beam_state(batch_size)
|
self._init_beam_state(batch_size)
|
||||||
for step in range(self.max_length):
|
for step in range(self.max_length):
|
||||||
|
|
||||||
decoder_input = fit_to_block_size(self.growing_beams, block_size, self.pad_token_id)
|
decoder_input = fit_to_block_size(self.growing_beams, block_size, self.pad_token_id)
|
||||||
kwargs_decoder["attention_mask"] = build_mask(decoder_input, self.pad_token_id)
|
kwargs_decoder["attention_mask"] = build_mask(decoder_input, self.pad_token_id)
|
||||||
outputs = self.model.decoder(decoder_input, **kwargs_decoder)
|
|
||||||
|
outputs, state = self.model.decoder(decoder_input, **kwargs_decoder)
|
||||||
|
|
||||||
next_token_scores = outputs[0][:, -1, :].squeeze(1)
|
next_token_scores = outputs[0][:, -1, :].squeeze(1)
|
||||||
log_probabilities = torch.nn.functional.log_softmax(next_token_scores, dim=0)
|
log_probabilities = torch.nn.functional.log_softmax(next_token_scores, dim=0)
|
||||||
@@ -178,9 +184,13 @@ class BeamSearch(object):
|
|||||||
kwargs_decoder["encoder_hidden_states"] = kwargs_decoder[
|
kwargs_decoder["encoder_hidden_states"] = kwargs_decoder[
|
||||||
"encoder_hidden_states"
|
"encoder_hidden_states"
|
||||||
].index_select(0, surviving_beams_rows)
|
].index_select(0, surviving_beams_rows)
|
||||||
kwargs_decoder["encoder_attention_mask"] = kwargs_decoder[
|
try:
|
||||||
"encoder_attention_mask"
|
kwargs_decoder["encoder_attention_mask"] = kwargs_decoder[
|
||||||
].index_select(0, surviving_beams_rows)
|
"encoder_attention_mask"
|
||||||
|
].index_select(0, surviving_beams_rows)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
kwargs_decoder["state"] = state
|
||||||
|
|
||||||
return self.results
|
return self.results
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user