Add utility to find model labels (#16526)
* Add utility to find model labels * Use it in the Trainer * Update src/transformers/utils/generic.py Co-authored-by: Matt <Rocketknight1@users.noreply.github.com> * Quality Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
@@ -67,7 +67,6 @@ from .deepspeed import deepspeed_init, deepspeed_reinit, is_deepspeed_zero3_enab
|
|||||||
from .dependency_versions_check import dep_version_check
|
from .dependency_versions_check import dep_version_check
|
||||||
from .modelcard import TrainingSummary
|
from .modelcard import TrainingSummary
|
||||||
from .modeling_utils import PreTrainedModel, unwrap_model
|
from .modeling_utils import PreTrainedModel, unwrap_model
|
||||||
from .models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
|
|
||||||
from .optimization import Adafactor, get_scheduler
|
from .optimization import Adafactor, get_scheduler
|
||||||
from .tokenization_utils_base import PreTrainedTokenizerBase
|
from .tokenization_utils_base import PreTrainedTokenizerBase
|
||||||
from .trainer_callback import (
|
from .trainer_callback import (
|
||||||
@@ -124,6 +123,7 @@ from .training_args import OptimizerNames, ParallelMode, TrainingArguments
|
|||||||
from .utils import (
|
from .utils import (
|
||||||
CONFIG_NAME,
|
CONFIG_NAME,
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
|
find_labels,
|
||||||
get_full_repo_name,
|
get_full_repo_name,
|
||||||
is_apex_available,
|
is_apex_available,
|
||||||
is_datasets_available,
|
is_datasets_available,
|
||||||
@@ -495,11 +495,7 @@ class Trainer:
|
|||||||
self.current_flos = 0
|
self.current_flos = 0
|
||||||
self.hp_search_backend = None
|
self.hp_search_backend = None
|
||||||
self.use_tune_checkpoints = False
|
self.use_tune_checkpoints = False
|
||||||
default_label_names = (
|
default_label_names = find_labels(self.model.__class__)
|
||||||
["start_positions", "end_positions"]
|
|
||||||
if type(self.model).__name__ in MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES.values()
|
|
||||||
else ["labels"]
|
|
||||||
)
|
|
||||||
self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
|
self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
|
||||||
self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)
|
self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)
|
||||||
|
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ from .generic import (
|
|||||||
PaddingStrategy,
|
PaddingStrategy,
|
||||||
TensorType,
|
TensorType,
|
||||||
cached_property,
|
cached_property,
|
||||||
|
find_labels,
|
||||||
is_tensor,
|
is_tensor,
|
||||||
to_numpy,
|
to_numpy,
|
||||||
to_py_obj,
|
to_py_obj,
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
Generic utilities
|
Generic utilities
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import inspect
|
||||||
from collections import OrderedDict, UserDict
|
from collections import OrderedDict, UserDict
|
||||||
from contextlib import ExitStack
|
from contextlib import ExitStack
|
||||||
from dataclasses import fields
|
from dataclasses import fields
|
||||||
@@ -289,3 +290,23 @@ class ContextManagers:
|
|||||||
|
|
||||||
def __exit__(self, *args, **kwargs):
|
def __exit__(self, *args, **kwargs):
|
||||||
self.stack.__exit__(*args, **kwargs)
|
self.stack.__exit__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def find_labels(model_class):
|
||||||
|
"""
|
||||||
|
Find the labels used by a given model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_class (`type`): The class of the model.
|
||||||
|
"""
|
||||||
|
model_name = model_class.__name__
|
||||||
|
if model_name.startswith("TF"):
|
||||||
|
signature = inspect.signature(model_class.call)
|
||||||
|
elif model_name.startswith("Flax"):
|
||||||
|
signature = inspect.signature(model_class.__call__)
|
||||||
|
else:
|
||||||
|
signature = inspect.signature(model_class.forward)
|
||||||
|
if "QuestionAnswering" in model_name:
|
||||||
|
return [p for p in signature.parameters if "label" in p or p in ("start_positions", "end_positions")]
|
||||||
|
else:
|
||||||
|
return [p for p in signature.parameters if "label" in p]
|
||||||
|
|||||||
@@ -35,10 +35,14 @@ from transformers.utils import (
|
|||||||
RepositoryNotFoundError,
|
RepositoryNotFoundError,
|
||||||
RevisionNotFoundError,
|
RevisionNotFoundError,
|
||||||
filename_to_url,
|
filename_to_url,
|
||||||
|
find_labels,
|
||||||
get_file_from_repo,
|
get_file_from_repo,
|
||||||
get_from_cache,
|
get_from_cache,
|
||||||
has_file,
|
has_file,
|
||||||
hf_bucket_url,
|
hf_bucket_url,
|
||||||
|
is_flax_available,
|
||||||
|
is_tf_available,
|
||||||
|
is_torch_available,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -158,24 +162,51 @@ class GetFromCacheTests(unittest.TestCase):
|
|||||||
self.assertIsNone(get_file_from_repo(tmp_dir, "b.txt"))
|
self.assertIsNone(get_file_from_repo(tmp_dir, "b.txt"))
|
||||||
|
|
||||||
|
|
||||||
class ContextManagerTests(unittest.TestCase):
|
class GenericUtilTests(unittest.TestCase):
|
||||||
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
|
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
|
||||||
def test_no_context(self, mock_stdout):
|
def test_context_managers_no_context(self, mock_stdout):
|
||||||
with ContextManagers([]):
|
with ContextManagers([]):
|
||||||
print("Transformers are awesome!")
|
print("Transformers are awesome!")
|
||||||
# The print statement adds a new line at the end of the output
|
# The print statement adds a new line at the end of the output
|
||||||
self.assertEqual(mock_stdout.getvalue(), "Transformers are awesome!\n")
|
self.assertEqual(mock_stdout.getvalue(), "Transformers are awesome!\n")
|
||||||
|
|
||||||
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
|
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
|
||||||
def test_one_context(self, mock_stdout):
|
def test_context_managers_one_context(self, mock_stdout):
|
||||||
with ContextManagers([context_en()]):
|
with ContextManagers([context_en()]):
|
||||||
print("Transformers are awesome!")
|
print("Transformers are awesome!")
|
||||||
# The output should be wrapped with an English welcome and goodbye
|
# The output should be wrapped with an English welcome and goodbye
|
||||||
self.assertEqual(mock_stdout.getvalue(), "Welcome!\nTransformers are awesome!\nBye!\n")
|
self.assertEqual(mock_stdout.getvalue(), "Welcome!\nTransformers are awesome!\nBye!\n")
|
||||||
|
|
||||||
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
|
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
|
||||||
def test_two_context(self, mock_stdout):
|
def test_context_managers_two_context(self, mock_stdout):
|
||||||
with ContextManagers([context_fr(), context_en()]):
|
with ContextManagers([context_fr(), context_en()]):
|
||||||
print("Transformers are awesome!")
|
print("Transformers are awesome!")
|
||||||
# The output should be wrapped with an English and French welcome and goodbye
|
# The output should be wrapped with an English and French welcome and goodbye
|
||||||
self.assertEqual(mock_stdout.getvalue(), "Bonjour!\nWelcome!\nTransformers are awesome!\nBye!\nAu revoir!\n")
|
self.assertEqual(mock_stdout.getvalue(), "Bonjour!\nWelcome!\nTransformers are awesome!\nBye!\nAu revoir!\n")
|
||||||
|
|
||||||
|
def test_find_labels(self):
|
||||||
|
if is_torch_available():
|
||||||
|
from transformers import BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification
|
||||||
|
|
||||||
|
self.assertEqual(find_labels(BertForSequenceClassification), ["labels"])
|
||||||
|
self.assertEqual(find_labels(BertForPreTraining), ["labels", "next_sentence_label"])
|
||||||
|
self.assertEqual(find_labels(BertForQuestionAnswering), ["start_positions", "end_positions"])
|
||||||
|
|
||||||
|
if is_tf_available():
|
||||||
|
from transformers import TFBertForPreTraining, TFBertForQuestionAnswering, TFBertForSequenceClassification
|
||||||
|
|
||||||
|
self.assertEqual(find_labels(TFBertForSequenceClassification), ["labels"])
|
||||||
|
self.assertEqual(find_labels(TFBertForPreTraining), ["labels", "next_sentence_label"])
|
||||||
|
self.assertEqual(find_labels(TFBertForQuestionAnswering), ["start_positions", "end_positions"])
|
||||||
|
|
||||||
|
if is_flax_available():
|
||||||
|
# Flax models don't have labels
|
||||||
|
from transformers import (
|
||||||
|
FlaxBertForPreTraining,
|
||||||
|
FlaxBertForQuestionAnswering,
|
||||||
|
FlaxBertForSequenceClassification,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(find_labels(FlaxBertForSequenceClassification), [])
|
||||||
|
self.assertEqual(find_labels(FlaxBertForPreTraining), [])
|
||||||
|
self.assertEqual(find_labels(FlaxBertForQuestionAnswering), [])
|
||||||
|
|||||||
Reference in New Issue
Block a user