From 286fdc6b3c51cd09e57b29ba11603f585073aefd Mon Sep 17 00:00:00 2001 From: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Date: Tue, 1 Mar 2022 18:09:52 +0100 Subject: [PATCH] [vision] Add problem_type support (#15851) * Add problem_type to missing models * Fix deit test Co-authored-by: Niels Rogge --- src/transformers/models/beit/modeling_beit.py | 24 ++++++--- src/transformers/models/deit/modeling_deit.py | 24 ++++++--- .../models/segformer/modeling_segformer.py | 24 ++++++--- src/transformers/models/vit/modeling_vit.py | 24 ++++++--- tests/deit/test_modeling_deit.py | 54 +++++++++++++++++++ tests/test_modeling_common.py | 5 +- 6 files changed, 130 insertions(+), 25 deletions(-) diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index 69c213b8dd..4fdd9191fa 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -22,7 +22,7 @@ from dataclasses import dataclass import torch import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss, MSELoss +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...file_utils import ( @@ -857,14 +857,26 @@ class BeitForImageClassification(BeitPreTrainedModel): loss = None if labels is not None: - if self.num_labels == 1: - # We are doing regression + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": loss_fct = MSELoss() - loss = loss_fct(logits.view(-1), labels.view(-1)) - else: + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output diff --git a/src/transformers/models/deit/modeling_deit.py b/src/transformers/models/deit/modeling_deit.py index 5e9dca45e7..e86f6de230 100644 --- a/src/transformers/models/deit/modeling_deit.py +++ b/src/transformers/models/deit/modeling_deit.py @@ -23,7 +23,7 @@ from typing import Optional, Tuple import torch import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss, MSELoss +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...file_utils import ( @@ -726,14 +726,26 @@ class DeiTForImageClassification(DeiTPreTrainedModel): loss = None if labels is not None: - if self.num_labels == 1: - # We are doing regression + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": loss_fct = MSELoss() - loss = loss_fct(logits.view(-1), labels.view(-1)) - else: + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output diff --git a/src/transformers/models/segformer/modeling_segformer.py b/src/transformers/models/segformer/modeling_segformer.py index 7fa9ccade8..3197678a36 100755 --- a/src/transformers/models/segformer/modeling_segformer.py +++ b/src/transformers/models/segformer/modeling_segformer.py @@ -21,7 +21,7 @@ import math import torch import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss, MSELoss +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...file_utils import ( @@ -590,14 +590,26 @@ class SegformerForImageClassification(SegformerPreTrainedModel): loss = None if labels is not None: - if self.num_labels == 1: - # We are doing regression + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": loss_fct = MSELoss() - loss = loss_fct(logits.view(-1), labels.view(-1)) - else: + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) if not return_dict: output = (logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index 773be1b6cf..2c6d660ab7 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -21,7 +21,7 @@ import math import torch import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss, MSELoss +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...file_utils import ( @@ -747,14 +747,26 @@ class ViTForImageClassification(ViTPreTrainedModel): loss = None if labels is not None: - if self.num_labels == 1: - # We are doing regression + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": loss_fct = MSELoss() - loss = loss_fct(logits.view(-1), labels.view(-1)) - else: + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output diff --git a/tests/deit/test_modeling_deit.py b/tests/deit/test_modeling_deit.py index 92e6d1c14b..fa89bf231d 100644 --- a/tests/deit/test_modeling_deit.py +++ b/tests/deit/test_modeling_deit.py @@ -17,6 +17,7 @@ import inspect import unittest +import warnings from transformers import DeiTConfig from transformers.file_utils import cached_property, is_torch_available, is_vision_available @@ -32,6 +33,8 @@ if is_torch_available(): from torch import nn from transformers import ( + MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, + MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, MODEL_MAPPING, DeiTForImageClassification, DeiTForImageClassificationWithTeacher, @@ -379,6 +382,57 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase): loss = model(**inputs).loss loss.backward() + def test_problem_types(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + problem_types = [ + {"title": "multi_label_classification", "num_labels": 2, "dtype": torch.float}, + {"title": "single_label_classification", "num_labels": 1, "dtype": torch.long}, + {"title": "regression", "num_labels": 1, "dtype": torch.float}, + ] + + for model_class in self.all_model_classes: + if ( + model_class + not in [ + *get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING), + *get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING), + ] + or model_class.__name__ == "DeiTForImageClassificationWithTeacher" + ): + continue + + for problem_type in problem_types: + with self.subTest(msg=f"Testing {model_class} with {problem_type['title']}"): + + config.problem_type = problem_type["title"] + config.num_labels = problem_type["num_labels"] + + model = model_class(config) + model.to(torch_device) + model.train() + + inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + + if problem_type["num_labels"] > 1: + inputs["labels"] = inputs["labels"].unsqueeze(1).repeat(1, problem_type["num_labels"]) + + inputs["labels"] = inputs["labels"].to(problem_type["dtype"]) + + # This tests that we do not trigger the warning form PyTorch "Using a target size that is different + # to the input size. This will likely lead to incorrect results due to broadcasting. Please ensure + # they have the same size." which is a symptom something in wrong for the regression problem. + # See https://github.com/huggingface/transformers/issues/11780 + with warnings.catch_warnings(record=True) as warning_list: + loss = model(**inputs).loss + for w in warning_list: + if "Using a target size that is different to the input size" in str(w.message): + raise ValueError( + f"Something is going wrong in the regression problem: intercepted {w.message}" + ) + + loss.backward() + def test_for_image_classification(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_image_classification(*config_and_inputs) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 51aeb7056c..0f2e274814 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -1873,7 +1873,10 @@ class ModelTesterMixin: ] for model_class in self.all_model_classes: - if model_class not in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING): + if model_class not in [ + *get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING), + *get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING), + ]: continue for problem_type in problem_types: