From 3951b9f3908bfa30be7fd814cd2ad1039d3162d8 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Mon, 4 Apr 2022 10:06:57 -0400 Subject: [PATCH] Add utility to find model labels (#16526) * Add utility to find model labels * Use it in the Trainer * Update src/transformers/utils/generic.py Co-authored-by: Matt * Quality Co-authored-by: Matt --- src/transformers/trainer.py | 8 ++---- src/transformers/utils/__init__.py | 1 + src/transformers/utils/generic.py | 21 ++++++++++++++++ tests/utils/test_file_utils.py | 39 +++++++++++++++++++++++++++--- 4 files changed, 59 insertions(+), 10 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 948697e351..921b9d27ac 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -67,7 +67,6 @@ from .deepspeed import deepspeed_init, deepspeed_reinit, is_deepspeed_zero3_enab from .dependency_versions_check import dep_version_check from .modelcard import TrainingSummary from .modeling_utils import PreTrainedModel, unwrap_model -from .models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES from .optimization import Adafactor, get_scheduler from .tokenization_utils_base import PreTrainedTokenizerBase from .trainer_callback import ( @@ -124,6 +123,7 @@ from .training_args import OptimizerNames, ParallelMode, TrainingArguments from .utils import ( CONFIG_NAME, WEIGHTS_NAME, + find_labels, get_full_repo_name, is_apex_available, is_datasets_available, @@ -495,11 +495,7 @@ class Trainer: self.current_flos = 0 self.hp_search_backend = None self.use_tune_checkpoints = False - default_label_names = ( - ["start_positions", "end_positions"] - if type(self.model).__name__ in MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES.values() - else ["labels"] - ) + default_label_names = find_labels(self.model.__class__) self.label_names = default_label_names if self.args.label_names is None else self.args.label_names self.control = self.callback_handler.on_init_end(self.args, self.state, self.control) diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index af326b53e8..45364fb8fd 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -37,6 +37,7 @@ from .generic import ( PaddingStrategy, TensorType, cached_property, + find_labels, is_tensor, to_numpy, to_py_obj, diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index e455cdc6ad..bea5b3dd47 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -15,6 +15,7 @@ Generic utilities """ +import inspect from collections import OrderedDict, UserDict from contextlib import ExitStack from dataclasses import fields @@ -289,3 +290,23 @@ class ContextManagers: def __exit__(self, *args, **kwargs): self.stack.__exit__(*args, **kwargs) + + +def find_labels(model_class): + """ + Find the labels used by a given model. + + Args: + model_class (`type`): The class of the model. + """ + model_name = model_class.__name__ + if model_name.startswith("TF"): + signature = inspect.signature(model_class.call) + elif model_name.startswith("Flax"): + signature = inspect.signature(model_class.__call__) + else: + signature = inspect.signature(model_class.forward) + if "QuestionAnswering" in model_name: + return [p for p in signature.parameters if "label" in p or p in ("start_positions", "end_positions")] + else: + return [p for p in signature.parameters if "label" in p] diff --git a/tests/utils/test_file_utils.py b/tests/utils/test_file_utils.py index decc7fd17c..75c4f19caa 100644 --- a/tests/utils/test_file_utils.py +++ b/tests/utils/test_file_utils.py @@ -35,10 +35,14 @@ from transformers.utils import ( RepositoryNotFoundError, RevisionNotFoundError, filename_to_url, + find_labels, get_file_from_repo, get_from_cache, has_file, hf_bucket_url, + is_flax_available, + is_tf_available, + is_torch_available, ) @@ -158,24 +162,51 @@ class GetFromCacheTests(unittest.TestCase): self.assertIsNone(get_file_from_repo(tmp_dir, "b.txt")) -class ContextManagerTests(unittest.TestCase): +class GenericUtilTests(unittest.TestCase): @unittest.mock.patch("sys.stdout", new_callable=io.StringIO) - def test_no_context(self, mock_stdout): + def test_context_managers_no_context(self, mock_stdout): with ContextManagers([]): print("Transformers are awesome!") # The print statement adds a new line at the end of the output self.assertEqual(mock_stdout.getvalue(), "Transformers are awesome!\n") @unittest.mock.patch("sys.stdout", new_callable=io.StringIO) - def test_one_context(self, mock_stdout): + def test_context_managers_one_context(self, mock_stdout): with ContextManagers([context_en()]): print("Transformers are awesome!") # The output should be wrapped with an English welcome and goodbye self.assertEqual(mock_stdout.getvalue(), "Welcome!\nTransformers are awesome!\nBye!\n") @unittest.mock.patch("sys.stdout", new_callable=io.StringIO) - def test_two_context(self, mock_stdout): + def test_context_managers_two_context(self, mock_stdout): with ContextManagers([context_fr(), context_en()]): print("Transformers are awesome!") # The output should be wrapped with an English and French welcome and goodbye self.assertEqual(mock_stdout.getvalue(), "Bonjour!\nWelcome!\nTransformers are awesome!\nBye!\nAu revoir!\n") + + def test_find_labels(self): + if is_torch_available(): + from transformers import BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification + + self.assertEqual(find_labels(BertForSequenceClassification), ["labels"]) + self.assertEqual(find_labels(BertForPreTraining), ["labels", "next_sentence_label"]) + self.assertEqual(find_labels(BertForQuestionAnswering), ["start_positions", "end_positions"]) + + if is_tf_available(): + from transformers import TFBertForPreTraining, TFBertForQuestionAnswering, TFBertForSequenceClassification + + self.assertEqual(find_labels(TFBertForSequenceClassification), ["labels"]) + self.assertEqual(find_labels(TFBertForPreTraining), ["labels", "next_sentence_label"]) + self.assertEqual(find_labels(TFBertForQuestionAnswering), ["start_positions", "end_positions"]) + + if is_flax_available(): + # Flax models don't have labels + from transformers import ( + FlaxBertForPreTraining, + FlaxBertForQuestionAnswering, + FlaxBertForSequenceClassification, + ) + + self.assertEqual(find_labels(FlaxBertForSequenceClassification), []) + self.assertEqual(find_labels(FlaxBertForPreTraining), []) + self.assertEqual(find_labels(FlaxBertForQuestionAnswering), [])