Generalize problem_type to all sequence classification models (#14180)

* Generalize problem_type to all classification models

* Missing import

* Deberta BC and fix tests

* Fix template

* Missing imports

* Revert change to reformer test

* Fix style
This commit is contained in:
Sylvain Gugger
2021-10-29 10:32:56 -04:00
committed by GitHub
parent 4ab6a4a086
commit c28bc80bbb
38 changed files with 474 additions and 191 deletions

View File

@@ -350,7 +350,6 @@ class XLMModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_generative_model_classes = (
(XLMWithLMHeadModel,) if is_torch_available() else ()
) # TODO (PVP): Check other models whether language generation is also applicable
test_sequence_classification_problem_types = True
# XLM has 2 QA models -> need to manually set the correct labels for one of them here
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):