Add support for multiple models for one config in auto classes (#11150)
* Add support for multiple models for one config in auto classes * Use get_values everywhere * Prettier doc
This commit is contained in:
@@ -25,6 +25,7 @@ from importlib import import_module
|
||||
from typing import List, Tuple
|
||||
|
||||
from transformers import is_tf_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import (
|
||||
_tf_gpu_memory_limit,
|
||||
is_pt_tf_cross_test,
|
||||
@@ -89,7 +90,7 @@ class TFModelTesterMixin:
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False) -> dict:
|
||||
inputs_dict = copy.deepcopy(inputs_dict)
|
||||
|
||||
if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
|
||||
if model_class in get_values(TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
|
||||
inputs_dict = {
|
||||
k: tf.tile(tf.expand_dims(v, 1), (1, self.model_tester.num_choices) + (1,) * (v.ndim - 1))
|
||||
if isinstance(v, tf.Tensor) and v.ndim > 0
|
||||
@@ -98,21 +99,21 @@ class TFModelTesterMixin:
|
||||
}
|
||||
|
||||
if return_labels:
|
||||
if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
|
||||
if model_class in get_values(TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
|
||||
inputs_dict["labels"] = tf.ones(self.model_tester.batch_size, dtype=tf.int32)
|
||||
elif model_class in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
|
||||
elif model_class in get_values(TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING):
|
||||
inputs_dict["start_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
|
||||
inputs_dict["end_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
|
||||
elif model_class in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values():
|
||||
elif model_class in get_values(TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING):
|
||||
inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
|
||||
elif model_class in TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.values():
|
||||
elif model_class in get_values(TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING):
|
||||
inputs_dict["next_sentence_label"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
|
||||
elif model_class in [
|
||||
*TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values(),
|
||||
*TF_MODEL_FOR_CAUSAL_LM_MAPPING.values(),
|
||||
*TF_MODEL_FOR_MASKED_LM_MAPPING.values(),
|
||||
*TF_MODEL_FOR_PRETRAINING_MAPPING.values(),
|
||||
*TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.values(),
|
||||
*get_values(TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING),
|
||||
*get_values(TF_MODEL_FOR_CAUSAL_LM_MAPPING),
|
||||
*get_values(TF_MODEL_FOR_MASKED_LM_MAPPING),
|
||||
*get_values(TF_MODEL_FOR_PRETRAINING_MAPPING),
|
||||
*get_values(TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING),
|
||||
]:
|
||||
inputs_dict["labels"] = tf.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=tf.int32
|
||||
@@ -580,7 +581,7 @@ class TFModelTesterMixin:
|
||||
),
|
||||
"input_ids": tf.keras.Input(batch_shape=(2, max_input), name="input_ids", dtype="int32"),
|
||||
}
|
||||
elif model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
|
||||
elif model_class in get_values(TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
|
||||
input_ids = tf.keras.Input(batch_shape=(4, 2, max_input), name="input_ids", dtype="int32")
|
||||
else:
|
||||
input_ids = tf.keras.Input(batch_shape=(2, max_input), name="input_ids", dtype="int32")
|
||||
@@ -796,9 +797,9 @@ class TFModelTesterMixin:
|
||||
def test_model_common_attributes(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
list_lm_models = (
|
||||
list(TF_MODEL_FOR_CAUSAL_LM_MAPPING.values())
|
||||
+ list(TF_MODEL_FOR_MASKED_LM_MAPPING.values())
|
||||
+ list(TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.values())
|
||||
get_values(TF_MODEL_FOR_CAUSAL_LM_MAPPING)
|
||||
+ get_values(TF_MODEL_FOR_MASKED_LM_MAPPING)
|
||||
+ get_values(TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING)
|
||||
)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
@@ -1128,7 +1129,7 @@ class TFModelTesterMixin:
|
||||
]
|
||||
loss_size = tf.size(added_label)
|
||||
|
||||
if model.__class__ in TF_MODEL_FOR_CAUSAL_LM_MAPPING.values():
|
||||
if model.__class__ in get_values(TF_MODEL_FOR_CAUSAL_LM_MAPPING):
|
||||
# if loss is causal lm loss, labels are shift, so that one label per batch
|
||||
# is cut
|
||||
loss_size = loss_size - self.model_tester.batch_size
|
||||
|
||||
Reference in New Issue
Block a user