Support T5 Generation (#3228)

* fix conflicts

* update bart max length test

* correct spelling mistakes

* implemented model specific encode function

* fix merge conflicts

* better naming

* save intermediate state -> need to rethink strucuture a bit

* leave tf problem as it is for now

* current version

* add layers.pop

* remove ipdb

* make style

* clean return cut decoding

* remove ipdbs

* Fix restoring layers in the decoders that doesnt exists.

* push good intermediate solution for now

* fix conflicts

* always good to refuse to merge conflicts when rebasing

* fix small bug

* improve function calls

* remove unused file

* add correct scope behavior for t5_generate

Co-authored-by: Morgan Funtowicz <funtowiczmo@gmail.com>
This commit is contained in:
Patrick von Platen
2020-03-19 23:18:23 +01:00
committed by GitHub
parent 656e1386a2
commit bbf26c4e61
16 changed files with 449 additions and 280 deletions

View File

@@ -24,14 +24,15 @@ from .utils import CACHE_DIR, require_torch, slow, torch_device
if is_torch_available():
from transformers import T5Config, T5Model, T5WithLMHeadModel
from transformers import T5Config, T5Model, T5ForConditionalGeneration
from transformers.modeling_t5 import T5_PRETRAINED_MODEL_ARCHIVE_MAP
@require_torch
class T5ModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (T5Model, T5WithLMHeadModel) if is_torch_available() else ()
all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else ()
test_pruning = False
test_torchscript = False
test_resize_embeddings = False
@@ -56,6 +57,8 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
relative_attention_num_buckets=8,
dropout_rate=0.1,
initializer_factor=0.002,
eos_token_ids=[1],
pad_token_id=0,
scope=None,
):
self.parent = parent
@@ -75,20 +78,22 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
self.dropout_rate = dropout_rate
self.initializer_factor = initializer_factor
self.scope = scope
self.eos_token_ids = eos_token_ids
self.pad_token_id = pad_token_id
def prepare_config_and_inputs(self):
encoder_input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size)
input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size)
decoder_input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
encoder_attention_mask = None
attention_mask = None
decoder_attention_mask = None
if self.use_attention_mask:
encoder_attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2)
attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2)
decoder_attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2)
decoder_lm_labels = None
lm_labels = None
if self.use_labels:
decoder_lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
config = T5Config(
vocab_size=self.vocab_size,
@@ -101,41 +106,36 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
relative_attention_num_buckets=self.relative_attention_num_buckets,
dropout_rate=self.dropout_rate,
initializer_factor=self.initializer_factor,
eos_token_ids=self.eos_token_ids,
bos_token_id=self.pad_token_id,
pad_token_id=self.pad_token_id,
)
return (
config,
encoder_input_ids,
input_ids,
decoder_input_ids,
encoder_attention_mask,
attention_mask,
decoder_attention_mask,
decoder_lm_labels,
lm_labels,
)
def check_loss_output(self, result):
self.parent.assertListEqual(list(result["loss"].size()), [])
def create_and_check_t5_model(
self,
config,
encoder_input_ids,
decoder_input_ids,
encoder_attention_mask,
decoder_attention_mask,
decoder_lm_labels,
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
):
model = T5Model(config=config)
model.to(torch_device)
model.eval()
decoder_output, encoder_output = model(
encoder_input_ids=encoder_input_ids,
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
encoder_attention_mask=encoder_attention_mask,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
decoder_output, encoder_output = model(
encoder_input_ids=encoder_input_ids, decoder_input_ids=decoder_input_ids
)
decoder_output, encoder_output = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
result = {
"encoder_output": encoder_output,
@@ -149,22 +149,16 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
)
def create_and_check_t5_with_lm_head(
self,
config,
encoder_input_ids,
decoder_input_ids,
encoder_attention_mask,
decoder_attention_mask,
decoder_lm_labels,
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
):
model = T5WithLMHeadModel(config=config)
model = T5ForConditionalGeneration(config=config)
model.to(torch_device)
model.eval()
outputs = model(
encoder_input_ids=encoder_input_ids,
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
decoder_lm_labels=decoder_lm_labels,
lm_labels=lm_labels,
)
loss, prediction_scores, encoder_features = outputs
self.parent.assertEqual(len(outputs), 3)
@@ -181,17 +175,18 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
encoder_input_ids,
input_ids,
decoder_input_ids,
encoder_attention_mask,
attention_mask,
decoder_attention_mask,
decoder_lm_labels,
lm_labels,
) = config_and_inputs
inputs_dict = {
"encoder_input_ids": encoder_input_ids,
"input_ids": input_ids,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
"encoder_attention_mask": encoder_attention_mask,
}
return config, inputs_dict