Add utility to find model labels (#16526)

* Add utility to find model labels

* Use it in the Trainer

* Update src/transformers/utils/generic.py

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>

* Quality

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
Sylvain Gugger
2022-04-04 10:06:57 -04:00
committed by GitHub
parent ec4da72fe9
commit 3951b9f390
4 changed files with 59 additions and 10 deletions

View File

@@ -35,10 +35,14 @@ from transformers.utils import (
RepositoryNotFoundError,
RevisionNotFoundError,
filename_to_url,
find_labels,
get_file_from_repo,
get_from_cache,
has_file,
hf_bucket_url,
is_flax_available,
is_tf_available,
is_torch_available,
)
@@ -158,24 +162,51 @@ class GetFromCacheTests(unittest.TestCase):
self.assertIsNone(get_file_from_repo(tmp_dir, "b.txt"))
class ContextManagerTests(unittest.TestCase):
class GenericUtilTests(unittest.TestCase):
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
def test_no_context(self, mock_stdout):
def test_context_managers_no_context(self, mock_stdout):
with ContextManagers([]):
print("Transformers are awesome!")
# The print statement adds a new line at the end of the output
self.assertEqual(mock_stdout.getvalue(), "Transformers are awesome!\n")
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
def test_one_context(self, mock_stdout):
def test_context_managers_one_context(self, mock_stdout):
with ContextManagers([context_en()]):
print("Transformers are awesome!")
# The output should be wrapped with an English welcome and goodbye
self.assertEqual(mock_stdout.getvalue(), "Welcome!\nTransformers are awesome!\nBye!\n")
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
def test_two_context(self, mock_stdout):
def test_context_managers_two_context(self, mock_stdout):
with ContextManagers([context_fr(), context_en()]):
print("Transformers are awesome!")
# The output should be wrapped with an English and French welcome and goodbye
self.assertEqual(mock_stdout.getvalue(), "Bonjour!\nWelcome!\nTransformers are awesome!\nBye!\nAu revoir!\n")
def test_find_labels(self):
if is_torch_available():
from transformers import BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification
self.assertEqual(find_labels(BertForSequenceClassification), ["labels"])
self.assertEqual(find_labels(BertForPreTraining), ["labels", "next_sentence_label"])
self.assertEqual(find_labels(BertForQuestionAnswering), ["start_positions", "end_positions"])
if is_tf_available():
from transformers import TFBertForPreTraining, TFBertForQuestionAnswering, TFBertForSequenceClassification
self.assertEqual(find_labels(TFBertForSequenceClassification), ["labels"])
self.assertEqual(find_labels(TFBertForPreTraining), ["labels", "next_sentence_label"])
self.assertEqual(find_labels(TFBertForQuestionAnswering), ["start_positions", "end_positions"])
if is_flax_available():
# Flax models don't have labels
from transformers import (
FlaxBertForPreTraining,
FlaxBertForQuestionAnswering,
FlaxBertForSequenceClassification,
)
self.assertEqual(find_labels(FlaxBertForSequenceClassification), [])
self.assertEqual(find_labels(FlaxBertForPreTraining), [])
self.assertEqual(find_labels(FlaxBertForQuestionAnswering), [])