Add multi-class, multi-label and regression to transformers (#11012)
* add to bert * review comments * Update src/transformers/configuration_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/configuration_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * self.config.problem_type * fix style * fix * fin * fix * update doc * fix * test * Test more problem types * Update src/transformers/configuration_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * fix * remove * fix * quality * make fix-copies * remove test Co-authored-by: abhishek thakur <abhishekkrthakur@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
@@ -163,6 +163,14 @@ class PretrainedConfig(PushToHubMixin):
|
|||||||
typically for a classification task.
|
typically for a classification task.
|
||||||
- **task_specific_params** (:obj:`Dict[str, Any]`, `optional`) -- Additional keyword arguments to store for the
|
- **task_specific_params** (:obj:`Dict[str, Any]`, `optional`) -- Additional keyword arguments to store for the
|
||||||
current task.
|
current task.
|
||||||
|
- **problem_type** (:obj:`str`, `optional`) -- Problem type for :obj:`XxxForSequenceClassification` models. Can
|
||||||
|
be one of (:obj:`"regression"`, :obj:`"single_label_classification"`, :obj:`"multi_label_classification"`).
|
||||||
|
Please note that this parameter is only available in the following models: `AlbertForSequenceClassification`,
|
||||||
|
`BertForSequenceClassification`, `BigBirdForSequenceClassification`, `ConvBertForSequenceClassification`,
|
||||||
|
`DistilBertForSequenceClassification`, `ElectraForSequenceClassification`, `FunnelForSequenceClassification`,
|
||||||
|
`LongformerForSequenceClassification`, `MobileBertForSequenceClassification`,
|
||||||
|
`ReformerForSequenceClassification`, `RobertaForSequenceClassification`,
|
||||||
|
`SqueezeBertForSequenceClassification`, `XLMForSequenceClassification` and `XLNetForSequenceClassification`.
|
||||||
|
|
||||||
Parameters linked to the tokenizer
|
Parameters linked to the tokenizer
|
||||||
|
|
||||||
@@ -260,6 +268,15 @@ class PretrainedConfig(PushToHubMixin):
|
|||||||
# task specific arguments
|
# task specific arguments
|
||||||
self.task_specific_params = kwargs.pop("task_specific_params", None)
|
self.task_specific_params = kwargs.pop("task_specific_params", None)
|
||||||
|
|
||||||
|
# regression / multi-label classification
|
||||||
|
self.problem_type = kwargs.pop("problem_type", None)
|
||||||
|
allowed_problem_types = ("regression", "single_label_classification", "multi_label_classification")
|
||||||
|
if self.problem_type is not None and self.problem_type not in allowed_problem_types:
|
||||||
|
raise ValueError(
|
||||||
|
f"The config parameter `problem_type` wasnot understood: received {self.problem_type}"
|
||||||
|
"but only 'regression', 'single_label_classification' and 'multi_label_classification' are valid."
|
||||||
|
)
|
||||||
|
|
||||||
# TPU arguments
|
# TPU arguments
|
||||||
if kwargs.pop("xla_device", None) is not None:
|
if kwargs.pop("xla_device", None) is not None:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from typing import Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn import CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...file_utils import (
|
from ...file_utils import (
|
||||||
@@ -970,6 +970,7 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
|
|||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.num_labels = config.num_labels
|
self.num_labels = config.num_labels
|
||||||
|
self.config = config
|
||||||
|
|
||||||
self.albert = AlbertModel(config)
|
self.albert = AlbertModel(config)
|
||||||
self.dropout = nn.Dropout(config.classifier_dropout_prob)
|
self.dropout = nn.Dropout(config.classifier_dropout_prob)
|
||||||
@@ -1024,13 +1025,23 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
|
|||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
if self.num_labels == 1:
|
if self.config.problem_type is None:
|
||||||
# We are doing regression
|
if self.num_labels == 1:
|
||||||
|
self.config.problem_type = "regression"
|
||||||
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||||
|
self.config.problem_type = "single_label_classification"
|
||||||
|
else:
|
||||||
|
self.config.problem_type = "multi_label_classification"
|
||||||
|
|
||||||
|
if self.config.problem_type == "regression":
|
||||||
loss_fct = MSELoss()
|
loss_fct = MSELoss()
|
||||||
loss = loss_fct(logits.view(-1), labels.view(-1))
|
loss = loss_fct(logits.view(-1, self.num_labels), labels)
|
||||||
else:
|
elif self.config.problem_type == "single_label_classification":
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
|
elif self.config.problem_type == "multi_label_classification":
|
||||||
|
loss_fct = BCEWithLogitsLoss()
|
||||||
|
loss = loss_fct(logits, labels)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[2:]
|
output = (logits,) + outputs[2:]
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ from typing import Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...file_utils import (
|
from ...file_utils import (
|
||||||
@@ -1381,7 +1381,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
|
|||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_dict=None,
|
return_dict=None,
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||||
@@ -1463,6 +1463,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
|||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.num_labels = config.num_labels
|
self.num_labels = config.num_labels
|
||||||
|
self.config = config
|
||||||
|
|
||||||
self.bert = BertModel(config)
|
self.bert = BertModel(config)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
@@ -1517,14 +1518,23 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
|||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
if self.num_labels == 1:
|
if self.config.problem_type is None:
|
||||||
# We are doing regression
|
if self.num_labels == 1:
|
||||||
|
self.config.problem_type = "regression"
|
||||||
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||||
|
self.config.problem_type = "single_label_classification"
|
||||||
|
else:
|
||||||
|
self.config.problem_type = "multi_label_classification"
|
||||||
|
|
||||||
|
if self.config.problem_type == "regression":
|
||||||
loss_fct = MSELoss()
|
loss_fct = MSELoss()
|
||||||
loss = loss_fct(logits.view(-1), labels.view(-1))
|
loss = loss_fct(logits.view(-1, self.num_labels), labels)
|
||||||
else:
|
elif self.config.problem_type == "single_label_classification":
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
|
elif self.config.problem_type == "multi_label_classification":
|
||||||
|
loss_fct = BCEWithLogitsLoss()
|
||||||
|
loss = loss_fct(logits, labels)
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[2:]
|
output = (logits,) + outputs[2:]
|
||||||
return ((loss,) + output) if loss is not None else output
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...file_utils import (
|
from ...file_utils import (
|
||||||
@@ -2609,6 +2609,7 @@ class BigBirdForSequenceClassification(BigBirdPreTrainedModel):
|
|||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.num_labels = config.num_labels
|
self.num_labels = config.num_labels
|
||||||
|
self.config = config
|
||||||
self.bert = BigBirdModel(config)
|
self.bert = BigBirdModel(config)
|
||||||
self.classifier = BigBirdClassificationHead(config)
|
self.classifier = BigBirdClassificationHead(config)
|
||||||
|
|
||||||
@@ -2659,13 +2660,23 @@ class BigBirdForSequenceClassification(BigBirdPreTrainedModel):
|
|||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
if self.num_labels == 1:
|
if self.config.problem_type is None:
|
||||||
# We are doing regression
|
if self.num_labels == 1:
|
||||||
|
self.config.problem_type = "regression"
|
||||||
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||||
|
self.config.problem_type = "single_label_classification"
|
||||||
|
else:
|
||||||
|
self.config.problem_type = "multi_label_classification"
|
||||||
|
|
||||||
|
if self.config.problem_type == "regression":
|
||||||
loss_fct = MSELoss()
|
loss_fct = MSELoss()
|
||||||
loss = loss_fct(logits.view(-1), labels.view(-1))
|
loss = loss_fct(logits.view(-1, self.num_labels), labels)
|
||||||
else:
|
elif self.config.problem_type == "single_label_classification":
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
|
elif self.config.problem_type == "multi_label_classification":
|
||||||
|
loss_fct = BCEWithLogitsLoss()
|
||||||
|
loss = loss_fct(logits, labels)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[2:]
|
output = (logits,) + outputs[2:]
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from operator import attrgetter
|
|||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from ...activations import ACT2FN, get_activation
|
from ...activations import ACT2FN, get_activation
|
||||||
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
|
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
|
||||||
@@ -962,6 +962,7 @@ class ConvBertForSequenceClassification(ConvBertPreTrainedModel):
|
|||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.num_labels = config.num_labels
|
self.num_labels = config.num_labels
|
||||||
|
self.config = config
|
||||||
self.convbert = ConvBertModel(config)
|
self.convbert = ConvBertModel(config)
|
||||||
self.classifier = ConvBertClassificationHead(config)
|
self.classifier = ConvBertClassificationHead(config)
|
||||||
|
|
||||||
@@ -1012,13 +1013,23 @@ class ConvBertForSequenceClassification(ConvBertPreTrainedModel):
|
|||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
if self.num_labels == 1:
|
if self.config.problem_type is None:
|
||||||
# We are doing regression
|
if self.num_labels == 1:
|
||||||
|
self.config.problem_type = "regression"
|
||||||
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||||
|
self.config.problem_type = "single_label_classification"
|
||||||
|
else:
|
||||||
|
self.config.problem_type = "multi_label_classification"
|
||||||
|
|
||||||
|
if self.config.problem_type == "regression":
|
||||||
loss_fct = MSELoss()
|
loss_fct = MSELoss()
|
||||||
loss = loss_fct(logits.view(-1), labels.view(-1))
|
loss = loss_fct(logits.view(-1, self.num_labels), labels)
|
||||||
else:
|
elif self.config.problem_type == "single_label_classification":
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
|
elif self.config.problem_type == "multi_label_classification":
|
||||||
|
loss_fct = BCEWithLogitsLoss()
|
||||||
|
loss = loss_fct(logits, labels)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[1:]
|
output = (logits,) + outputs[1:]
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ import math
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from ...activations import gelu
|
from ...activations import gelu
|
||||||
from ...file_utils import (
|
from ...file_utils import (
|
||||||
@@ -579,6 +579,7 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
|
|||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.num_labels = config.num_labels
|
self.num_labels = config.num_labels
|
||||||
|
self.config = config
|
||||||
|
|
||||||
self.distilbert = DistilBertModel(config)
|
self.distilbert = DistilBertModel(config)
|
||||||
self.pre_classifier = nn.Linear(config.dim, config.dim)
|
self.pre_classifier = nn.Linear(config.dim, config.dim)
|
||||||
@@ -631,12 +632,23 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
|
|||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
if self.num_labels == 1:
|
if self.config.problem_type is None:
|
||||||
loss_fct = nn.MSELoss()
|
if self.num_labels == 1:
|
||||||
loss = loss_fct(logits.view(-1), labels.view(-1))
|
self.config.problem_type = "regression"
|
||||||
else:
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||||
loss_fct = nn.CrossEntropyLoss()
|
self.config.problem_type = "single_label_classification"
|
||||||
|
else:
|
||||||
|
self.config.problem_type = "multi_label_classification"
|
||||||
|
|
||||||
|
if self.config.problem_type == "regression":
|
||||||
|
loss_fct = MSELoss()
|
||||||
|
loss = loss_fct(logits.view(-1, self.num_labels), labels)
|
||||||
|
elif self.config.problem_type == "single_label_classification":
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
|
elif self.config.problem_type == "multi_label_classification":
|
||||||
|
loss_fct = BCEWithLogitsLoss()
|
||||||
|
loss = loss_fct(logits, labels)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + distilbert_output[1:]
|
output = (logits,) + distilbert_output[1:]
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from typing import Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch.nn import CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from ...activations import ACT2FN, get_activation
|
from ...activations import ACT2FN, get_activation
|
||||||
from ...file_utils import (
|
from ...file_utils import (
|
||||||
@@ -903,6 +903,7 @@ class ElectraForSequenceClassification(ElectraPreTrainedModel):
|
|||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.num_labels = config.num_labels
|
self.num_labels = config.num_labels
|
||||||
|
self.config = config
|
||||||
self.electra = ElectraModel(config)
|
self.electra = ElectraModel(config)
|
||||||
self.classifier = ElectraClassificationHead(config)
|
self.classifier = ElectraClassificationHead(config)
|
||||||
|
|
||||||
@@ -953,13 +954,23 @@ class ElectraForSequenceClassification(ElectraPreTrainedModel):
|
|||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
if self.num_labels == 1:
|
if self.config.problem_type is None:
|
||||||
# We are doing regression
|
if self.num_labels == 1:
|
||||||
|
self.config.problem_type = "regression"
|
||||||
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||||
|
self.config.problem_type = "single_label_classification"
|
||||||
|
else:
|
||||||
|
self.config.problem_type = "multi_label_classification"
|
||||||
|
|
||||||
|
if self.config.problem_type == "regression":
|
||||||
loss_fct = MSELoss()
|
loss_fct = MSELoss()
|
||||||
loss = loss_fct(logits.view(-1), labels.view(-1))
|
loss = loss_fct(logits.view(-1, self.num_labels), labels)
|
||||||
else:
|
elif self.config.problem_type == "single_label_classification":
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
|
elif self.config.problem_type == "multi_label_classification":
|
||||||
|
loss_fct = BCEWithLogitsLoss()
|
||||||
|
loss = loss_fct(logits, labels)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + discriminator_hidden_states[1:]
|
output = (logits,) + discriminator_hidden_states[1:]
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from typing import Optional, Tuple
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
@@ -1240,6 +1240,7 @@ class FunnelForSequenceClassification(FunnelPreTrainedModel):
|
|||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.num_labels = config.num_labels
|
self.num_labels = config.num_labels
|
||||||
|
self.config = config
|
||||||
|
|
||||||
self.funnel = FunnelBaseModel(config)
|
self.funnel = FunnelBaseModel(config)
|
||||||
self.classifier = FunnelClassificationHead(config, config.num_labels)
|
self.classifier = FunnelClassificationHead(config, config.num_labels)
|
||||||
@@ -1287,13 +1288,23 @@ class FunnelForSequenceClassification(FunnelPreTrainedModel):
|
|||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
if self.num_labels == 1:
|
if self.config.problem_type is None:
|
||||||
# We are doing regression
|
if self.num_labels == 1:
|
||||||
|
self.config.problem_type = "regression"
|
||||||
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||||
|
self.config.problem_type = "single_label_classification"
|
||||||
|
else:
|
||||||
|
self.config.problem_type = "multi_label_classification"
|
||||||
|
|
||||||
|
if self.config.problem_type == "regression":
|
||||||
loss_fct = MSELoss()
|
loss_fct = MSELoss()
|
||||||
loss = loss_fct(logits.view(-1), labels.view(-1))
|
loss = loss_fct(logits.view(-1, self.num_labels), labels)
|
||||||
else:
|
elif self.config.problem_type == "single_label_classification":
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
|
elif self.config.problem_type == "multi_label_classification":
|
||||||
|
loss_fct = BCEWithLogitsLoss()
|
||||||
|
loss = loss_fct(logits, labels)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[1:]
|
output = (logits,) + outputs[1:]
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from typing import Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch.nn import CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from ...activations import ACT2FN, gelu
|
from ...activations import ACT2FN, gelu
|
||||||
@@ -1803,6 +1803,7 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel):
|
|||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.num_labels = config.num_labels
|
self.num_labels = config.num_labels
|
||||||
|
self.config = config
|
||||||
|
|
||||||
self.longformer = LongformerModel(config, add_pooling_layer=False)
|
self.longformer = LongformerModel(config, add_pooling_layer=False)
|
||||||
self.classifier = LongformerClassificationHead(config)
|
self.classifier = LongformerClassificationHead(config)
|
||||||
@@ -1861,13 +1862,23 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel):
|
|||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
if self.num_labels == 1:
|
if self.config.problem_type is None:
|
||||||
# We are doing regression
|
if self.num_labels == 1:
|
||||||
|
self.config.problem_type = "regression"
|
||||||
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||||
|
self.config.problem_type = "single_label_classification"
|
||||||
|
else:
|
||||||
|
self.config.problem_type = "multi_label_classification"
|
||||||
|
|
||||||
|
if self.config.problem_type == "regression":
|
||||||
loss_fct = MSELoss()
|
loss_fct = MSELoss()
|
||||||
loss = loss_fct(logits.view(-1), labels.view(-1))
|
loss = loss_fct(logits.view(-1, self.num_labels), labels)
|
||||||
else:
|
elif self.config.problem_type == "single_label_classification":
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
|
elif self.config.problem_type == "multi_label_classification":
|
||||||
|
loss_fct = BCEWithLogitsLoss()
|
||||||
|
loss = loss_fct(logits, labels)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[2:]
|
output = (logits,) + outputs[2:]
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ from typing import Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...file_utils import (
|
from ...file_utils import (
|
||||||
@@ -1214,6 +1214,7 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel):
|
|||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.num_labels = config.num_labels
|
self.num_labels = config.num_labels
|
||||||
|
self.config = config
|
||||||
|
|
||||||
self.mobilebert = MobileBertModel(config)
|
self.mobilebert = MobileBertModel(config)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
@@ -1268,14 +1269,23 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel):
|
|||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
if self.num_labels == 1:
|
if self.config.problem_type is None:
|
||||||
# We are doing regression
|
if self.num_labels == 1:
|
||||||
|
self.config.problem_type = "regression"
|
||||||
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||||
|
self.config.problem_type = "single_label_classification"
|
||||||
|
else:
|
||||||
|
self.config.problem_type = "multi_label_classification"
|
||||||
|
|
||||||
|
if self.config.problem_type == "regression":
|
||||||
loss_fct = MSELoss()
|
loss_fct = MSELoss()
|
||||||
loss = loss_fct(logits.view(-1), labels.view(-1))
|
loss = loss_fct(logits.view(-1, self.num_labels), labels)
|
||||||
else:
|
elif self.config.problem_type == "single_label_classification":
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
|
elif self.config.problem_type == "multi_label_classification":
|
||||||
|
loss_fct = BCEWithLogitsLoss()
|
||||||
|
loss = loss_fct(logits, labels)
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[2:]
|
output = (logits,) + outputs[2:]
|
||||||
return ((loss,) + output) if loss is not None else output
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.autograd.function import Function
|
from torch.autograd.function import Function
|
||||||
from torch.nn import CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...file_utils import (
|
from ...file_utils import (
|
||||||
@@ -366,7 +366,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
|
|||||||
past_buckets_states=None,
|
past_buckets_states=None,
|
||||||
use_cache=False,
|
use_cache=False,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
sequence_length = hidden_states.shape[1]
|
sequence_length = hidden_states.shape[1]
|
||||||
batch_size = hidden_states.shape[0]
|
batch_size = hidden_states.shape[0]
|
||||||
@@ -1045,7 +1045,7 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
|
|||||||
past_buckets_states=None,
|
past_buckets_states=None,
|
||||||
use_cache=False,
|
use_cache=False,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
sequence_length = hidden_states.shape[1]
|
sequence_length = hidden_states.shape[1]
|
||||||
batch_size = hidden_states.shape[0]
|
batch_size = hidden_states.shape[0]
|
||||||
@@ -2381,6 +2381,7 @@ class ReformerForSequenceClassification(ReformerPreTrainedModel):
|
|||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.num_labels = config.num_labels
|
self.num_labels = config.num_labels
|
||||||
|
self.config = config
|
||||||
|
|
||||||
self.reformer = ReformerModel(config)
|
self.reformer = ReformerModel(config)
|
||||||
self.classifier = ReformerClassificationHead(config)
|
self.classifier = ReformerClassificationHead(config)
|
||||||
@@ -2434,13 +2435,23 @@ class ReformerForSequenceClassification(ReformerPreTrainedModel):
|
|||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
if self.num_labels == 1:
|
if self.config.problem_type is None:
|
||||||
# We are doing regression
|
if self.num_labels == 1:
|
||||||
|
self.config.problem_type = "regression"
|
||||||
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||||
|
self.config.problem_type = "single_label_classification"
|
||||||
|
else:
|
||||||
|
self.config.problem_type = "multi_label_classification"
|
||||||
|
|
||||||
|
if self.config.problem_type == "regression":
|
||||||
loss_fct = MSELoss()
|
loss_fct = MSELoss()
|
||||||
loss = loss_fct(logits.view(-1), labels.view(-1))
|
loss = loss_fct(logits.view(-1, self.num_labels), labels)
|
||||||
else:
|
elif self.config.problem_type == "single_label_classification":
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
|
elif self.config.problem_type == "multi_label_classification":
|
||||||
|
loss_fct = BCEWithLogitsLoss()
|
||||||
|
loss = loss_fct(logits, labels)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[2:]
|
output = (logits,) + outputs[2:]
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import math
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch.nn import CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from ...activations import ACT2FN, gelu
|
from ...activations import ACT2FN, gelu
|
||||||
from ...file_utils import (
|
from ...file_utils import (
|
||||||
@@ -1117,6 +1117,7 @@ class RobertaForSequenceClassification(RobertaPreTrainedModel):
|
|||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.num_labels = config.num_labels
|
self.num_labels = config.num_labels
|
||||||
|
self.config = config
|
||||||
|
|
||||||
self.roberta = RobertaModel(config, add_pooling_layer=False)
|
self.roberta = RobertaModel(config, add_pooling_layer=False)
|
||||||
self.classifier = RobertaClassificationHead(config)
|
self.classifier = RobertaClassificationHead(config)
|
||||||
@@ -1167,13 +1168,23 @@ class RobertaForSequenceClassification(RobertaPreTrainedModel):
|
|||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
if self.num_labels == 1:
|
if self.config.problem_type is None:
|
||||||
# We are doing regression
|
if self.num_labels == 1:
|
||||||
|
self.config.problem_type = "regression"
|
||||||
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||||
|
self.config.problem_type = "single_label_classification"
|
||||||
|
else:
|
||||||
|
self.config.problem_type = "multi_label_classification"
|
||||||
|
|
||||||
|
if self.config.problem_type == "regression":
|
||||||
loss_fct = MSELoss()
|
loss_fct = MSELoss()
|
||||||
loss = loss_fct(logits.view(-1), labels.view(-1))
|
loss = loss_fct(logits.view(-1, self.num_labels), labels)
|
||||||
else:
|
elif self.config.problem_type == "single_label_classification":
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
|
elif self.config.problem_type == "multi_label_classification":
|
||||||
|
loss_fct = BCEWithLogitsLoss()
|
||||||
|
loss = loss_fct(logits, labels)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[2:]
|
output = (logits,) + outputs[2:]
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ import math
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
|
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
|
||||||
@@ -733,6 +733,7 @@ class SqueezeBertForSequenceClassification(SqueezeBertPreTrainedModel):
|
|||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.num_labels = config.num_labels
|
self.num_labels = config.num_labels
|
||||||
|
self.config = config
|
||||||
|
|
||||||
self.transformer = SqueezeBertModel(config)
|
self.transformer = SqueezeBertModel(config)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
@@ -787,13 +788,23 @@ class SqueezeBertForSequenceClassification(SqueezeBertPreTrainedModel):
|
|||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
if self.num_labels == 1:
|
if self.config.problem_type is None:
|
||||||
# We are doing regression
|
if self.num_labels == 1:
|
||||||
|
self.config.problem_type = "regression"
|
||||||
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||||
|
self.config.problem_type = "single_label_classification"
|
||||||
|
else:
|
||||||
|
self.config.problem_type = "multi_label_classification"
|
||||||
|
|
||||||
|
if self.config.problem_type == "regression":
|
||||||
loss_fct = MSELoss()
|
loss_fct = MSELoss()
|
||||||
loss = loss_fct(logits.view(-1), labels.view(-1))
|
loss = loss_fct(logits.view(-1, self.num_labels), labels)
|
||||||
else:
|
elif self.config.problem_type == "single_label_classification":
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
|
elif self.config.problem_type == "multi_label_classification":
|
||||||
|
loss_fct = BCEWithLogitsLoss()
|
||||||
|
loss = loss_fct(logits, labels)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[2:]
|
output = (logits,) + outputs[2:]
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from typing import Optional, Tuple
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from ...activations import gelu
|
from ...activations import gelu
|
||||||
@@ -779,6 +779,7 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
|
|||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.num_labels = config.num_labels
|
self.num_labels = config.num_labels
|
||||||
|
self.config = config
|
||||||
|
|
||||||
self.transformer = XLMModel(config)
|
self.transformer = XLMModel(config)
|
||||||
self.sequence_summary = SequenceSummary(config)
|
self.sequence_summary = SequenceSummary(config)
|
||||||
@@ -836,13 +837,23 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
|
|||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
if self.num_labels == 1:
|
if self.config.problem_type is None:
|
||||||
# We are doing regression
|
if self.num_labels == 1:
|
||||||
|
self.config.problem_type = "regression"
|
||||||
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||||
|
self.config.problem_type = "single_label_classification"
|
||||||
|
else:
|
||||||
|
self.config.problem_type = "multi_label_classification"
|
||||||
|
|
||||||
|
if self.config.problem_type == "regression":
|
||||||
loss_fct = MSELoss()
|
loss_fct = MSELoss()
|
||||||
loss = loss_fct(logits.view(-1), labels.view(-1))
|
loss = loss_fct(logits.view(-1, self.num_labels), labels)
|
||||||
else:
|
elif self.config.problem_type == "single_label_classification":
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
|
elif self.config.problem_type == "multi_label_classification":
|
||||||
|
loss_fct = BCEWithLogitsLoss()
|
||||||
|
loss = loss_fct(logits, labels)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + transformer_outputs[1:]
|
output = (logits,) + transformer_outputs[1:]
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from typing import List, Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
@@ -1488,6 +1488,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
|
|||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.num_labels = config.num_labels
|
self.num_labels = config.num_labels
|
||||||
|
self.config = config
|
||||||
|
|
||||||
self.transformer = XLNetModel(config)
|
self.transformer = XLNetModel(config)
|
||||||
self.sequence_summary = SequenceSummary(config)
|
self.sequence_summary = SequenceSummary(config)
|
||||||
@@ -1551,13 +1552,23 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
|
|||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
if self.num_labels == 1:
|
if self.config.problem_type is None:
|
||||||
# We are doing regression
|
if self.num_labels == 1:
|
||||||
|
self.config.problem_type = "regression"
|
||||||
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||||
|
self.config.problem_type = "single_label_classification"
|
||||||
|
else:
|
||||||
|
self.config.problem_type = "multi_label_classification"
|
||||||
|
|
||||||
|
if self.config.problem_type == "regression":
|
||||||
loss_fct = MSELoss()
|
loss_fct = MSELoss()
|
||||||
loss = loss_fct(logits.view(-1), labels.view(-1))
|
loss = loss_fct(logits.view(-1, self.num_labels), labels)
|
||||||
else:
|
elif self.config.problem_type == "single_label_classification":
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
|
elif self.config.problem_type == "multi_label_classification":
|
||||||
|
loss_fct = BCEWithLogitsLoss()
|
||||||
|
loss = loss_fct(logits, labels)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + transformer_outputs[1:]
|
output = (logits,) + transformer_outputs[1:]
|
||||||
|
|||||||
@@ -230,6 +230,8 @@ class AlbertModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
test_sequence_classification_problem_types = True
|
||||||
|
|
||||||
# special case for ForPreTraining model
|
# special case for ForPreTraining model
|
||||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||||
|
|||||||
@@ -439,6 +439,7 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else ()
|
all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else ()
|
||||||
|
test_sequence_classification_problem_types = True
|
||||||
|
|
||||||
# special case for ForPreTraining model
|
# special case for ForPreTraining model
|
||||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||||
|
|||||||
@@ -433,6 +433,7 @@ class BigBirdModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
# head masking & pruning is currently not supported for big bird
|
# head masking & pruning is currently not supported for big bird
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
|
test_sequence_classification_problem_types = True
|
||||||
|
|
||||||
# torchscript should be possible, but takes prohibitively long to test.
|
# torchscript should be possible, but takes prohibitively long to test.
|
||||||
# Also torchscript is not an important feature to have in the beginning.
|
# Also torchscript is not an important feature to have in the beginning.
|
||||||
|
|||||||
@@ -89,6 +89,7 @@ class ModelTesterMixin:
|
|||||||
test_missing_keys = True
|
test_missing_keys = True
|
||||||
test_model_parallel = False
|
test_model_parallel = False
|
||||||
is_encoder_decoder = False
|
is_encoder_decoder = False
|
||||||
|
test_sequence_classification_problem_types = False
|
||||||
|
|
||||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||||
inputs_dict = copy.deepcopy(inputs_dict)
|
inputs_dict = copy.deepcopy(inputs_dict)
|
||||||
@@ -1238,6 +1239,42 @@ class ModelTesterMixin:
|
|||||||
model.parallelize()
|
model.parallelize()
|
||||||
model.generate(**cast_to_device(inputs_dict, "cuda:0"), num_beams=2)
|
model.generate(**cast_to_device(inputs_dict, "cuda:0"), num_beams=2)
|
||||||
|
|
||||||
|
def test_problem_types(self):
|
||||||
|
if not self.test_sequence_classification_problem_types:
|
||||||
|
return
|
||||||
|
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
problem_types = [
|
||||||
|
{"title": "multi_label_classification", "num_labels": 2, "dtype": torch.float},
|
||||||
|
{"title": "single_label_classification", "num_labels": 1, "dtype": torch.long},
|
||||||
|
{"title": "regression", "num_labels": 1, "dtype": torch.float},
|
||||||
|
]
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
if model_class not in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for problem_type in problem_types:
|
||||||
|
with self.subTest(msg=f"Testing {model_class} with {problem_type['title']}"):
|
||||||
|
|
||||||
|
config.problem_type = problem_type["title"]
|
||||||
|
config.num_labels = problem_type["num_labels"]
|
||||||
|
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
|
|
||||||
|
if problem_type["num_labels"] > 1:
|
||||||
|
inputs["labels"] = inputs["labels"].unsqueeze(1).repeat(1, problem_type["num_labels"])
|
||||||
|
|
||||||
|
inputs["labels"] = inputs["labels"].to(problem_type["dtype"])
|
||||||
|
|
||||||
|
loss = model(**inputs).loss
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
|
||||||
global_rng = random.Random()
|
global_rng = random.Random()
|
||||||
|
|
||||||
|
|||||||
@@ -260,6 +260,7 @@ class ConvBertModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
test_sequence_classification_problem_types = True
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = ConvBertModelTester(self)
|
self.model_tester = ConvBertModelTester(self)
|
||||||
|
|||||||
@@ -211,6 +211,7 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
test_pruning = True
|
test_pruning = True
|
||||||
test_torchscript = True
|
test_torchscript = True
|
||||||
test_resize_embeddings = True
|
test_resize_embeddings = True
|
||||||
|
test_sequence_classification_problem_types = True
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = DistilBertModelTester(self)
|
self.model_tester = DistilBertModelTester(self)
|
||||||
|
|||||||
@@ -287,6 +287,7 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
|
test_sequence_classification_problem_types = True
|
||||||
|
|
||||||
# special case for ForPreTraining model
|
# special case for ForPreTraining model
|
||||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||||
|
|||||||
@@ -360,6 +360,7 @@ class FunnelModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
|
test_sequence_classification_problem_types = True
|
||||||
|
|
||||||
# special case for ForPreTraining model
|
# special case for ForPreTraining model
|
||||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||||
|
|||||||
@@ -274,6 +274,7 @@ class LongformerModelTester:
|
|||||||
class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
|
class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
test_pruning = False # pruning is not supported
|
test_pruning = False # pruning is not supported
|
||||||
test_torchscript = False
|
test_torchscript = False
|
||||||
|
test_sequence_classification_problem_types = True
|
||||||
|
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(
|
(
|
||||||
|
|||||||
@@ -267,6 +267,7 @@ class MobileBertModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
|
test_sequence_classification_problem_types = True
|
||||||
|
|
||||||
# special case for ForPreTraining model
|
# special case for ForPreTraining model
|
||||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||||
|
|||||||
@@ -590,6 +590,7 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, Mod
|
|||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_headmasking = False
|
test_headmasking = False
|
||||||
test_torchscript = False
|
test_torchscript = False
|
||||||
|
test_sequence_classification_problem_types = True
|
||||||
|
|
||||||
def prepare_kwargs(self):
|
def prepare_kwargs(self):
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -351,6 +351,7 @@ class RobertaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
|
|||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
all_generative_model_classes = (RobertaForCausalLM,) if is_torch_available() else ()
|
all_generative_model_classes = (RobertaForCausalLM,) if is_torch_available() else ()
|
||||||
|
test_sequence_classification_problem_types = True
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = RobertaModelTester(self)
|
self.model_tester = RobertaModelTester(self)
|
||||||
|
|||||||
@@ -231,6 +231,7 @@ class SqueezeBertModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
test_torchscript = True
|
test_torchscript = True
|
||||||
test_resize_embeddings = True
|
test_resize_embeddings = True
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
test_sequence_classification_problem_types = True
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = SqueezeBertModelTester(self)
|
self.model_tester = SqueezeBertModelTester(self)
|
||||||
|
|||||||
@@ -349,6 +349,7 @@ class XLMModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|||||||
all_generative_model_classes = (
|
all_generative_model_classes = (
|
||||||
(XLMWithLMHeadModel,) if is_torch_available() else ()
|
(XLMWithLMHeadModel,) if is_torch_available() else ()
|
||||||
) # TODO (PVP): Check other models whether language generation is also applicable
|
) # TODO (PVP): Check other models whether language generation is also applicable
|
||||||
|
test_sequence_classification_problem_types = True
|
||||||
|
|
||||||
# XLM has 2 QA models -> need to manually set the correct labels for one of them here
|
# XLM has 2 QA models -> need to manually set the correct labels for one of them here
|
||||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||||
|
|||||||
@@ -526,6 +526,7 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
|
|||||||
(XLNetLMHeadModel,) if is_torch_available() else ()
|
(XLNetLMHeadModel,) if is_torch_available() else ()
|
||||||
) # TODO (PVP): Check other models whether language generation is also applicable
|
) # TODO (PVP): Check other models whether language generation is also applicable
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
|
test_sequence_classification_problem_types = True
|
||||||
|
|
||||||
# XLNet has 2 QA models -> need to manually set the correct labels for one of them here
|
# XLNet has 2 QA models -> need to manually set the correct labels for one of them here
|
||||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||||
|
|||||||
Reference in New Issue
Block a user