Fix T5 and BART for TF (#9063)

* Fix T5 for graphe compilation+execution

* Fix BART

* Fix import

* Fix naming

* fix attribute name

* Oops

* fix import

* fix tests

* fix tests

* Update test

* Add mising import

* Address Patrick's comments

* Style

* Address Patrick's comment
This commit is contained in:
Julien Plu
2020-12-14 18:47:00 +01:00
committed by GitHub
parent a9c8bff724
commit df3f4d2aef
8 changed files with 151 additions and 166 deletions

View File

@@ -133,8 +133,6 @@ class TFT5ModelTester:
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
output, past_key_values = outputs
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
@@ -142,7 +140,7 @@ class TFT5ModelTester:
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
output_from_no_past = model(next_input_ids)[0]
output_from_past = model(next_tokens, past_key_values=past_key_values)[0]
output_from_past = model(next_tokens, past_key_values=outputs.past_key_values)[0]
# select random slice
random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
@@ -164,7 +162,7 @@ class TFT5ModelTester:
attn_mask = tf.concat([attn_mask_begin, attn_mask_end], axis=1)
# first forward pass
_, past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True)
outputs = model(input_ids, attention_mask=attn_mask, use_cache=True)
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
@@ -187,7 +185,7 @@ class TFT5ModelTester:
# get two different outputs
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)[0]
output_from_past = model(next_tokens, past_key_values=past_key_values, attention_mask=attn_mask)[0]
output_from_past = model(next_tokens, past_key_values=outputs.past_key_values, attention_mask=attn_mask)[0]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).numpy().item()
@@ -208,8 +206,6 @@ class TFT5ModelTester:
# first forward pass
outputs = model(input_ids, use_cache=True)
output, past_key_values = outputs
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
@@ -217,7 +213,7 @@ class TFT5ModelTester:
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
output_from_no_past = model(next_input_ids)[0]
output_from_past = model(next_tokens, past_key_values=past_key_values)[0]
output_from_past = model(next_tokens, past_key_values=outputs.past_key_values)[0]
self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1])
@@ -236,7 +232,7 @@ class TFT5ModelTester:
"input_ids": input_ids,
"decoder_input_ids": input_ids,
"decoder_attention_mask": input_mask,
"use_cache": tf.convert_to_tensor([False]),
"use_cache": False,
}
return config, inputs_dict
@@ -298,14 +294,6 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
model = TFT5Model.from_pretrained("t5-small")
self.assertIsNotNone(model)
@slow
def test_saved_model_with_attentions_output(self):
pass
@slow
def test_saved_model_with_hidden_states_output(self):
pass
class TFT5EncoderOnlyModelTester:
def __init__(
@@ -411,6 +399,7 @@ class TFT5EncoderOnlyModelTester:
class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
is_encoder_decoder = False
all_model_classes = (TFT5EncoderModel,) if is_tf_available() else ()
def setUp(self):