Add has_attentions to TFModelTesterMixin as done on PyTorch side (#16259)
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -114,6 +114,7 @@ class TFModelTesterMixin:
|
|||||||
test_resize_embeddings = True
|
test_resize_embeddings = True
|
||||||
test_head_masking = True
|
test_head_masking = True
|
||||||
is_encoder_decoder = False
|
is_encoder_decoder = False
|
||||||
|
has_attentions = True
|
||||||
|
|
||||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False) -> dict:
|
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False) -> dict:
|
||||||
inputs_dict = copy.deepcopy(inputs_dict)
|
inputs_dict = copy.deepcopy(inputs_dict)
|
||||||
@@ -539,9 +540,7 @@ class TFModelTesterMixin:
|
|||||||
|
|
||||||
# Output all for aggressive testing
|
# Output all for aggressive testing
|
||||||
config.output_hidden_states = True
|
config.output_hidden_states = True
|
||||||
# Pure convolutional models have no attention
|
if self.has_attentions:
|
||||||
# TODO: use a better and general criteria
|
|
||||||
if "TFConvNext" not in model_class.__name__:
|
|
||||||
config.output_attentions = True
|
config.output_attentions = True
|
||||||
|
|
||||||
for k in ["attention_mask", "encoder_attention_mask", "decoder_attention_mask"]:
|
for k in ["attention_mask", "encoder_attention_mask", "decoder_attention_mask"]:
|
||||||
@@ -567,8 +566,6 @@ class TFModelTesterMixin:
|
|||||||
pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning
|
pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning
|
||||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||||
|
|
||||||
config.output_hidden_states = True
|
|
||||||
|
|
||||||
tf_model = model_class(config)
|
tf_model = model_class(config)
|
||||||
pt_model = pt_model_class(config)
|
pt_model = pt_model_class(config)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user