Fix T5 and BART for TF (#9063)
* Fix T5 for graphe compilation+execution * Fix BART * Fix import * Fix naming * fix attribute name * Oops * fix import * fix tests * fix tests * Update test * Add mising import * Address Patrick's comments * Style * Address Patrick's comment
This commit is contained in:
@@ -171,6 +171,11 @@ class TFModelTesterMixin:
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
# A saved model is always executed in graph mode, since we merged the PR #8777
|
||||
# the booleans in graph mode are always the ones in the config, then we update
|
||||
# the use_cache property if it exists in order to have similar booleans with the inputs
|
||||
if "use_cache" in class_inputs_dict:
|
||||
config.use_cache = class_inputs_dict.pop("use_cache")
|
||||
model = model_class(config)
|
||||
num_out = len(model(class_inputs_dict))
|
||||
model._saved_model_inputs_spec = None
|
||||
@@ -207,6 +212,11 @@ class TFModelTesterMixin:
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
# A saved model is always executed in graph mode, since we merged the PR #8777
|
||||
# the booleans in graph mode are always the ones in the config, then we update
|
||||
# the use_cache property if it exists in order to have similar booleans with the inputs
|
||||
if "use_cache" in class_inputs_dict:
|
||||
config.use_cache = class_inputs_dict.pop("use_cache")
|
||||
model = model_class(config)
|
||||
num_out = len(model(class_inputs_dict))
|
||||
model._saved_model_inputs_spec = None
|
||||
@@ -249,10 +259,11 @@ class TFModelTesterMixin:
|
||||
if "T5" in main_layer_class.__name__:
|
||||
# Take the same values than in TFT5ModelTester for this shared layer
|
||||
shared = TFSharedEmbeddings(99, 32, name="shared")
|
||||
config.use_cache = False
|
||||
config.use_cache = inputs_dict.pop("use_cache", None)
|
||||
main_layer = main_layer_class(config, embed_tokens=shared)
|
||||
else:
|
||||
main_layer = main_layer_class(config)
|
||||
|
||||
symbolic_inputs = {
|
||||
name: tf.keras.Input(tensor.shape[1:], dtype=tensor.dtype) for name, tensor in inputs_dict.items()
|
||||
}
|
||||
@@ -321,10 +332,13 @@ class TFModelTesterMixin:
|
||||
|
||||
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
|
||||
pt_model.eval()
|
||||
pt_inputs_dict = dict(
|
||||
(name, torch.from_numpy(key.numpy()).to(torch.long))
|
||||
for name, key in self._prepare_for_class(inputs_dict, model_class).items()
|
||||
)
|
||||
pt_inputs_dict = {}
|
||||
for name, key in self._prepare_for_class(inputs_dict, model_class).items():
|
||||
if type(key) == bool:
|
||||
pt_inputs_dict[name] = key
|
||||
else:
|
||||
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
|
||||
|
||||
# need to rename encoder-decoder "inputs" for PyTorch
|
||||
if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
|
||||
pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")
|
||||
@@ -358,10 +372,13 @@ class TFModelTesterMixin:
|
||||
|
||||
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
|
||||
pt_model.eval()
|
||||
pt_inputs_dict = dict(
|
||||
(name, torch.from_numpy(key.numpy()).to(torch.long))
|
||||
for name, key in self._prepare_for_class(inputs_dict, model_class).items()
|
||||
)
|
||||
pt_inputs_dict = {}
|
||||
for name, key in self._prepare_for_class(inputs_dict, model_class).items():
|
||||
if type(key) == bool:
|
||||
key = np.array(key, dtype=bool)
|
||||
pt_inputs_dict[name] = torch.from_numpy(key).to(torch.long)
|
||||
else:
|
||||
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
|
||||
# need to rename encoder-decoder "inputs" for PyTorch
|
||||
if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
|
||||
pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")
|
||||
@@ -574,13 +591,29 @@ class TFModelTesterMixin:
|
||||
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
|
||||
)
|
||||
|
||||
hidden_states = outputs[-1]
|
||||
self.assertEqual(config.output_attentions, False)
|
||||
self.assertEqual(len(hidden_states), expected_num_layers)
|
||||
self.assertListEqual(
|
||||
list(hidden_states[0].shape[-2:]),
|
||||
[self.model_tester.seq_length, self.model_tester.hidden_size],
|
||||
)
|
||||
if model.config.is_encoder_decoder:
|
||||
encoder_hidden_states = outputs.encoder_hidden_states
|
||||
decoder_hidden_states = outputs.decoder_hidden_states
|
||||
|
||||
self.assertEqual(config.output_attentions, False)
|
||||
self.assertEqual(len(encoder_hidden_states), expected_num_layers)
|
||||
self.assertListEqual(
|
||||
list(encoder_hidden_states[0].shape[-2:]),
|
||||
[self.model_tester.seq_length, self.model_tester.hidden_size],
|
||||
)
|
||||
self.assertEqual(len(decoder_hidden_states), expected_num_layers)
|
||||
self.assertListEqual(
|
||||
list(decoder_hidden_states[0].shape[-2:]),
|
||||
[self.model_tester.seq_length, self.model_tester.hidden_size],
|
||||
)
|
||||
else:
|
||||
hidden_states = outputs.hidden_states
|
||||
self.assertEqual(config.output_attentions, False)
|
||||
self.assertEqual(len(hidden_states), expected_num_layers)
|
||||
self.assertListEqual(
|
||||
list(hidden_states[0].shape[-2:]),
|
||||
[self.model_tester.seq_length, self.model_tester.hidden_size],
|
||||
)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
@@ -796,7 +829,7 @@ class TFModelTesterMixin:
|
||||
|
||||
def test_lm_head_model_random_beam_search_generate(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"]
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
Reference in New Issue
Block a user