Add utility to find model labels (#16526)
* Add utility to find model labels * Use it in the Trainer * Update src/transformers/utils/generic.py Co-authored-by: Matt <Rocketknight1@users.noreply.github.com> * Quality Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
@@ -67,7 +67,6 @@ from .deepspeed import deepspeed_init, deepspeed_reinit, is_deepspeed_zero3_enab
|
||||
from .dependency_versions_check import dep_version_check
|
||||
from .modelcard import TrainingSummary
|
||||
from .modeling_utils import PreTrainedModel, unwrap_model
|
||||
from .models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
|
||||
from .optimization import Adafactor, get_scheduler
|
||||
from .tokenization_utils_base import PreTrainedTokenizerBase
|
||||
from .trainer_callback import (
|
||||
@@ -124,6 +123,7 @@ from .training_args import OptimizerNames, ParallelMode, TrainingArguments
|
||||
from .utils import (
|
||||
CONFIG_NAME,
|
||||
WEIGHTS_NAME,
|
||||
find_labels,
|
||||
get_full_repo_name,
|
||||
is_apex_available,
|
||||
is_datasets_available,
|
||||
@@ -495,11 +495,7 @@ class Trainer:
|
||||
self.current_flos = 0
|
||||
self.hp_search_backend = None
|
||||
self.use_tune_checkpoints = False
|
||||
default_label_names = (
|
||||
["start_positions", "end_positions"]
|
||||
if type(self.model).__name__ in MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES.values()
|
||||
else ["labels"]
|
||||
)
|
||||
default_label_names = find_labels(self.model.__class__)
|
||||
self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
|
||||
self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)
|
||||
|
||||
|
||||
@@ -37,6 +37,7 @@ from .generic import (
|
||||
PaddingStrategy,
|
||||
TensorType,
|
||||
cached_property,
|
||||
find_labels,
|
||||
is_tensor,
|
||||
to_numpy,
|
||||
to_py_obj,
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
Generic utilities
|
||||
"""
|
||||
|
||||
import inspect
|
||||
from collections import OrderedDict, UserDict
|
||||
from contextlib import ExitStack
|
||||
from dataclasses import fields
|
||||
@@ -289,3 +290,23 @@ class ContextManagers:
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
self.stack.__exit__(*args, **kwargs)
|
||||
|
||||
|
||||
def find_labels(model_class):
|
||||
"""
|
||||
Find the labels used by a given model.
|
||||
|
||||
Args:
|
||||
model_class (`type`): The class of the model.
|
||||
"""
|
||||
model_name = model_class.__name__
|
||||
if model_name.startswith("TF"):
|
||||
signature = inspect.signature(model_class.call)
|
||||
elif model_name.startswith("Flax"):
|
||||
signature = inspect.signature(model_class.__call__)
|
||||
else:
|
||||
signature = inspect.signature(model_class.forward)
|
||||
if "QuestionAnswering" in model_name:
|
||||
return [p for p in signature.parameters if "label" in p or p in ("start_positions", "end_positions")]
|
||||
else:
|
||||
return [p for p in signature.parameters if "label" in p]
|
||||
|
||||
@@ -35,10 +35,14 @@ from transformers.utils import (
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
filename_to_url,
|
||||
find_labels,
|
||||
get_file_from_repo,
|
||||
get_from_cache,
|
||||
has_file,
|
||||
hf_bucket_url,
|
||||
is_flax_available,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
)
|
||||
|
||||
|
||||
@@ -158,24 +162,51 @@ class GetFromCacheTests(unittest.TestCase):
|
||||
self.assertIsNone(get_file_from_repo(tmp_dir, "b.txt"))
|
||||
|
||||
|
||||
class ContextManagerTests(unittest.TestCase):
|
||||
class GenericUtilTests(unittest.TestCase):
|
||||
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
|
||||
def test_no_context(self, mock_stdout):
|
||||
def test_context_managers_no_context(self, mock_stdout):
|
||||
with ContextManagers([]):
|
||||
print("Transformers are awesome!")
|
||||
# The print statement adds a new line at the end of the output
|
||||
self.assertEqual(mock_stdout.getvalue(), "Transformers are awesome!\n")
|
||||
|
||||
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
|
||||
def test_one_context(self, mock_stdout):
|
||||
def test_context_managers_one_context(self, mock_stdout):
|
||||
with ContextManagers([context_en()]):
|
||||
print("Transformers are awesome!")
|
||||
# The output should be wrapped with an English welcome and goodbye
|
||||
self.assertEqual(mock_stdout.getvalue(), "Welcome!\nTransformers are awesome!\nBye!\n")
|
||||
|
||||
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
|
||||
def test_two_context(self, mock_stdout):
|
||||
def test_context_managers_two_context(self, mock_stdout):
|
||||
with ContextManagers([context_fr(), context_en()]):
|
||||
print("Transformers are awesome!")
|
||||
# The output should be wrapped with an English and French welcome and goodbye
|
||||
self.assertEqual(mock_stdout.getvalue(), "Bonjour!\nWelcome!\nTransformers are awesome!\nBye!\nAu revoir!\n")
|
||||
|
||||
def test_find_labels(self):
|
||||
if is_torch_available():
|
||||
from transformers import BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification
|
||||
|
||||
self.assertEqual(find_labels(BertForSequenceClassification), ["labels"])
|
||||
self.assertEqual(find_labels(BertForPreTraining), ["labels", "next_sentence_label"])
|
||||
self.assertEqual(find_labels(BertForQuestionAnswering), ["start_positions", "end_positions"])
|
||||
|
||||
if is_tf_available():
|
||||
from transformers import TFBertForPreTraining, TFBertForQuestionAnswering, TFBertForSequenceClassification
|
||||
|
||||
self.assertEqual(find_labels(TFBertForSequenceClassification), ["labels"])
|
||||
self.assertEqual(find_labels(TFBertForPreTraining), ["labels", "next_sentence_label"])
|
||||
self.assertEqual(find_labels(TFBertForQuestionAnswering), ["start_positions", "end_positions"])
|
||||
|
||||
if is_flax_available():
|
||||
# Flax models don't have labels
|
||||
from transformers import (
|
||||
FlaxBertForPreTraining,
|
||||
FlaxBertForQuestionAnswering,
|
||||
FlaxBertForSequenceClassification,
|
||||
)
|
||||
|
||||
self.assertEqual(find_labels(FlaxBertForSequenceClassification), [])
|
||||
self.assertEqual(find_labels(FlaxBertForPreTraining), [])
|
||||
self.assertEqual(find_labels(FlaxBertForQuestionAnswering), [])
|
||||
|
||||
Reference in New Issue
Block a user