Add "Leveraging Pretrained Checkpoints for Generation" Seq2Seq models. (#6594)
* add conversion script * improve conversion script * make style * add tryout files * fix * update * add causal bert * better names * add tokenizer file as well * finish causal_bert * fix small bugs * improve generate * change naming * renaming * renaming * renaming * remove leftover files * clean files * add fix tokenizer * finalize * correct slow test * update docs * small fixes * fix link * adapt check repo * apply sams and sylvains recommendations * fix import * implement Lysandres recommendations * fix logger warn
This commit is contained in:
committed by
GitHub
parent
d1691d90e5
commit
7fd1febf38
@@ -21,6 +21,7 @@ from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_modeling_bert import BertModelTester
|
||||
from .test_modeling_bert_generation import BertGenerationEncoderTester
|
||||
from .test_modeling_common import ids_tensor
|
||||
from .test_modeling_gpt2 import GPT2ModelTester
|
||||
from .test_modeling_roberta import RobertaModelTester
|
||||
@@ -31,6 +32,9 @@ if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
BertGenerationDecoder,
|
||||
BertGenerationEncoder,
|
||||
BertLMHeadModel,
|
||||
BertModel,
|
||||
BertTokenizer,
|
||||
@@ -489,6 +493,67 @@ class BertEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
|
||||
self.assertEqual(summary, EXPECTED_SUMMARY)
|
||||
|
||||
|
||||
class BertForSeqGenerationEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
|
||||
def get_pretrained_model(self):
|
||||
return EncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||
"google/bert_for_seq_generation_L-24_bbc_encoder", "google/bert_for_seq_generation_L-24_bbc_encoder"
|
||||
)
|
||||
|
||||
def get_encoder_decoder_model(self, config, decoder_config):
|
||||
encoder_model = BertGenerationEncoder(config)
|
||||
decoder_model = BertGenerationDecoder(decoder_config)
|
||||
return encoder_model, decoder_model
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
model_tester = BertGenerationEncoderTester(self)
|
||||
encoder_config_and_inputs = model_tester.prepare_config_and_inputs()
|
||||
decoder_config_and_inputs = model_tester.prepare_config_and_inputs_for_decoder()
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_labels,
|
||||
) = encoder_config_and_inputs
|
||||
(
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_input_mask,
|
||||
decoder_token_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
) = decoder_config_and_inputs
|
||||
|
||||
# make sure that cross attention layers are added
|
||||
decoder_config.add_cross_attention = True
|
||||
return {
|
||||
"config": config,
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": input_mask,
|
||||
"decoder_config": decoder_config,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_input_mask,
|
||||
"decoder_token_labels": decoder_token_labels,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"labels": decoder_token_labels,
|
||||
}
|
||||
|
||||
@slow
|
||||
def test_roberta2roberta_summarization(self):
|
||||
model = EncoderDecoderModel.from_pretrained("google/roberta2roberta_L-24_bbc")
|
||||
model.to(torch_device)
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/roberta2roberta_L-24_bbc")
|
||||
|
||||
ARTICLE = """The problem is affecting people using the older versions of the PlayStation 3, called the "Fat" model.The problem isn't affecting the newer PS3 Slim systems that have been on sale since September last year.Sony have also said they are aiming to have the problem fixed shortly but is advising some users to avoid using their console for the time being."We hope to resolve this problem within the next 24 hours," a statement reads. "In the meantime, if you have a model other than the new slim PS3, we advise that you do not use your PS3 system, as doing so may result in errors in some functionality, such as recording obtained trophies, and not being able to restore certain data."We believe we have identified that this problem is being caused by a bug in the clock functionality incorporated in the system."The PlayStation Network is used by millions of people around the world.It allows users to play their friends at games like Fifa over the internet and also do things like download software or visit online stores."""
|
||||
|
||||
EXPECTED_SUMMARY = """Sony has said that a bug in its PlayStation 3 console is preventing them from using the machine as a computer."""
|
||||
|
||||
input_ids = tokenizer(ARTICLE, return_tensors="pt").input_ids.to(torch_device)
|
||||
output_ids = model.generate(input_ids)
|
||||
summary = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
||||
|
||||
self.assertEqual(summary, EXPECTED_SUMMARY)
|
||||
|
||||
|
||||
class RoBertaEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
|
||||
def get_encoder_decoder_model(self, config, decoder_config):
|
||||
encoder_model = RobertaModel(config)
|
||||
|
||||
Reference in New Issue
Block a user