Output hidden states (#4978)
* Configure all models to use output_hidden_states as argument passed to foward() * Pass all tests * Remove cast_bool_to_primitive in TF Flaubert model * correct tf xlnet * add pytorch test * add tf test * Fix broken tests * Configure all models to use output_hidden_states as argument passed to foward() * Pass all tests * Remove cast_bool_to_primitive in TF Flaubert model * correct tf xlnet * add pytorch test * add tf test * Fix broken tests * Refactor output_hidden_states for mobilebert * Reset and remerge to master Co-authored-by: Joseph Liu <joseph.liu@coinflex.com> Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -143,14 +143,13 @@ class ModelTesterMixin:
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_attentions"] = True
|
||||
config.output_hidden_states = False
|
||||
inputs_dict["output_hidden_states"] = False
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs[-1]
|
||||
self.assertEqual(model.config.output_hidden_states, False)
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
# check that output_attentions also work using config
|
||||
@@ -162,7 +161,6 @@ class ModelTesterMixin:
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs[-1]
|
||||
self.assertEqual(model.config.output_hidden_states, False)
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
if chunk_length is not None:
|
||||
@@ -201,14 +199,13 @@ class ModelTesterMixin:
|
||||
|
||||
# Check attention is always last and order is fine
|
||||
inputs_dict["output_attentions"] = True
|
||||
config.output_hidden_states = True
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
self.assertEqual(out_len + (2 if self.is_encoder_decoder else 1), len(outputs))
|
||||
self.assertEqual(model.config.output_hidden_states, True)
|
||||
|
||||
self_attentions = outputs[-1]
|
||||
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
|
||||
@@ -493,19 +490,16 @@ class ModelTesterMixin:
|
||||
self.assertDictEqual(model.config.pruned_heads, {0: [0], 1: [1, 2], 2: [1, 2]})
|
||||
|
||||
def test_hidden_states_output(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config.output_hidden_states = True
|
||||
def check_hidden_states_output(inputs_dict, config, model_class):
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
hidden_states = outputs[-1]
|
||||
self.assertEqual(model.config.output_hidden_states, True)
|
||||
self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
|
||||
|
||||
self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
|
||||
if hasattr(self.model_tester, "encoder_seq_length"):
|
||||
seq_length = self.model_tester.encoder_seq_length
|
||||
if hasattr(self.model_tester, "chunk_length") and self.model_tester.chunk_length > 1:
|
||||
@@ -517,6 +511,18 @@ class ModelTesterMixin:
|
||||
list(hidden_states[0].shape[-2:]), [seq_length, self.model_tester.hidden_size],
|
||||
)
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
# check that output_hidden_states also work using config
|
||||
del inputs_dict["output_hidden_states"]
|
||||
config.output_hidden_states = True
|
||||
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
def test_resize_tokens_embeddings(self):
|
||||
(original_config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if not self.test_resize_embeddings:
|
||||
|
||||
@@ -392,17 +392,23 @@ class TFModelTesterMixin:
|
||||
def test_hidden_states_output(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config.output_hidden_states = True
|
||||
def check_hidden_states_output(config, inputs_dict, model_class):
|
||||
model = model_class(config)
|
||||
outputs = model(self._prepare_for_class(inputs_dict, model_class))
|
||||
hidden_states = [t.numpy() for t in outputs[-1]]
|
||||
self.assertEqual(model.config.output_hidden_states, True)
|
||||
self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
|
||||
self.assertListEqual(
|
||||
list(hidden_states[0].shape[-2:]), [self.model_tester.seq_length, self.model_tester.hidden_size],
|
||||
)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
check_hidden_states_output(config, inputs_dict, model_class)
|
||||
|
||||
del inputs_dict["output_hidden_states"]
|
||||
config.output_hidden_states = True
|
||||
check_hidden_states_output(config, inputs_dict, model_class)
|
||||
|
||||
def test_model_common_attributes(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user