Add "Leveraging Pretrained Checkpoints for Generation" Seq2Seq models. (#6594)

* add conversion script

* improve conversion script

* make style

* add tryout files

* fix

* update

* add causal bert

* better names

* add tokenizer file as well

* finish causal_bert

* fix small bugs

* improve generate

* change naming

* renaming

* renaming

* renaming

* remove leftover files

* clean files

* add fix tokenizer

* finalize

* correct slow test

* update docs

* small fixes

* fix link

* adapt check repo

* apply sams and sylvains recommendations

* fix import

* implement Lysandres recommendations

* fix logger warn
This commit is contained in:
Patrick von Platen
2020-09-10 16:40:51 +02:00
committed by GitHub
parent d1691d90e5
commit 7fd1febf38
20 changed files with 1508 additions and 9 deletions

View File

@@ -0,0 +1,234 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_torch_available():
from transformers import BertGenerationConfig, BertGenerationDecoder, BertGenerationEncoder
class BertGenerationEncoderTester:
def __init__(
self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_input_mask=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=50,
initializer_range=0.02,
use_labels=True,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.use_labels = use_labels
self.scope = scope
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = random_attention_mask([self.batch_size, self.seq_length])
if self.use_labels:
token_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
config = BertGenerationConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
is_decoder=False,
initializer_range=self.initializer_range,
return_dict=True,
)
return config, input_ids, input_mask, token_labels
def prepare_config_and_inputs_for_decoder(self):
(
config,
input_ids,
input_mask,
token_labels,
) = self.prepare_config_and_inputs()
config.is_decoder = True
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
return (
config,
input_ids,
input_mask,
token_labels,
encoder_hidden_states,
encoder_attention_mask,
)
def create_and_check_model(
self,
config,
input_ids,
input_mask,
token_labels,
**kwargs,
):
model = BertGenerationEncoder(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask)
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_model_as_decoder(
self,
config,
input_ids,
input_mask,
token_labels,
encoder_hidden_states,
encoder_attention_mask,
**kwargs,
):
config.add_cross_attention = True
model = BertGenerationEncoder(config=config)
model.to(torch_device)
model.eval()
result = model(
input_ids,
attention_mask=input_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
)
result = model(
input_ids,
attention_mask=input_mask,
encoder_hidden_states=encoder_hidden_states,
)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_for_causal_lm(
self,
config,
input_ids,
input_mask,
token_labels,
*args,
):
model = BertGenerationDecoder(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
input_mask,
token_labels,
) = config_and_inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
return config, inputs_dict
@require_torch
class BertGenerationEncoderTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (BertGenerationEncoder, BertGenerationDecoder) if is_torch_available() else ()
def setUp(self):
self.model_tester = BertGenerationEncoderTester(self)
self.config_tester = ConfigTester(self, config_class=BertGenerationConfig, hidden_size=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_as_decoder(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)
def test_model_as_decoder_with_default_input_mask(self):
# This regression test was failing with PyTorch < 1.3
(
config,
input_ids,
input_mask,
token_labels,
encoder_hidden_states,
encoder_attention_mask,
) = self.model_tester.prepare_config_and_inputs_for_decoder()
input_mask = None
self.model_tester.create_and_check_model_as_decoder(
config,
input_ids,
input_mask,
token_labels,
encoder_hidden_states,
encoder_attention_mask,
)
def test_for_causal_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
model = BertGenerationEncoder.from_pretrained("google/bert_for_seq_generation_L-24_bbc_encoder")
self.assertIsNotNone(model)

View File

@@ -21,6 +21,7 @@ from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_modeling_bert import BertModelTester
from .test_modeling_bert_generation import BertGenerationEncoderTester
from .test_modeling_common import ids_tensor
from .test_modeling_gpt2 import GPT2ModelTester
from .test_modeling_roberta import RobertaModelTester
@@ -31,6 +32,9 @@ if is_torch_available():
import torch
from transformers import (
AutoTokenizer,
BertGenerationDecoder,
BertGenerationEncoder,
BertLMHeadModel,
BertModel,
BertTokenizer,
@@ -489,6 +493,67 @@ class BertEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
self.assertEqual(summary, EXPECTED_SUMMARY)
class BertForSeqGenerationEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
def get_pretrained_model(self):
return EncoderDecoderModel.from_encoder_decoder_pretrained(
"google/bert_for_seq_generation_L-24_bbc_encoder", "google/bert_for_seq_generation_L-24_bbc_encoder"
)
def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = BertGenerationEncoder(config)
decoder_model = BertGenerationDecoder(decoder_config)
return encoder_model, decoder_model
def prepare_config_and_inputs(self):
model_tester = BertGenerationEncoderTester(self)
encoder_config_and_inputs = model_tester.prepare_config_and_inputs()
decoder_config_and_inputs = model_tester.prepare_config_and_inputs_for_decoder()
(
config,
input_ids,
input_mask,
token_labels,
) = encoder_config_and_inputs
(
decoder_config,
decoder_input_ids,
decoder_input_mask,
decoder_token_labels,
encoder_hidden_states,
encoder_attention_mask,
) = decoder_config_and_inputs
# make sure that cross attention layers are added
decoder_config.add_cross_attention = True
return {
"config": config,
"input_ids": input_ids,
"attention_mask": input_mask,
"decoder_config": decoder_config,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_input_mask,
"decoder_token_labels": decoder_token_labels,
"encoder_hidden_states": encoder_hidden_states,
"labels": decoder_token_labels,
}
@slow
def test_roberta2roberta_summarization(self):
model = EncoderDecoderModel.from_pretrained("google/roberta2roberta_L-24_bbc")
model.to(torch_device)
tokenizer = AutoTokenizer.from_pretrained("google/roberta2roberta_L-24_bbc")
ARTICLE = """The problem is affecting people using the older versions of the PlayStation 3, called the "Fat" model.The problem isn't affecting the newer PS3 Slim systems that have been on sale since September last year.Sony have also said they are aiming to have the problem fixed shortly but is advising some users to avoid using their console for the time being."We hope to resolve this problem within the next 24 hours," a statement reads. "In the meantime, if you have a model other than the new slim PS3, we advise that you do not use your PS3 system, as doing so may result in errors in some functionality, such as recording obtained trophies, and not being able to restore certain data."We believe we have identified that this problem is being caused by a bug in the clock functionality incorporated in the system."The PlayStation Network is used by millions of people around the world.It allows users to play their friends at games like Fifa over the internet and also do things like download software or visit online stores."""
EXPECTED_SUMMARY = """Sony has said that a bug in its PlayStation 3 console is preventing them from using the machine as a computer."""
input_ids = tokenizer(ARTICLE, return_tensors="pt").input_ids.to(torch_device)
output_ids = model.generate(input_ids)
summary = tokenizer.decode(output_ids[0], skip_special_tokens=True)
self.assertEqual(summary, EXPECTED_SUMMARY)
class RoBertaEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = RobertaModel(config)

View File

@@ -0,0 +1,210 @@
# coding=utf-8
# Copyright 2020 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
from transformers.file_utils import cached_property
from transformers.testing_utils import require_torch, slow
from transformers.tokenization_bert_generation import BertGenerationTokenizer
from .test_tokenization_common import TokenizerTesterMixin
SPIECE_UNDERLINE = ""
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")
class BertForSeqGenerationTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = BertGenerationTokenizer
def setUp(self):
super().setUp()
tokenizer = BertGenerationTokenizer(SAMPLE_VOCAB, keep_accents=True)
tokenizer.save_pretrained(self.tmpdirname)
def test_full_tokenizer(self):
tokenizer = BertGenerationTokenizer(SAMPLE_VOCAB, keep_accents=True)
tokens = tokenizer.tokenize("This is a test")
self.assertListEqual(tokens, ["▁This", "▁is", "▁a", "▁t", "est"])
self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens),
[285, 46, 10, 170, 382],
)
tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.")
self.assertListEqual(
tokens,
[
SPIECE_UNDERLINE + "I",
SPIECE_UNDERLINE + "was",
SPIECE_UNDERLINE + "b",
"or",
"n",
SPIECE_UNDERLINE + "in",
SPIECE_UNDERLINE + "",
"9",
"2",
"0",
"0",
"0",
",",
SPIECE_UNDERLINE + "and",
SPIECE_UNDERLINE + "this",
SPIECE_UNDERLINE + "is",
SPIECE_UNDERLINE + "f",
"al",
"s",
"é",
".",
],
)
ids = tokenizer.convert_tokens_to_ids(tokens)
self.assertListEqual(
ids,
[8, 21, 84, 55, 24, 19, 7, 0, 602, 347, 347, 347, 3, 12, 66, 46, 72, 80, 6, 0, 4],
)
back_tokens = tokenizer.convert_ids_to_tokens(ids)
self.assertListEqual(
back_tokens,
[
SPIECE_UNDERLINE + "I",
SPIECE_UNDERLINE + "was",
SPIECE_UNDERLINE + "b",
"or",
"n",
SPIECE_UNDERLINE + "in",
SPIECE_UNDERLINE + "",
"<unk>",
"2",
"0",
"0",
"0",
",",
SPIECE_UNDERLINE + "and",
SPIECE_UNDERLINE + "this",
SPIECE_UNDERLINE + "is",
SPIECE_UNDERLINE + "f",
"al",
"s",
"<unk>",
".",
],
)
@cached_property
def big_tokenizer(self):
return BertGenerationTokenizer.from_pretrained("google/bert_for_seq_generation_L-24_bbc_encoder")
@slow
def test_tokenization_base_easy_symbols(self):
symbols = "Hello World!"
original_tokenizer_encodings = [18536, 2260, 101]
self.assertListEqual(original_tokenizer_encodings, self.big_tokenizer.encode(symbols))
@slow
def test_tokenization_base_hard_symbols(self):
symbols = 'This is a very long text with a lot of weird characters, such as: . , ~ ? ( ) " [ ] ! : - . Also we will add words that should not exsist and be tokenized to <unk>, such as saoneuhaoesuth'
original_tokenizer_encodings = [
871,
419,
358,
946,
991,
2521,
452,
358,
1357,
387,
7751,
3536,
112,
985,
456,
126,
865,
938,
5400,
5734,
458,
1368,
467,
786,
2462,
5246,
1159,
633,
865,
4519,
457,
582,
852,
2557,
427,
916,
508,
405,
34324,
497,
391,
408,
11342,
1244,
385,
100,
938,
985,
456,
574,
362,
12597,
3200,
3129,
1172,
]
self.assertListEqual(original_tokenizer_encodings, self.big_tokenizer.encode(symbols))
@slow
@require_torch
def test_torch_encode_plus_sent_to_model(self):
import torch
from transformers import BertGenerationConfig, BertGenerationEncoder
# Build sequence
first_ten_tokens = list(self.big_tokenizer.get_vocab().keys())[:10]
sequence = " ".join(first_ten_tokens)
encoded_sequence = self.big_tokenizer.encode_plus(sequence, return_tensors="pt", return_token_type_ids=False)
batch_encoded_sequence = self.big_tokenizer.batch_encode_plus(
[sequence + " " + sequence], return_tensors="pt", return_token_type_ids=False
)
config = BertGenerationConfig()
model = BertGenerationEncoder(config)
assert model.get_input_embeddings().weight.shape[0] >= self.big_tokenizer.vocab_size
with torch.no_grad():
model(**encoded_sequence)
model(**batch_encoded_sequence)

View File

@@ -21,11 +21,12 @@ from transformers import BatchEncoding
from transformers.file_utils import cached_property
from transformers.testing_utils import _torch_available
from transformers.tokenization_t5 import T5Tokenizer
from transformers.tokenization_xlnet import SPIECE_UNDERLINE
from .test_tokenization_common import TokenizerTesterMixin
SPIECE_UNDERLINE = ""
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")
FRAMEWORK = "pt" if _torch_available else "tf"