Fix T5 and BART for TF (#9063)
* Fix T5 for graphe compilation+execution * Fix BART * Fix import * Fix naming * fix attribute name * Oops * fix import * fix tests * fix tests * Update test * Add mising import * Address Patrick's comments * Style * Address Patrick's comment
This commit is contained in:
@@ -133,8 +133,6 @@ class TFT5ModelTester:
|
||||
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
|
||||
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
|
||||
|
||||
output, past_key_values = outputs
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||
|
||||
@@ -142,7 +140,7 @@ class TFT5ModelTester:
|
||||
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
|
||||
|
||||
output_from_no_past = model(next_input_ids)[0]
|
||||
output_from_past = model(next_tokens, past_key_values=past_key_values)[0]
|
||||
output_from_past = model(next_tokens, past_key_values=outputs.past_key_values)[0]
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
|
||||
@@ -164,7 +162,7 @@ class TFT5ModelTester:
|
||||
attn_mask = tf.concat([attn_mask_begin, attn_mask_end], axis=1)
|
||||
|
||||
# first forward pass
|
||||
_, past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True)
|
||||
outputs = model(input_ids, attention_mask=attn_mask, use_cache=True)
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||
@@ -187,7 +185,7 @@ class TFT5ModelTester:
|
||||
|
||||
# get two different outputs
|
||||
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)[0]
|
||||
output_from_past = model(next_tokens, past_key_values=past_key_values, attention_mask=attn_mask)[0]
|
||||
output_from_past = model(next_tokens, past_key_values=outputs.past_key_values, attention_mask=attn_mask)[0]
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).numpy().item()
|
||||
@@ -208,8 +206,6 @@ class TFT5ModelTester:
|
||||
# first forward pass
|
||||
outputs = model(input_ids, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||||
|
||||
@@ -217,7 +213,7 @@ class TFT5ModelTester:
|
||||
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
|
||||
|
||||
output_from_no_past = model(next_input_ids)[0]
|
||||
output_from_past = model(next_tokens, past_key_values=past_key_values)[0]
|
||||
output_from_past = model(next_tokens, past_key_values=outputs.past_key_values)[0]
|
||||
|
||||
self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1])
|
||||
|
||||
@@ -236,7 +232,7 @@ class TFT5ModelTester:
|
||||
"input_ids": input_ids,
|
||||
"decoder_input_ids": input_ids,
|
||||
"decoder_attention_mask": input_mask,
|
||||
"use_cache": tf.convert_to_tensor([False]),
|
||||
"use_cache": False,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
@@ -298,14 +294,6 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
model = TFT5Model.from_pretrained("t5-small")
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@slow
|
||||
def test_saved_model_with_attentions_output(self):
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_saved_model_with_hidden_states_output(self):
|
||||
pass
|
||||
|
||||
|
||||
class TFT5EncoderOnlyModelTester:
|
||||
def __init__(
|
||||
@@ -411,6 +399,7 @@ class TFT5EncoderOnlyModelTester:
|
||||
|
||||
|
||||
class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
is_encoder_decoder = False
|
||||
all_model_classes = (TFT5EncoderModel,) if is_tf_available() else ()
|
||||
|
||||
def setUp(self):
|
||||
|
||||
Reference in New Issue
Block a user