Add multi-class, multi-label and regression to transformers (#11012)
* add to bert * review comments * Update src/transformers/configuration_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/configuration_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * self.config.problem_type * fix style * fix * fin * fix * update doc * fix * test * Test more problem types * Update src/transformers/configuration_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * fix * remove * fix * quality * make fix-copies * remove test Co-authored-by: abhishek thakur <abhishekkrthakur@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
@@ -230,6 +230,8 @@ class AlbertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
|
||||
test_sequence_classification_problem_types = True
|
||||
|
||||
# special case for ForPreTraining model
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||
|
||||
@@ -439,6 +439,7 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else ()
|
||||
test_sequence_classification_problem_types = True
|
||||
|
||||
# special case for ForPreTraining model
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
|
||||
@@ -433,6 +433,7 @@ class BigBirdModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
# head masking & pruning is currently not supported for big bird
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
test_sequence_classification_problem_types = True
|
||||
|
||||
# torchscript should be possible, but takes prohibitively long to test.
|
||||
# Also torchscript is not an important feature to have in the beginning.
|
||||
|
||||
@@ -89,6 +89,7 @@ class ModelTesterMixin:
|
||||
test_missing_keys = True
|
||||
test_model_parallel = False
|
||||
is_encoder_decoder = False
|
||||
test_sequence_classification_problem_types = False
|
||||
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
inputs_dict = copy.deepcopy(inputs_dict)
|
||||
@@ -1238,6 +1239,42 @@ class ModelTesterMixin:
|
||||
model.parallelize()
|
||||
model.generate(**cast_to_device(inputs_dict, "cuda:0"), num_beams=2)
|
||||
|
||||
def test_problem_types(self):
|
||||
if not self.test_sequence_classification_problem_types:
|
||||
return
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
problem_types = [
|
||||
{"title": "multi_label_classification", "num_labels": 2, "dtype": torch.float},
|
||||
{"title": "single_label_classification", "num_labels": 1, "dtype": torch.long},
|
||||
{"title": "regression", "num_labels": 1, "dtype": torch.float},
|
||||
]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class not in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING):
|
||||
continue
|
||||
|
||||
for problem_type in problem_types:
|
||||
with self.subTest(msg=f"Testing {model_class} with {problem_type['title']}"):
|
||||
|
||||
config.problem_type = problem_type["title"]
|
||||
config.num_labels = problem_type["num_labels"]
|
||||
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
|
||||
if problem_type["num_labels"] > 1:
|
||||
inputs["labels"] = inputs["labels"].unsqueeze(1).repeat(1, problem_type["num_labels"])
|
||||
|
||||
inputs["labels"] = inputs["labels"].to(problem_type["dtype"])
|
||||
|
||||
loss = model(**inputs).loss
|
||||
loss.backward()
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
|
||||
|
||||
@@ -260,6 +260,7 @@ class ConvBertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
)
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
test_sequence_classification_problem_types = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = ConvBertModelTester(self)
|
||||
|
||||
@@ -211,6 +211,7 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
test_pruning = True
|
||||
test_torchscript = True
|
||||
test_resize_embeddings = True
|
||||
test_sequence_classification_problem_types = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = DistilBertModelTester(self)
|
||||
|
||||
@@ -287,6 +287,7 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
test_sequence_classification_problem_types = True
|
||||
|
||||
# special case for ForPreTraining model
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
|
||||
@@ -360,6 +360,7 @@ class FunnelModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
test_sequence_classification_problem_types = True
|
||||
|
||||
# special case for ForPreTraining model
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
|
||||
@@ -274,6 +274,7 @@ class LongformerModelTester:
|
||||
class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
test_pruning = False # pruning is not supported
|
||||
test_torchscript = False
|
||||
test_sequence_classification_problem_types = True
|
||||
|
||||
all_model_classes = (
|
||||
(
|
||||
|
||||
@@ -267,6 +267,7 @@ class MobileBertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
test_sequence_classification_problem_types = True
|
||||
|
||||
# special case for ForPreTraining model
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
|
||||
@@ -590,6 +590,7 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, Mod
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
test_torchscript = False
|
||||
test_sequence_classification_problem_types = True
|
||||
|
||||
def prepare_kwargs(self):
|
||||
return {
|
||||
|
||||
@@ -351,6 +351,7 @@ class RobertaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (RobertaForCausalLM,) if is_torch_available() else ()
|
||||
test_sequence_classification_problem_types = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = RobertaModelTester(self)
|
||||
|
||||
@@ -231,6 +231,7 @@ class SqueezeBertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
test_torchscript = True
|
||||
test_resize_embeddings = True
|
||||
test_head_masking = False
|
||||
test_sequence_classification_problem_types = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = SqueezeBertModelTester(self)
|
||||
|
||||
@@ -349,6 +349,7 @@ class XLMModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
all_generative_model_classes = (
|
||||
(XLMWithLMHeadModel,) if is_torch_available() else ()
|
||||
) # TODO (PVP): Check other models whether language generation is also applicable
|
||||
test_sequence_classification_problem_types = True
|
||||
|
||||
# XLM has 2 QA models -> need to manually set the correct labels for one of them here
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
|
||||
@@ -526,6 +526,7 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
|
||||
(XLNetLMHeadModel,) if is_torch_available() else ()
|
||||
) # TODO (PVP): Check other models whether language generation is also applicable
|
||||
test_pruning = False
|
||||
test_sequence_classification_problem_types = True
|
||||
|
||||
# XLNet has 2 QA models -> need to manually set the correct labels for one of them here
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
|
||||
Reference in New Issue
Block a user