Add support for multiple models for one config in auto classes (#11150)

* Add support for multiple models for one config in auto classes

* Use get_values everywhere

* Prettier doc
This commit is contained in:
Sylvain Gugger
2021-04-08 18:41:36 -04:00
committed by GitHub
parent 97ccf67bb3
commit ba8b1f4754
26 changed files with 188 additions and 72 deletions

View File

@@ -17,6 +17,7 @@
import unittest
from transformers import AlbertConfig, is_tf_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_tf, slow
from .test_configuration_common import ConfigTester
@@ -249,7 +250,7 @@ class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
if return_labels:
if model_class in TF_MODEL_FOR_PRETRAINING_MAPPING.values():
if model_class in get_values(TF_MODEL_FOR_PRETRAINING_MAPPING):
inputs_dict["sentence_order_label"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
return inputs_dict