Add multi-class, multi-label and regression to transformers (#11012)
* add to bert * review comments * Update src/transformers/configuration_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/configuration_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * self.config.problem_type * fix style * fix * fin * fix * update doc * fix * test * Test more problem types * Update src/transformers/configuration_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * fix * remove * fix * quality * make fix-copies * remove test Co-authored-by: abhishek thakur <abhishekkrthakur@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
@@ -89,6 +89,7 @@ class ModelTesterMixin:
|
||||
test_missing_keys = True
|
||||
test_model_parallel = False
|
||||
is_encoder_decoder = False
|
||||
test_sequence_classification_problem_types = False
|
||||
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
inputs_dict = copy.deepcopy(inputs_dict)
|
||||
@@ -1238,6 +1239,42 @@ class ModelTesterMixin:
|
||||
model.parallelize()
|
||||
model.generate(**cast_to_device(inputs_dict, "cuda:0"), num_beams=2)
|
||||
|
||||
def test_problem_types(self):
|
||||
if not self.test_sequence_classification_problem_types:
|
||||
return
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
problem_types = [
|
||||
{"title": "multi_label_classification", "num_labels": 2, "dtype": torch.float},
|
||||
{"title": "single_label_classification", "num_labels": 1, "dtype": torch.long},
|
||||
{"title": "regression", "num_labels": 1, "dtype": torch.float},
|
||||
]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class not in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING):
|
||||
continue
|
||||
|
||||
for problem_type in problem_types:
|
||||
with self.subTest(msg=f"Testing {model_class} with {problem_type['title']}"):
|
||||
|
||||
config.problem_type = problem_type["title"]
|
||||
config.num_labels = problem_type["num_labels"]
|
||||
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
|
||||
if problem_type["num_labels"] > 1:
|
||||
inputs["labels"] = inputs["labels"].unsqueeze(1).repeat(1, problem_type["num_labels"])
|
||||
|
||||
inputs["labels"] = inputs["labels"].to(problem_type["dtype"])
|
||||
|
||||
loss = model(**inputs).loss
|
||||
loss.backward()
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user