Add support for multiple models for one config in auto classes (#11150)

* Add support for multiple models for one config in auto classes

* Use get_values everywhere

* Prettier doc
This commit is contained in:
Sylvain Gugger
2021-04-08 18:41:36 -04:00
committed by GitHub
parent 97ccf67bb3
commit ba8b1f4754
26 changed files with 188 additions and 72 deletions

View File

@@ -25,6 +25,7 @@ from importlib import import_module
from typing import List, Tuple
from transformers import is_tf_available
from transformers.models.auto import get_values
from transformers.testing_utils import (
_tf_gpu_memory_limit,
is_pt_tf_cross_test,
@@ -89,7 +90,7 @@ class TFModelTesterMixin:
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False) -> dict:
inputs_dict = copy.deepcopy(inputs_dict)
if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
if model_class in get_values(TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
inputs_dict = {
k: tf.tile(tf.expand_dims(v, 1), (1, self.model_tester.num_choices) + (1,) * (v.ndim - 1))
if isinstance(v, tf.Tensor) and v.ndim > 0
@@ -98,21 +99,21 @@ class TFModelTesterMixin:
}
if return_labels:
if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
if model_class in get_values(TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
inputs_dict["labels"] = tf.ones(self.model_tester.batch_size, dtype=tf.int32)
elif model_class in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
elif model_class in get_values(TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING):
inputs_dict["start_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
inputs_dict["end_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
elif model_class in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values():
elif model_class in get_values(TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING):
inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
elif model_class in TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.values():
elif model_class in get_values(TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING):
inputs_dict["next_sentence_label"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
elif model_class in [
*TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values(),
*TF_MODEL_FOR_CAUSAL_LM_MAPPING.values(),
*TF_MODEL_FOR_MASKED_LM_MAPPING.values(),
*TF_MODEL_FOR_PRETRAINING_MAPPING.values(),
*TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.values(),
*get_values(TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING),
*get_values(TF_MODEL_FOR_CAUSAL_LM_MAPPING),
*get_values(TF_MODEL_FOR_MASKED_LM_MAPPING),
*get_values(TF_MODEL_FOR_PRETRAINING_MAPPING),
*get_values(TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING),
]:
inputs_dict["labels"] = tf.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=tf.int32
@@ -580,7 +581,7 @@ class TFModelTesterMixin:
),
"input_ids": tf.keras.Input(batch_shape=(2, max_input), name="input_ids", dtype="int32"),
}
elif model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
elif model_class in get_values(TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
input_ids = tf.keras.Input(batch_shape=(4, 2, max_input), name="input_ids", dtype="int32")
else:
input_ids = tf.keras.Input(batch_shape=(2, max_input), name="input_ids", dtype="int32")
@@ -796,9 +797,9 @@ class TFModelTesterMixin:
def test_model_common_attributes(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
list_lm_models = (
list(TF_MODEL_FOR_CAUSAL_LM_MAPPING.values())
+ list(TF_MODEL_FOR_MASKED_LM_MAPPING.values())
+ list(TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.values())
get_values(TF_MODEL_FOR_CAUSAL_LM_MAPPING)
+ get_values(TF_MODEL_FOR_MASKED_LM_MAPPING)
+ get_values(TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING)
)
for model_class in self.all_model_classes:
@@ -1128,7 +1129,7 @@ class TFModelTesterMixin:
]
loss_size = tf.size(added_label)
if model.__class__ in TF_MODEL_FOR_CAUSAL_LM_MAPPING.values():
if model.__class__ in get_values(TF_MODEL_FOR_CAUSAL_LM_MAPPING):
# if loss is causal lm loss, labels are shift, so that one label per batch
# is cut
loss_size = loss_size - self.model_tester.batch_size