Add utility to find model labels (#16526)
* Add utility to find model labels * Use it in the Trainer * Update src/transformers/utils/generic.py Co-authored-by: Matt <Rocketknight1@users.noreply.github.com> * Quality Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
@@ -35,10 +35,14 @@ from transformers.utils import (
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
filename_to_url,
|
||||
find_labels,
|
||||
get_file_from_repo,
|
||||
get_from_cache,
|
||||
has_file,
|
||||
hf_bucket_url,
|
||||
is_flax_available,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
)
|
||||
|
||||
|
||||
@@ -158,24 +162,51 @@ class GetFromCacheTests(unittest.TestCase):
|
||||
self.assertIsNone(get_file_from_repo(tmp_dir, "b.txt"))
|
||||
|
||||
|
||||
class ContextManagerTests(unittest.TestCase):
|
||||
class GenericUtilTests(unittest.TestCase):
|
||||
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
|
||||
def test_no_context(self, mock_stdout):
|
||||
def test_context_managers_no_context(self, mock_stdout):
|
||||
with ContextManagers([]):
|
||||
print("Transformers are awesome!")
|
||||
# The print statement adds a new line at the end of the output
|
||||
self.assertEqual(mock_stdout.getvalue(), "Transformers are awesome!\n")
|
||||
|
||||
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
|
||||
def test_one_context(self, mock_stdout):
|
||||
def test_context_managers_one_context(self, mock_stdout):
|
||||
with ContextManagers([context_en()]):
|
||||
print("Transformers are awesome!")
|
||||
# The output should be wrapped with an English welcome and goodbye
|
||||
self.assertEqual(mock_stdout.getvalue(), "Welcome!\nTransformers are awesome!\nBye!\n")
|
||||
|
||||
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
|
||||
def test_two_context(self, mock_stdout):
|
||||
def test_context_managers_two_context(self, mock_stdout):
|
||||
with ContextManagers([context_fr(), context_en()]):
|
||||
print("Transformers are awesome!")
|
||||
# The output should be wrapped with an English and French welcome and goodbye
|
||||
self.assertEqual(mock_stdout.getvalue(), "Bonjour!\nWelcome!\nTransformers are awesome!\nBye!\nAu revoir!\n")
|
||||
|
||||
def test_find_labels(self):
|
||||
if is_torch_available():
|
||||
from transformers import BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification
|
||||
|
||||
self.assertEqual(find_labels(BertForSequenceClassification), ["labels"])
|
||||
self.assertEqual(find_labels(BertForPreTraining), ["labels", "next_sentence_label"])
|
||||
self.assertEqual(find_labels(BertForQuestionAnswering), ["start_positions", "end_positions"])
|
||||
|
||||
if is_tf_available():
|
||||
from transformers import TFBertForPreTraining, TFBertForQuestionAnswering, TFBertForSequenceClassification
|
||||
|
||||
self.assertEqual(find_labels(TFBertForSequenceClassification), ["labels"])
|
||||
self.assertEqual(find_labels(TFBertForPreTraining), ["labels", "next_sentence_label"])
|
||||
self.assertEqual(find_labels(TFBertForQuestionAnswering), ["start_positions", "end_positions"])
|
||||
|
||||
if is_flax_available():
|
||||
# Flax models don't have labels
|
||||
from transformers import (
|
||||
FlaxBertForPreTraining,
|
||||
FlaxBertForQuestionAnswering,
|
||||
FlaxBertForSequenceClassification,
|
||||
)
|
||||
|
||||
self.assertEqual(find_labels(FlaxBertForSequenceClassification), [])
|
||||
self.assertEqual(find_labels(FlaxBertForPreTraining), [])
|
||||
self.assertEqual(find_labels(FlaxBertForQuestionAnswering), [])
|
||||
|
||||
Reference in New Issue
Block a user