Add support for multiple models for one config in auto classes (#11150)
* Add support for multiple models for one config in auto classes * Use get_values everywhere * Prettier doc
This commit is contained in:
@@ -17,6 +17,7 @@
|
||||
import unittest
|
||||
|
||||
from transformers import AlbertConfig, is_tf_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import require_tf, slow
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
@@ -249,7 +250,7 @@ class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||
|
||||
if return_labels:
|
||||
if model_class in TF_MODEL_FOR_PRETRAINING_MAPPING.values():
|
||||
if model_class in get_values(TF_MODEL_FOR_PRETRAINING_MAPPING):
|
||||
inputs_dict["sentence_order_label"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
|
||||
|
||||
return inputs_dict
|
||||
|
||||
Reference in New Issue
Block a user