Add support for multiple models for one config in auto classes (#11150)
* Add support for multiple models for one config in auto classes * Use get_values everywhere * Prettier doc
This commit is contained in:
@@ -387,6 +387,7 @@ class FlaxPreTrainedModel(ABC):
|
|||||||
# get abs dir
|
# get abs dir
|
||||||
save_directory = os.path.abspath(save_directory)
|
save_directory = os.path.abspath(save_directory)
|
||||||
# save config as well
|
# save config as well
|
||||||
|
self.config.architectures = [self.__class__.__name__[4:]]
|
||||||
self.config.save_pretrained(save_directory)
|
self.config.save_pretrained(save_directory)
|
||||||
|
|
||||||
# save model
|
# save model
|
||||||
|
|||||||
@@ -1037,6 +1037,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
|
|||||||
logger.info(f"Saved model created in {saved_model_dir}")
|
logger.info(f"Saved model created in {saved_model_dir}")
|
||||||
|
|
||||||
# Save configuration file
|
# Save configuration file
|
||||||
|
self.config.architectures = [self.__class__.__name__[2:]]
|
||||||
self.config.save_pretrained(save_directory)
|
self.config.save_pretrained(save_directory)
|
||||||
|
|
||||||
# If we save using the predefined names, we can load using `from_pretrained`
|
# If we save using the predefined names, we can load using `from_pretrained`
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from ...file_utils import _BaseLazyModule, is_flax_available, is_tf_available, i
|
|||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
|
"auto_factory": ["get_values"],
|
||||||
"configuration_auto": ["ALL_PRETRAINED_CONFIG_ARCHIVE_MAP", "CONFIG_MAPPING", "MODEL_NAMES_MAPPING", "AutoConfig"],
|
"configuration_auto": ["ALL_PRETRAINED_CONFIG_ARCHIVE_MAP", "CONFIG_MAPPING", "MODEL_NAMES_MAPPING", "AutoConfig"],
|
||||||
"feature_extraction_auto": ["FEATURE_EXTRACTOR_MAPPING", "AutoFeatureExtractor"],
|
"feature_extraction_auto": ["FEATURE_EXTRACTOR_MAPPING", "AutoFeatureExtractor"],
|
||||||
"tokenization_auto": ["TOKENIZER_MAPPING", "AutoTokenizer"],
|
"tokenization_auto": ["TOKENIZER_MAPPING", "AutoTokenizer"],
|
||||||
@@ -104,6 +105,7 @@ if is_flax_available():
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from .auto_factory import get_values
|
||||||
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, MODEL_NAMES_MAPPING, AutoConfig
|
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, MODEL_NAMES_MAPPING, AutoConfig
|
||||||
from .feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor
|
from .feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor
|
||||||
from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
|
from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
|
||||||
|
|||||||
@@ -328,6 +328,26 @@ FROM_PRETRAINED_FLAX_DOCSTRING = """
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _get_model_class(config, model_mapping):
|
||||||
|
supported_models = model_mapping[type(config)]
|
||||||
|
if not isinstance(supported_models, (list, tuple)):
|
||||||
|
return supported_models
|
||||||
|
|
||||||
|
name_to_model = {model.__name__: model for model in supported_models}
|
||||||
|
architectures = getattr(config, "architectures", [])
|
||||||
|
for arch in architectures:
|
||||||
|
if arch in name_to_model:
|
||||||
|
return name_to_model[arch]
|
||||||
|
elif f"TF{arch}" in name_to_model:
|
||||||
|
return name_to_model[f"TF{arch}"]
|
||||||
|
elif f"Flax{arch}" in name_to_model:
|
||||||
|
return name_to_model[f"Flax{arch}"]
|
||||||
|
|
||||||
|
# If not architecture is set in the config or match the supported models, the first element of the tuple is the
|
||||||
|
# defaults.
|
||||||
|
return supported_models[0]
|
||||||
|
|
||||||
|
|
||||||
class _BaseAutoModelClass:
|
class _BaseAutoModelClass:
|
||||||
# Base class for auto models.
|
# Base class for auto models.
|
||||||
_model_mapping = None
|
_model_mapping = None
|
||||||
@@ -341,7 +361,8 @@ class _BaseAutoModelClass:
|
|||||||
|
|
||||||
def from_config(cls, config, **kwargs):
|
def from_config(cls, config, **kwargs):
|
||||||
if type(config) in cls._model_mapping.keys():
|
if type(config) in cls._model_mapping.keys():
|
||||||
return cls._model_mapping[type(config)](config, **kwargs)
|
model_class = _get_model_class(config, cls._model_mapping)
|
||||||
|
return model_class(config, **kwargs)
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
|
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
|
||||||
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
|
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
|
||||||
@@ -356,9 +377,8 @@ class _BaseAutoModelClass:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if type(config) in cls._model_mapping.keys():
|
if type(config) in cls._model_mapping.keys():
|
||||||
return cls._model_mapping[type(config)].from_pretrained(
|
model_class = _get_model_class(config, cls._model_mapping)
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||||
)
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
|
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
|
||||||
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
|
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
|
||||||
@@ -418,3 +438,14 @@ def auto_class_factory(name, model_mapping, checkpoint_for_example="bert-base-ca
|
|||||||
from_pretrained = replace_list_option_in_docstrings(model_mapping)(from_pretrained)
|
from_pretrained = replace_list_option_in_docstrings(model_mapping)(from_pretrained)
|
||||||
new_class.from_pretrained = classmethod(from_pretrained)
|
new_class.from_pretrained = classmethod(from_pretrained)
|
||||||
return new_class
|
return new_class
|
||||||
|
|
||||||
|
|
||||||
|
def get_values(model_mapping):
|
||||||
|
result = []
|
||||||
|
for model in model_mapping.values():
|
||||||
|
if isinstance(model, (list, tuple)):
|
||||||
|
result += list(model)
|
||||||
|
else:
|
||||||
|
result.append(model)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|||||||
@@ -247,29 +247,38 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_class_name(model_class):
|
||||||
|
if isinstance(model_class, (list, tuple)):
|
||||||
|
return " or ".join([f":class:`~transformers.{c.__name__}`" for c in model_class])
|
||||||
|
return f":class:`~transformers.{model_class.__name__}`"
|
||||||
|
|
||||||
|
|
||||||
def _list_model_options(indent, config_to_class=None, use_model_types=True):
|
def _list_model_options(indent, config_to_class=None, use_model_types=True):
|
||||||
if config_to_class is None and not use_model_types:
|
if config_to_class is None and not use_model_types:
|
||||||
raise ValueError("Using `use_model_types=False` requires a `config_to_class` dictionary.")
|
raise ValueError("Using `use_model_types=False` requires a `config_to_class` dictionary.")
|
||||||
if use_model_types:
|
if use_model_types:
|
||||||
if config_to_class is None:
|
if config_to_class is None:
|
||||||
model_type_to_name = {model_type: config.__name__ for model_type, config in CONFIG_MAPPING.items()}
|
model_type_to_name = {
|
||||||
|
model_type: f":class:`~transformers.{config.__name__}`"
|
||||||
|
for model_type, config in CONFIG_MAPPING.items()
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
model_type_to_name = {
|
model_type_to_name = {
|
||||||
model_type: config_to_class[config].__name__
|
model_type: _get_class_name(config_to_class[config])
|
||||||
for model_type, config in CONFIG_MAPPING.items()
|
for model_type, config in CONFIG_MAPPING.items()
|
||||||
if config in config_to_class
|
if config in config_to_class
|
||||||
}
|
}
|
||||||
lines = [
|
lines = [
|
||||||
f"{indent}- **{model_type}** -- :class:`~transformers.{model_type_to_name[model_type]}` ({MODEL_NAMES_MAPPING[model_type]} model)"
|
f"{indent}- **{model_type}** -- {model_type_to_name[model_type]} ({MODEL_NAMES_MAPPING[model_type]} model)"
|
||||||
for model_type in sorted(model_type_to_name.keys())
|
for model_type in sorted(model_type_to_name.keys())
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
config_to_name = {config.__name__: clas.__name__ for config, clas in config_to_class.items()}
|
config_to_name = {config.__name__: _get_class_name(clas) for config, clas in config_to_class.items()}
|
||||||
config_to_model_name = {
|
config_to_model_name = {
|
||||||
config.__name__: MODEL_NAMES_MAPPING[model_type] for model_type, config in CONFIG_MAPPING.items()
|
config.__name__: MODEL_NAMES_MAPPING[model_type] for model_type, config in CONFIG_MAPPING.items()
|
||||||
}
|
}
|
||||||
lines = [
|
lines = [
|
||||||
f"{indent}- :class:`~transformers.{config_name}` configuration class: :class:`~transformers.{config_to_name[config_name]}` ({config_to_model_name[config_name]} model)"
|
f"{indent}- :class:`~transformers.{config_name}` configuration class: {config_to_name[config_name]} ({config_to_model_name[config_name]} model)"
|
||||||
for config_name in sorted(config_to_name.keys())
|
for config_name in sorted(config_to_name.keys())
|
||||||
]
|
]
|
||||||
return "\n".join(lines)
|
return "\n".join(lines)
|
||||||
|
|||||||
@@ -124,6 +124,7 @@ from ..flaubert.modeling_flaubert import (
|
|||||||
)
|
)
|
||||||
from ..fsmt.modeling_fsmt import FSMTForConditionalGeneration, FSMTModel
|
from ..fsmt.modeling_fsmt import FSMTForConditionalGeneration, FSMTModel
|
||||||
from ..funnel.modeling_funnel import (
|
from ..funnel.modeling_funnel import (
|
||||||
|
FunnelBaseModel,
|
||||||
FunnelForMaskedLM,
|
FunnelForMaskedLM,
|
||||||
FunnelForMultipleChoice,
|
FunnelForMultipleChoice,
|
||||||
FunnelForPreTraining,
|
FunnelForPreTraining,
|
||||||
@@ -377,7 +378,7 @@ MODEL_MAPPING = OrderedDict(
|
|||||||
(CTRLConfig, CTRLModel),
|
(CTRLConfig, CTRLModel),
|
||||||
(ElectraConfig, ElectraModel),
|
(ElectraConfig, ElectraModel),
|
||||||
(ReformerConfig, ReformerModel),
|
(ReformerConfig, ReformerModel),
|
||||||
(FunnelConfig, FunnelModel),
|
(FunnelConfig, (FunnelModel, FunnelBaseModel)),
|
||||||
(LxmertConfig, LxmertModel),
|
(LxmertConfig, LxmertModel),
|
||||||
(BertGenerationConfig, BertGenerationEncoder),
|
(BertGenerationConfig, BertGenerationEncoder),
|
||||||
(DebertaConfig, DebertaModel),
|
(DebertaConfig, DebertaModel),
|
||||||
|
|||||||
@@ -91,6 +91,7 @@ from ..flaubert.modeling_tf_flaubert import (
|
|||||||
TFFlaubertWithLMHeadModel,
|
TFFlaubertWithLMHeadModel,
|
||||||
)
|
)
|
||||||
from ..funnel.modeling_tf_funnel import (
|
from ..funnel.modeling_tf_funnel import (
|
||||||
|
TFFunnelBaseModel,
|
||||||
TFFunnelForMaskedLM,
|
TFFunnelForMaskedLM,
|
||||||
TFFunnelForMultipleChoice,
|
TFFunnelForMultipleChoice,
|
||||||
TFFunnelForPreTraining,
|
TFFunnelForPreTraining,
|
||||||
@@ -242,7 +243,7 @@ TF_MODEL_MAPPING = OrderedDict(
|
|||||||
(XLMConfig, TFXLMModel),
|
(XLMConfig, TFXLMModel),
|
||||||
(CTRLConfig, TFCTRLModel),
|
(CTRLConfig, TFCTRLModel),
|
||||||
(ElectraConfig, TFElectraModel),
|
(ElectraConfig, TFElectraModel),
|
||||||
(FunnelConfig, TFFunnelModel),
|
(FunnelConfig, (TFFunnelModel, TFFunnelBaseModel)),
|
||||||
(DPRConfig, TFDPRQuestionEncoder),
|
(DPRConfig, TFDPRQuestionEncoder),
|
||||||
(MPNetConfig, TFMPNetModel),
|
(MPNetConfig, TFMPNetModel),
|
||||||
(BartConfig, TFBartModel),
|
(BartConfig, TFBartModel),
|
||||||
|
|||||||
@@ -17,6 +17,7 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
|
from transformers.models.auto import get_values
|
||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import require_torch, slow, torch_device
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
@@ -234,7 +235,7 @@ class AlbertModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||||
|
|
||||||
if return_labels:
|
if return_labels:
|
||||||
if model_class in MODEL_FOR_PRETRAINING_MAPPING.values():
|
if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
|
||||||
inputs_dict["labels"] = torch.zeros(
|
inputs_dict["labels"] = torch.zeros(
|
||||||
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -13,7 +13,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
@@ -46,6 +47,8 @@ if is_torch_available():
|
|||||||
BertForSequenceClassification,
|
BertForSequenceClassification,
|
||||||
BertForTokenClassification,
|
BertForTokenClassification,
|
||||||
BertModel,
|
BertModel,
|
||||||
|
FunnelBaseModel,
|
||||||
|
FunnelModel,
|
||||||
GPT2Config,
|
GPT2Config,
|
||||||
GPT2LMHeadModel,
|
GPT2LMHeadModel,
|
||||||
RobertaForMaskedLM,
|
RobertaForMaskedLM,
|
||||||
@@ -218,6 +221,21 @@ class AutoModelTest(unittest.TestCase):
|
|||||||
self.assertEqual(model.num_parameters(), 14410)
|
self.assertEqual(model.num_parameters(), 14410)
|
||||||
self.assertEqual(model.num_parameters(only_trainable=True), 14410)
|
self.assertEqual(model.num_parameters(only_trainable=True), 14410)
|
||||||
|
|
||||||
|
def test_from_pretrained_with_tuple_values(self):
|
||||||
|
# For the auto model mapping, FunnelConfig has two models: FunnelModel and FunnelBaseModel
|
||||||
|
model = AutoModel.from_pretrained("sgugger/funnel-random-tiny")
|
||||||
|
self.assertIsInstance(model, FunnelModel)
|
||||||
|
|
||||||
|
config = copy.deepcopy(model.config)
|
||||||
|
config.architectures = ["FunnelBaseModel"]
|
||||||
|
model = AutoModel.from_config(config)
|
||||||
|
self.assertIsInstance(model, FunnelBaseModel)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir)
|
||||||
|
model = AutoModel.from_pretrained(tmp_dir)
|
||||||
|
self.assertIsInstance(model, FunnelBaseModel)
|
||||||
|
|
||||||
def test_parents_and_children_in_mappings(self):
|
def test_parents_and_children_in_mappings(self):
|
||||||
# Test that the children are placed before the parents in the mappings, as the `instanceof` will be triggered
|
# Test that the children are placed before the parents in the mappings, as the `instanceof` will be triggered
|
||||||
# by the parents and will return the wrong configuration type when using auto models
|
# by the parents and will return the wrong configuration type when using auto models
|
||||||
@@ -242,6 +260,12 @@ class AutoModelTest(unittest.TestCase):
|
|||||||
assert not issubclass(
|
assert not issubclass(
|
||||||
child_config, parent_config
|
child_config, parent_config
|
||||||
), f"{child_config.__name__} is child of {parent_config.__name__}"
|
), f"{child_config.__name__} is child of {parent_config.__name__}"
|
||||||
assert not issubclass(
|
|
||||||
child_model, parent_model
|
# Tuplify child_model and parent_model since some of them could be tuples.
|
||||||
), f"{child_config.__name__} is child of {parent_config.__name__}"
|
if not isinstance(child_model, (list, tuple)):
|
||||||
|
child_model = (child_model,)
|
||||||
|
if not isinstance(parent_model, (list, tuple)):
|
||||||
|
parent_model = (parent_model,)
|
||||||
|
|
||||||
|
for child, parent in [(a, b) for a in child_model for b in parent_model]:
|
||||||
|
assert not issubclass(child, parent), f"{child.__name__} is child of {parent.__name__}"
|
||||||
|
|||||||
@@ -17,6 +17,7 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
|
from transformers.models.auto import get_values
|
||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import require_torch, slow, torch_device
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
@@ -444,7 +445,7 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|||||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||||
|
|
||||||
if return_labels:
|
if return_labels:
|
||||||
if model_class in MODEL_FOR_PRETRAINING_MAPPING.values():
|
if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
|
||||||
inputs_dict["labels"] = torch.zeros(
|
inputs_dict["labels"] = torch.zeros(
|
||||||
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import unittest
|
|||||||
|
|
||||||
from tests.test_modeling_common import floats_tensor
|
from tests.test_modeling_common import floats_tensor
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
|
from transformers.models.auto import get_values
|
||||||
from transformers.models.big_bird.tokenization_big_bird import BigBirdTokenizer
|
from transformers.models.big_bird.tokenization_big_bird import BigBirdTokenizer
|
||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import require_torch, slow, torch_device
|
||||||
|
|
||||||
@@ -458,7 +459,7 @@ class BigBirdModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||||
|
|
||||||
if return_labels:
|
if return_labels:
|
||||||
if model_class in MODEL_FOR_PRETRAINING_MAPPING.values():
|
if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
|
||||||
inputs_dict["labels"] = torch.zeros(
|
inputs_dict["labels"] = torch.zeros(
|
||||||
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from typing import List, Tuple
|
|||||||
|
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
from transformers.file_utils import WEIGHTS_NAME
|
from transformers.file_utils import WEIGHTS_NAME
|
||||||
|
from transformers.models.auto import get_values
|
||||||
from transformers.testing_utils import require_torch, require_torch_multi_gpu, slow, torch_device
|
from transformers.testing_utils import require_torch, require_torch_multi_gpu, slow, torch_device
|
||||||
|
|
||||||
|
|
||||||
@@ -79,7 +80,7 @@ class ModelTesterMixin:
|
|||||||
|
|
||||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||||
inputs_dict = copy.deepcopy(inputs_dict)
|
inputs_dict = copy.deepcopy(inputs_dict)
|
||||||
if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
|
if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
|
||||||
inputs_dict = {
|
inputs_dict = {
|
||||||
k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous()
|
k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous()
|
||||||
if isinstance(v, torch.Tensor) and v.ndim > 1
|
if isinstance(v, torch.Tensor) and v.ndim > 1
|
||||||
@@ -88,9 +89,9 @@ class ModelTesterMixin:
|
|||||||
}
|
}
|
||||||
|
|
||||||
if return_labels:
|
if return_labels:
|
||||||
if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
|
if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
|
||||||
inputs_dict["labels"] = torch.ones(self.model_tester.batch_size, dtype=torch.long, device=torch_device)
|
inputs_dict["labels"] = torch.ones(self.model_tester.batch_size, dtype=torch.long, device=torch_device)
|
||||||
elif model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
|
elif model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
|
||||||
inputs_dict["start_positions"] = torch.zeros(
|
inputs_dict["start_positions"] = torch.zeros(
|
||||||
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
||||||
)
|
)
|
||||||
@@ -98,18 +99,18 @@ class ModelTesterMixin:
|
|||||||
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
||||||
)
|
)
|
||||||
elif model_class in [
|
elif model_class in [
|
||||||
*MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values(),
|
*get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING),
|
||||||
*MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.values(),
|
*get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING),
|
||||||
*MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.values(),
|
*get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING),
|
||||||
]:
|
]:
|
||||||
inputs_dict["labels"] = torch.zeros(
|
inputs_dict["labels"] = torch.zeros(
|
||||||
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
||||||
)
|
)
|
||||||
elif model_class in [
|
elif model_class in [
|
||||||
*MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values(),
|
*get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING),
|
||||||
*MODEL_FOR_CAUSAL_LM_MAPPING.values(),
|
*get_values(MODEL_FOR_CAUSAL_LM_MAPPING),
|
||||||
*MODEL_FOR_MASKED_LM_MAPPING.values(),
|
*get_values(MODEL_FOR_MASKED_LM_MAPPING),
|
||||||
*MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.values(),
|
*get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING),
|
||||||
]:
|
]:
|
||||||
inputs_dict["labels"] = torch.zeros(
|
inputs_dict["labels"] = torch.zeros(
|
||||||
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
||||||
@@ -229,7 +230,7 @@ class ModelTesterMixin:
|
|||||||
config.return_dict = True
|
config.return_dict = True
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
if model_class in MODEL_MAPPING.values():
|
if model_class in get_values(MODEL_MAPPING):
|
||||||
continue
|
continue
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
@@ -248,7 +249,7 @@ class ModelTesterMixin:
|
|||||||
config.return_dict = True
|
config.return_dict = True
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
if model_class in MODEL_MAPPING.values():
|
if model_class in get_values(MODEL_MAPPING):
|
||||||
continue
|
continue
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
@@ -312,7 +313,7 @@ class ModelTesterMixin:
|
|||||||
if "labels" in inputs_dict:
|
if "labels" in inputs_dict:
|
||||||
correct_outlen += 1 # loss is added to beginning
|
correct_outlen += 1 # loss is added to beginning
|
||||||
# Question Answering model returns start_logits and end_logits
|
# Question Answering model returns start_logits and end_logits
|
||||||
if model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
|
if model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
|
||||||
correct_outlen += 1 # start_logits and end_logits instead of only 1 output
|
correct_outlen += 1 # start_logits and end_logits instead of only 1 output
|
||||||
if "past_key_values" in outputs:
|
if "past_key_values" in outputs:
|
||||||
correct_outlen += 1 # past_key_values have been returned
|
correct_outlen += 1 # past_key_values have been returned
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import unittest
|
|||||||
|
|
||||||
from tests.test_modeling_common import floats_tensor
|
from tests.test_modeling_common import floats_tensor
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
|
from transformers.models.auto import get_values
|
||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import require_torch, slow, torch_device
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
@@ -352,7 +353,7 @@ class ConvBertModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
if "labels" in inputs_dict:
|
if "labels" in inputs_dict:
|
||||||
correct_outlen += 1 # loss is added to beginning
|
correct_outlen += 1 # loss is added to beginning
|
||||||
# Question Answering model returns start_logits and end_logits
|
# Question Answering model returns start_logits and end_logits
|
||||||
if model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
|
if model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
|
||||||
correct_outlen += 1 # start_logits and end_logits instead of only 1 output
|
correct_outlen += 1 # start_logits and end_logits instead of only 1 output
|
||||||
if "past_key_values" in outputs:
|
if "past_key_values" in outputs:
|
||||||
correct_outlen += 1 # past_key_values have been returned
|
correct_outlen += 1 # past_key_values have been returned
|
||||||
|
|||||||
@@ -17,6 +17,7 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
|
from transformers.models.auto import get_values
|
||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import require_torch, slow, torch_device
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
@@ -292,7 +293,7 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||||
|
|
||||||
if return_labels:
|
if return_labels:
|
||||||
if model_class in MODEL_FOR_PRETRAINING_MAPPING.values():
|
if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
|
||||||
inputs_dict["labels"] = torch.zeros(
|
inputs_dict["labels"] = torch.zeros(
|
||||||
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ if is_flax_available():
|
|||||||
FlaxBertForNextSentencePrediction,
|
FlaxBertForNextSentencePrediction,
|
||||||
FlaxBertForPreTraining,
|
FlaxBertForPreTraining,
|
||||||
FlaxBertForQuestionAnswering,
|
FlaxBertForQuestionAnswering,
|
||||||
|
FlaxBertForSequenceClassification,
|
||||||
FlaxBertForTokenClassification,
|
FlaxBertForTokenClassification,
|
||||||
FlaxBertModel,
|
FlaxBertModel,
|
||||||
)
|
)
|
||||||
@@ -125,6 +126,7 @@ class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
|||||||
FlaxBertForMultipleChoice,
|
FlaxBertForMultipleChoice,
|
||||||
FlaxBertForQuestionAnswering,
|
FlaxBertForQuestionAnswering,
|
||||||
FlaxBertForNextSentencePrediction,
|
FlaxBertForNextSentencePrediction,
|
||||||
|
FlaxBertForSequenceClassification,
|
||||||
FlaxBertForTokenClassification,
|
FlaxBertForTokenClassification,
|
||||||
FlaxBertForQuestionAnswering,
|
FlaxBertForQuestionAnswering,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -17,6 +17,7 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import FunnelTokenizer, is_torch_available
|
from transformers import FunnelTokenizer, is_torch_available
|
||||||
|
from transformers.models.auto import get_values
|
||||||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
@@ -365,7 +366,7 @@ class FunnelModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||||
|
|
||||||
if return_labels:
|
if return_labels:
|
||||||
if model_class in MODEL_FOR_PRETRAINING_MAPPING.values():
|
if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
|
||||||
inputs_dict["labels"] = torch.zeros(
|
inputs_dict["labels"] = torch.zeros(
|
||||||
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import unittest
|
|||||||
|
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
from transformers.file_utils import cached_property
|
from transformers.file_utils import cached_property
|
||||||
|
from transformers.models.auto import get_values
|
||||||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
@@ -412,7 +413,7 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|||||||
if "labels" in inputs_dict:
|
if "labels" in inputs_dict:
|
||||||
correct_outlen += 1 # loss is added to beginning
|
correct_outlen += 1 # loss is added to beginning
|
||||||
# Question Answering model returns start_logits and end_logits
|
# Question Answering model returns start_logits and end_logits
|
||||||
if model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
|
if model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
|
||||||
correct_outlen += 1 # start_logits and end_logits instead of only 1 output
|
correct_outlen += 1 # start_logits and end_logits instead of only 1 output
|
||||||
if "past_key_values" in outputs:
|
if "past_key_values" in outputs:
|
||||||
correct_outlen += 1 # past_key_values have been returned
|
correct_outlen += 1 # past_key_values have been returned
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import copy
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
|
from transformers.models.auto import get_values
|
||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import require_torch, slow, torch_device
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
@@ -532,11 +533,11 @@ class LxmertModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
inputs_dict = copy.deepcopy(inputs_dict)
|
inputs_dict = copy.deepcopy(inputs_dict)
|
||||||
|
|
||||||
if return_labels:
|
if return_labels:
|
||||||
if model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
|
if model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
|
||||||
inputs_dict["labels"] = torch.zeros(
|
inputs_dict["labels"] = torch.zeros(
|
||||||
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
||||||
)
|
)
|
||||||
elif model_class in MODEL_FOR_PRETRAINING_MAPPING.values():
|
elif model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
|
||||||
# special case for models like BERT that use multi-loss training for PreTraining
|
# special case for models like BERT that use multi-loss training for PreTraining
|
||||||
inputs_dict["labels"] = torch.zeros(
|
inputs_dict["labels"] = torch.zeros(
|
||||||
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import os
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
|
from transformers.models.auto import get_values
|
||||||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
@@ -290,7 +291,7 @@ class MegatronBertModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||||
|
|
||||||
if return_labels:
|
if return_labels:
|
||||||
if model_class in MODEL_FOR_PRETRAINING_MAPPING.values():
|
if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
|
||||||
inputs_dict["labels"] = torch.zeros(
|
inputs_dict["labels"] = torch.zeros(
|
||||||
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -17,6 +17,7 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
|
from transformers.models.auto import get_values
|
||||||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
@@ -272,7 +273,7 @@ class MobileBertModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||||
|
|
||||||
if return_labels:
|
if return_labels:
|
||||||
if model_class in MODEL_FOR_PRETRAINING_MAPPING.values():
|
if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
|
||||||
inputs_dict["labels"] = torch.zeros(
|
inputs_dict["labels"] = torch.zeros(
|
||||||
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ from transformers import (
|
|||||||
is_torch_available,
|
is_torch_available,
|
||||||
)
|
)
|
||||||
from transformers.file_utils import cached_property
|
from transformers.file_utils import cached_property
|
||||||
|
from transformers.models.auto import get_values
|
||||||
from transformers.testing_utils import require_scatter, require_torch, slow, torch_device
|
from transformers.testing_utils import require_scatter, require_torch, slow, torch_device
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
@@ -425,7 +426,7 @@ class TapasModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||||
inputs_dict = copy.deepcopy(inputs_dict)
|
inputs_dict = copy.deepcopy(inputs_dict)
|
||||||
if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
|
if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
|
||||||
inputs_dict = {
|
inputs_dict = {
|
||||||
k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous()
|
k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous()
|
||||||
if isinstance(v, torch.Tensor) and v.ndim > 1
|
if isinstance(v, torch.Tensor) and v.ndim > 1
|
||||||
@@ -434,9 +435,9 @@ class TapasModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
}
|
}
|
||||||
|
|
||||||
if return_labels:
|
if return_labels:
|
||||||
if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
|
if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
|
||||||
inputs_dict["labels"] = torch.ones(self.model_tester.batch_size, dtype=torch.long, device=torch_device)
|
inputs_dict["labels"] = torch.ones(self.model_tester.batch_size, dtype=torch.long, device=torch_device)
|
||||||
elif model_class in MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING.values():
|
elif model_class in get_values(MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING):
|
||||||
inputs_dict["labels"] = torch.zeros(
|
inputs_dict["labels"] = torch.zeros(
|
||||||
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
||||||
)
|
)
|
||||||
@@ -457,17 +458,17 @@ class TapasModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
self.model_tester.batch_size, dtype=torch.float, device=torch_device
|
self.model_tester.batch_size, dtype=torch.float, device=torch_device
|
||||||
)
|
)
|
||||||
elif model_class in [
|
elif model_class in [
|
||||||
*MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values(),
|
*get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING),
|
||||||
*MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.values(),
|
*get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING),
|
||||||
]:
|
]:
|
||||||
inputs_dict["labels"] = torch.zeros(
|
inputs_dict["labels"] = torch.zeros(
|
||||||
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
||||||
)
|
)
|
||||||
elif model_class in [
|
elif model_class in [
|
||||||
*MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values(),
|
*get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING),
|
||||||
*MODEL_FOR_CAUSAL_LM_MAPPING.values(),
|
*get_values(MODEL_FOR_CAUSAL_LM_MAPPING),
|
||||||
*MODEL_FOR_MASKED_LM_MAPPING.values(),
|
*get_values(MODEL_FOR_MASKED_LM_MAPPING),
|
||||||
*MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.values(),
|
*get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING),
|
||||||
]:
|
]:
|
||||||
inputs_dict["labels"] = torch.zeros(
|
inputs_dict["labels"] = torch.zeros(
|
||||||
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
||||||
|
|||||||
@@ -17,6 +17,7 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import AlbertConfig, is_tf_available
|
from transformers import AlbertConfig, is_tf_available
|
||||||
|
from transformers.models.auto import get_values
|
||||||
from transformers.testing_utils import require_tf, slow
|
from transformers.testing_utils import require_tf, slow
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
@@ -249,7 +250,7 @@ class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||||
|
|
||||||
if return_labels:
|
if return_labels:
|
||||||
if model_class in TF_MODEL_FOR_PRETRAINING_MAPPING.values():
|
if model_class in get_values(TF_MODEL_FOR_PRETRAINING_MAPPING):
|
||||||
inputs_dict["sentence_order_label"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
|
inputs_dict["sentence_order_label"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
|
||||||
|
|
||||||
return inputs_dict
|
return inputs_dict
|
||||||
|
|||||||
@@ -13,7 +13,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import is_tf_available
|
from transformers import is_tf_available
|
||||||
@@ -39,6 +40,8 @@ if is_tf_available():
|
|||||||
TFBertForQuestionAnswering,
|
TFBertForQuestionAnswering,
|
||||||
TFBertForSequenceClassification,
|
TFBertForSequenceClassification,
|
||||||
TFBertModel,
|
TFBertModel,
|
||||||
|
TFFunnelBaseModel,
|
||||||
|
TFFunnelModel,
|
||||||
TFGPT2LMHeadModel,
|
TFGPT2LMHeadModel,
|
||||||
TFRobertaForMaskedLM,
|
TFRobertaForMaskedLM,
|
||||||
TFT5ForConditionalGeneration,
|
TFT5ForConditionalGeneration,
|
||||||
@@ -176,6 +179,21 @@ class TFAutoModelTest(unittest.TestCase):
|
|||||||
self.assertEqual(model.num_parameters(), 14410)
|
self.assertEqual(model.num_parameters(), 14410)
|
||||||
self.assertEqual(model.num_parameters(only_trainable=True), 14410)
|
self.assertEqual(model.num_parameters(only_trainable=True), 14410)
|
||||||
|
|
||||||
|
def test_from_pretrained_with_tuple_values(self):
|
||||||
|
# For the auto model mapping, FunnelConfig has two models: FunnelModel and FunnelBaseModel
|
||||||
|
model = TFAutoModel.from_pretrained("sgugger/funnel-random-tiny")
|
||||||
|
self.assertIsInstance(model, TFFunnelModel)
|
||||||
|
|
||||||
|
config = copy.deepcopy(model.config)
|
||||||
|
config.architectures = ["FunnelBaseModel"]
|
||||||
|
model = TFAutoModel.from_config(config)
|
||||||
|
self.assertIsInstance(model, TFFunnelBaseModel)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir)
|
||||||
|
model = TFAutoModel.from_pretrained(tmp_dir)
|
||||||
|
self.assertIsInstance(model, TFFunnelBaseModel)
|
||||||
|
|
||||||
def test_parents_and_children_in_mappings(self):
|
def test_parents_and_children_in_mappings(self):
|
||||||
# Test that the children are placed before the parents in the mappings, as the `instanceof` will be triggered
|
# Test that the children are placed before the parents in the mappings, as the `instanceof` will be triggered
|
||||||
# by the parents and will return the wrong configuration type when using auto models
|
# by the parents and will return the wrong configuration type when using auto models
|
||||||
@@ -197,4 +215,12 @@ class TFAutoModelTest(unittest.TestCase):
|
|||||||
for parent_config, parent_model in mapping[: index + 1]:
|
for parent_config, parent_model in mapping[: index + 1]:
|
||||||
with self.subTest(msg=f"Testing if {child_config.__name__} is child of {parent_config.__name__}"):
|
with self.subTest(msg=f"Testing if {child_config.__name__} is child of {parent_config.__name__}"):
|
||||||
self.assertFalse(issubclass(child_config, parent_config))
|
self.assertFalse(issubclass(child_config, parent_config))
|
||||||
self.assertFalse(issubclass(child_model, parent_model))
|
|
||||||
|
# Tuplify child_model and parent_model since some of them could be tuples.
|
||||||
|
if not isinstance(child_model, (list, tuple)):
|
||||||
|
child_model = (child_model,)
|
||||||
|
if not isinstance(parent_model, (list, tuple)):
|
||||||
|
parent_model = (parent_model,)
|
||||||
|
|
||||||
|
for child, parent in [(a, b) for a in child_model for b in parent_model]:
|
||||||
|
assert not issubclass(child, parent), f"{child.__name__} is child of {parent.__name__}"
|
||||||
|
|||||||
@@ -17,6 +17,7 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import BertConfig, is_tf_available
|
from transformers import BertConfig, is_tf_available
|
||||||
|
from transformers.models.auto import get_values
|
||||||
from transformers.testing_utils import require_tf, slow
|
from transformers.testing_utils import require_tf, slow
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
@@ -282,7 +283,7 @@ class TFBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||||
|
|
||||||
if return_labels:
|
if return_labels:
|
||||||
if model_class in TF_MODEL_FOR_PRETRAINING_MAPPING.values():
|
if model_class in get_values(TF_MODEL_FOR_PRETRAINING_MAPPING):
|
||||||
inputs_dict["next_sentence_label"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
|
inputs_dict["next_sentence_label"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
|
||||||
|
|
||||||
return inputs_dict
|
return inputs_dict
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ from importlib import import_module
|
|||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
from transformers import is_tf_available
|
from transformers import is_tf_available
|
||||||
|
from transformers.models.auto import get_values
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
_tf_gpu_memory_limit,
|
_tf_gpu_memory_limit,
|
||||||
is_pt_tf_cross_test,
|
is_pt_tf_cross_test,
|
||||||
@@ -89,7 +90,7 @@ class TFModelTesterMixin:
|
|||||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False) -> dict:
|
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False) -> dict:
|
||||||
inputs_dict = copy.deepcopy(inputs_dict)
|
inputs_dict = copy.deepcopy(inputs_dict)
|
||||||
|
|
||||||
if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
|
if model_class in get_values(TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
|
||||||
inputs_dict = {
|
inputs_dict = {
|
||||||
k: tf.tile(tf.expand_dims(v, 1), (1, self.model_tester.num_choices) + (1,) * (v.ndim - 1))
|
k: tf.tile(tf.expand_dims(v, 1), (1, self.model_tester.num_choices) + (1,) * (v.ndim - 1))
|
||||||
if isinstance(v, tf.Tensor) and v.ndim > 0
|
if isinstance(v, tf.Tensor) and v.ndim > 0
|
||||||
@@ -98,21 +99,21 @@ class TFModelTesterMixin:
|
|||||||
}
|
}
|
||||||
|
|
||||||
if return_labels:
|
if return_labels:
|
||||||
if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
|
if model_class in get_values(TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
|
||||||
inputs_dict["labels"] = tf.ones(self.model_tester.batch_size, dtype=tf.int32)
|
inputs_dict["labels"] = tf.ones(self.model_tester.batch_size, dtype=tf.int32)
|
||||||
elif model_class in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
|
elif model_class in get_values(TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING):
|
||||||
inputs_dict["start_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
|
inputs_dict["start_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
|
||||||
inputs_dict["end_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
|
inputs_dict["end_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
|
||||||
elif model_class in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values():
|
elif model_class in get_values(TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING):
|
||||||
inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
|
inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
|
||||||
elif model_class in TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.values():
|
elif model_class in get_values(TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING):
|
||||||
inputs_dict["next_sentence_label"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
|
inputs_dict["next_sentence_label"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
|
||||||
elif model_class in [
|
elif model_class in [
|
||||||
*TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values(),
|
*get_values(TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING),
|
||||||
*TF_MODEL_FOR_CAUSAL_LM_MAPPING.values(),
|
*get_values(TF_MODEL_FOR_CAUSAL_LM_MAPPING),
|
||||||
*TF_MODEL_FOR_MASKED_LM_MAPPING.values(),
|
*get_values(TF_MODEL_FOR_MASKED_LM_MAPPING),
|
||||||
*TF_MODEL_FOR_PRETRAINING_MAPPING.values(),
|
*get_values(TF_MODEL_FOR_PRETRAINING_MAPPING),
|
||||||
*TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.values(),
|
*get_values(TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING),
|
||||||
]:
|
]:
|
||||||
inputs_dict["labels"] = tf.zeros(
|
inputs_dict["labels"] = tf.zeros(
|
||||||
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=tf.int32
|
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=tf.int32
|
||||||
@@ -580,7 +581,7 @@ class TFModelTesterMixin:
|
|||||||
),
|
),
|
||||||
"input_ids": tf.keras.Input(batch_shape=(2, max_input), name="input_ids", dtype="int32"),
|
"input_ids": tf.keras.Input(batch_shape=(2, max_input), name="input_ids", dtype="int32"),
|
||||||
}
|
}
|
||||||
elif model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
|
elif model_class in get_values(TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
|
||||||
input_ids = tf.keras.Input(batch_shape=(4, 2, max_input), name="input_ids", dtype="int32")
|
input_ids = tf.keras.Input(batch_shape=(4, 2, max_input), name="input_ids", dtype="int32")
|
||||||
else:
|
else:
|
||||||
input_ids = tf.keras.Input(batch_shape=(2, max_input), name="input_ids", dtype="int32")
|
input_ids = tf.keras.Input(batch_shape=(2, max_input), name="input_ids", dtype="int32")
|
||||||
@@ -796,9 +797,9 @@ class TFModelTesterMixin:
|
|||||||
def test_model_common_attributes(self):
|
def test_model_common_attributes(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
list_lm_models = (
|
list_lm_models = (
|
||||||
list(TF_MODEL_FOR_CAUSAL_LM_MAPPING.values())
|
get_values(TF_MODEL_FOR_CAUSAL_LM_MAPPING)
|
||||||
+ list(TF_MODEL_FOR_MASKED_LM_MAPPING.values())
|
+ get_values(TF_MODEL_FOR_MASKED_LM_MAPPING)
|
||||||
+ list(TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.values())
|
+ get_values(TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING)
|
||||||
)
|
)
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
@@ -1128,7 +1129,7 @@ class TFModelTesterMixin:
|
|||||||
]
|
]
|
||||||
loss_size = tf.size(added_label)
|
loss_size = tf.size(added_label)
|
||||||
|
|
||||||
if model.__class__ in TF_MODEL_FOR_CAUSAL_LM_MAPPING.values():
|
if model.__class__ in get_values(TF_MODEL_FOR_CAUSAL_LM_MAPPING):
|
||||||
# if loss is causal lm loss, labels are shift, so that one label per batch
|
# if loss is causal lm loss, labels are shift, so that one label per batch
|
||||||
# is cut
|
# is cut
|
||||||
loss_size = loss_size - self.model_tester.batch_size
|
loss_size = loss_size - self.model_tester.batch_size
|
||||||
|
|||||||
@@ -19,6 +19,8 @@ import os
|
|||||||
import re
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from transformers.models.auto import get_values
|
||||||
|
|
||||||
|
|
||||||
# All paths are set with the intent you should run this script from the root of the repo with the command
|
# All paths are set with the intent you should run this script from the root of the repo with the command
|
||||||
# python utils/check_repo.py
|
# python utils/check_repo.py
|
||||||
@@ -86,7 +88,6 @@ IGNORE_NON_AUTO_CONFIGURED = [
|
|||||||
"DPRReader",
|
"DPRReader",
|
||||||
"DPRSpanPredictor",
|
"DPRSpanPredictor",
|
||||||
"FlaubertForQuestionAnswering",
|
"FlaubertForQuestionAnswering",
|
||||||
"FunnelBaseModel",
|
|
||||||
"GPT2DoubleHeadsModel",
|
"GPT2DoubleHeadsModel",
|
||||||
"OpenAIGPTDoubleHeadsModel",
|
"OpenAIGPTDoubleHeadsModel",
|
||||||
"RagModel",
|
"RagModel",
|
||||||
@@ -95,7 +96,6 @@ IGNORE_NON_AUTO_CONFIGURED = [
|
|||||||
"T5Stack",
|
"T5Stack",
|
||||||
"TFDPRReader",
|
"TFDPRReader",
|
||||||
"TFDPRSpanPredictor",
|
"TFDPRSpanPredictor",
|
||||||
"TFFunnelBaseModel",
|
|
||||||
"TFGPT2DoubleHeadsModel",
|
"TFGPT2DoubleHeadsModel",
|
||||||
"TFOpenAIGPTDoubleHeadsModel",
|
"TFOpenAIGPTDoubleHeadsModel",
|
||||||
"TFRagModel",
|
"TFRagModel",
|
||||||
@@ -153,7 +153,7 @@ def get_model_modules():
|
|||||||
def get_models(module):
|
def get_models(module):
|
||||||
""" Get the objects in module that are models."""
|
""" Get the objects in module that are models."""
|
||||||
models = []
|
models = []
|
||||||
model_classes = (transformers.PreTrainedModel, transformers.TFPreTrainedModel)
|
model_classes = (transformers.PreTrainedModel, transformers.TFPreTrainedModel, transformers.FlaxPreTrainedModel)
|
||||||
for attr_name in dir(module):
|
for attr_name in dir(module):
|
||||||
if "Pretrained" in attr_name or "PreTrained" in attr_name:
|
if "Pretrained" in attr_name or "PreTrained" in attr_name:
|
||||||
continue
|
continue
|
||||||
@@ -249,10 +249,13 @@ def get_all_auto_configured_models():
|
|||||||
result = set() # To avoid duplicates we concatenate all model classes in a set.
|
result = set() # To avoid duplicates we concatenate all model classes in a set.
|
||||||
for attr_name in dir(transformers.models.auto.modeling_auto):
|
for attr_name in dir(transformers.models.auto.modeling_auto):
|
||||||
if attr_name.startswith("MODEL_") and attr_name.endswith("MAPPING"):
|
if attr_name.startswith("MODEL_") and attr_name.endswith("MAPPING"):
|
||||||
result = result | set(getattr(transformers.models.auto.modeling_auto, attr_name).values())
|
result = result | set(get_values(getattr(transformers.models.auto.modeling_auto, attr_name)))
|
||||||
for attr_name in dir(transformers.models.auto.modeling_tf_auto):
|
for attr_name in dir(transformers.models.auto.modeling_tf_auto):
|
||||||
if attr_name.startswith("TF_MODEL_") and attr_name.endswith("MAPPING"):
|
if attr_name.startswith("TF_MODEL_") and attr_name.endswith("MAPPING"):
|
||||||
result = result | set(getattr(transformers.models.auto.modeling_tf_auto, attr_name).values())
|
result = result | set(get_values(getattr(transformers.models.auto.modeling_tf_auto, attr_name)))
|
||||||
|
for attr_name in dir(transformers.models.auto.modeling_flax_auto):
|
||||||
|
if attr_name.startswith("FLAX_MODEL_") and attr_name.endswith("MAPPING"):
|
||||||
|
result = result | set(get_values(getattr(transformers.models.auto.modeling_flax_auto, attr_name)))
|
||||||
return [cls.__name__ for cls in result]
|
return [cls.__name__ for cls in result]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user