Support T5 Generation (#3228)
* fix conflicts * update bart max length test * correct spelling mistakes * implemented model specific encode function * fix merge conflicts * better naming * save intermediate state -> need to rethink strucuture a bit * leave tf problem as it is for now * current version * add layers.pop * remove ipdb * make style * clean return cut decoding * remove ipdbs * Fix restoring layers in the decoders that doesnt exists. * push good intermediate solution for now * fix conflicts * always good to refuse to merge conflicts when rebasing * fix small bug * improve function calls * remove unused file * add correct scope behavior for t5_generate Co-authored-by: Morgan Funtowicz <funtowiczmo@gmail.com>
This commit is contained in:
committed by
GitHub
parent
656e1386a2
commit
bbf26c4e61
@@ -82,7 +82,7 @@ class ModelTester:
|
||||
dropout=self.hidden_dropout_prob,
|
||||
attention_dropout=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
eos_token_ids=[2],
|
||||
eos_token_ids=self.eos_token_ids,
|
||||
bos_token_id=self.bos_token_id,
|
||||
pad_token_id=self.pad_token_id,
|
||||
)
|
||||
@@ -234,12 +234,10 @@ class BartHeadTests(unittest.TestCase):
|
||||
|
||||
def test_lm_forward(self):
|
||||
config, input_ids, batch_size = self._get_config_and_data(output_past=False)
|
||||
decoder_lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size).to(torch_device)
|
||||
lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size).to(torch_device)
|
||||
lm_model = BartForConditionalGeneration(config)
|
||||
lm_model.to(torch_device)
|
||||
loss, logits, enc_features = lm_model(
|
||||
input_ids=input_ids, lm_labels=decoder_lm_labels, decoder_input_ids=input_ids
|
||||
)
|
||||
loss, logits, enc_features = lm_model(input_ids=input_ids, lm_labels=lm_labels, decoder_input_ids=input_ids)
|
||||
expected_shape = (batch_size, input_ids.shape[1], config.vocab_size)
|
||||
self.assertEqual(logits.shape, expected_shape)
|
||||
self.assertIsInstance(loss.item(), float)
|
||||
@@ -292,7 +290,7 @@ class BartHeadTests(unittest.TestCase):
|
||||
no_repeat_ngram_size=3,
|
||||
max_length=max_length,
|
||||
)
|
||||
self.assertEqual(new_input_ids.shape, (input_ids.shape[0], max_length - 1))
|
||||
self.assertEqual(new_input_ids.shape, (input_ids.shape[0], max_length))
|
||||
# TODO(SS): uneven length batches, empty inputs
|
||||
|
||||
def test_shift_tokens_right(self):
|
||||
|
||||
Reference in New Issue
Block a user