Support T5 Generation (#3228)

* fix conflicts

* update bart max length test

* correct spelling mistakes

* implemented model specific encode function

* fix merge conflicts

* better naming

* save intermediate state -> need to rethink strucuture a bit

* leave tf problem as it is for now

* current version

* add layers.pop

* remove ipdb

* make style

* clean return cut decoding

* remove ipdbs

* Fix restoring layers in the decoders that doesnt exists.

* push good intermediate solution for now

* fix conflicts

* always good to refuse to merge conflicts when rebasing

* fix small bug

* improve function calls

* remove unused file

* add correct scope behavior for t5_generate

Co-authored-by: Morgan Funtowicz <funtowiczmo@gmail.com>
This commit is contained in:
Patrick von Platen
2020-03-19 23:18:23 +01:00
committed by GitHub
parent 656e1386a2
commit bbf26c4e61
16 changed files with 449 additions and 280 deletions

View File

@@ -82,7 +82,7 @@ class ModelTester:
dropout=self.hidden_dropout_prob,
attention_dropout=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
eos_token_ids=[2],
eos_token_ids=self.eos_token_ids,
bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id,
)
@@ -234,12 +234,10 @@ class BartHeadTests(unittest.TestCase):
def test_lm_forward(self):
config, input_ids, batch_size = self._get_config_and_data(output_past=False)
decoder_lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size).to(torch_device)
lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size).to(torch_device)
lm_model = BartForConditionalGeneration(config)
lm_model.to(torch_device)
loss, logits, enc_features = lm_model(
input_ids=input_ids, lm_labels=decoder_lm_labels, decoder_input_ids=input_ids
)
loss, logits, enc_features = lm_model(input_ids=input_ids, lm_labels=lm_labels, decoder_input_ids=input_ids)
expected_shape = (batch_size, input_ids.shape[1], config.vocab_size)
self.assertEqual(logits.shape, expected_shape)
self.assertIsInstance(loss.item(), float)
@@ -292,7 +290,7 @@ class BartHeadTests(unittest.TestCase):
no_repeat_ngram_size=3,
max_length=max_length,
)
self.assertEqual(new_input_ids.shape, (input_ids.shape[0], max_length - 1))
self.assertEqual(new_input_ids.shape, (input_ids.shape[0], max_length))
# TODO(SS): uneven length batches, empty inputs
def test_shift_tokens_right(self):