Add "Leveraging Pretrained Checkpoints for Generation" Seq2Seq models. (#6594)

* add conversion script

* improve conversion script

* make style

* add tryout files

* fix

* update

* add causal bert

* better names

* add tokenizer file as well

* finish causal_bert

* fix small bugs

* improve generate

* change naming

* renaming

* renaming

* renaming

* remove leftover files

* clean files

* add fix tokenizer

* finalize

* correct slow test

* update docs

* small fixes

* fix link

* adapt check repo

* apply sams and sylvains recommendations

* fix import

* implement Lysandres recommendations

* fix logger warn
This commit is contained in:
Patrick von Platen
2020-09-10 16:40:51 +02:00
committed by GitHub
parent d1691d90e5
commit 7fd1febf38
20 changed files with 1508 additions and 9 deletions

View File

@@ -21,6 +21,7 @@ from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_modeling_bert import BertModelTester
from .test_modeling_bert_generation import BertGenerationEncoderTester
from .test_modeling_common import ids_tensor
from .test_modeling_gpt2 import GPT2ModelTester
from .test_modeling_roberta import RobertaModelTester
@@ -31,6 +32,9 @@ if is_torch_available():
import torch
from transformers import (
AutoTokenizer,
BertGenerationDecoder,
BertGenerationEncoder,
BertLMHeadModel,
BertModel,
BertTokenizer,
@@ -489,6 +493,67 @@ class BertEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
self.assertEqual(summary, EXPECTED_SUMMARY)
class BertForSeqGenerationEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
def get_pretrained_model(self):
return EncoderDecoderModel.from_encoder_decoder_pretrained(
"google/bert_for_seq_generation_L-24_bbc_encoder", "google/bert_for_seq_generation_L-24_bbc_encoder"
)
def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = BertGenerationEncoder(config)
decoder_model = BertGenerationDecoder(decoder_config)
return encoder_model, decoder_model
def prepare_config_and_inputs(self):
model_tester = BertGenerationEncoderTester(self)
encoder_config_and_inputs = model_tester.prepare_config_and_inputs()
decoder_config_and_inputs = model_tester.prepare_config_and_inputs_for_decoder()
(
config,
input_ids,
input_mask,
token_labels,
) = encoder_config_and_inputs
(
decoder_config,
decoder_input_ids,
decoder_input_mask,
decoder_token_labels,
encoder_hidden_states,
encoder_attention_mask,
) = decoder_config_and_inputs
# make sure that cross attention layers are added
decoder_config.add_cross_attention = True
return {
"config": config,
"input_ids": input_ids,
"attention_mask": input_mask,
"decoder_config": decoder_config,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_input_mask,
"decoder_token_labels": decoder_token_labels,
"encoder_hidden_states": encoder_hidden_states,
"labels": decoder_token_labels,
}
@slow
def test_roberta2roberta_summarization(self):
model = EncoderDecoderModel.from_pretrained("google/roberta2roberta_L-24_bbc")
model.to(torch_device)
tokenizer = AutoTokenizer.from_pretrained("google/roberta2roberta_L-24_bbc")
ARTICLE = """The problem is affecting people using the older versions of the PlayStation 3, called the "Fat" model.The problem isn't affecting the newer PS3 Slim systems that have been on sale since September last year.Sony have also said they are aiming to have the problem fixed shortly but is advising some users to avoid using their console for the time being."We hope to resolve this problem within the next 24 hours," a statement reads. "In the meantime, if you have a model other than the new slim PS3, we advise that you do not use your PS3 system, as doing so may result in errors in some functionality, such as recording obtained trophies, and not being able to restore certain data."We believe we have identified that this problem is being caused by a bug in the clock functionality incorporated in the system."The PlayStation Network is used by millions of people around the world.It allows users to play their friends at games like Fifa over the internet and also do things like download software or visit online stores."""
EXPECTED_SUMMARY = """Sony has said that a bug in its PlayStation 3 console is preventing them from using the machine as a computer."""
input_ids = tokenizer(ARTICLE, return_tensors="pt").input_ids.to(torch_device)
output_ids = model.generate(input_ids)
summary = tokenizer.decode(output_ids[0], skip_special_tokens=True)
self.assertEqual(summary, EXPECTED_SUMMARY)
class RoBertaEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = RobertaModel(config)