Fix loss calculation in TFXXXForTokenClassification models (#15294)
* Fix loss calculation in TFFunnelForTokenClassification * revert the change in TFFunnelForTokenClassification * fix FunnelForTokenClassification loss * fix other TokenClassification loss * fix more * fix more * add num_labels to ElectraForTokenClassification * revert the change to research projects Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -1150,16 +1150,7 @@ class AlbertForTokenClassification(AlbertPreTrainedModel):
|
|||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
# Only keep active parts of the loss
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
if attention_mask is not None:
|
|
||||||
active_loss = attention_mask.view(-1) == 1
|
|
||||||
active_logits = logits.view(-1, self.num_labels)
|
|
||||||
active_labels = torch.where(
|
|
||||||
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
|
|
||||||
)
|
|
||||||
loss = loss_fct(active_logits, active_labels)
|
|
||||||
else:
|
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[2:]
|
output = (logits,) + outputs[2:]
|
||||||
|
|||||||
@@ -1759,16 +1759,7 @@ class BertForTokenClassification(BertPreTrainedModel):
|
|||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
# Only keep active parts of the loss
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
if attention_mask is not None:
|
|
||||||
active_loss = attention_mask.view(-1) == 1
|
|
||||||
active_logits = logits.view(-1, self.num_labels)
|
|
||||||
active_labels = torch.where(
|
|
||||||
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
|
|
||||||
)
|
|
||||||
loss = loss_fct(active_logits, active_labels)
|
|
||||||
else:
|
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[2:]
|
output = (logits,) + outputs[2:]
|
||||||
|
|||||||
@@ -2888,16 +2888,7 @@ class BigBirdForTokenClassification(BigBirdPreTrainedModel):
|
|||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
# Only keep active parts of the loss
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
if attention_mask is not None:
|
|
||||||
active_loss = attention_mask.view(-1) == 1
|
|
||||||
active_logits = logits.view(-1, self.num_labels)
|
|
||||||
active_labels = torch.where(
|
|
||||||
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
|
|
||||||
)
|
|
||||||
loss = loss_fct(active_logits, active_labels)
|
|
||||||
else:
|
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[2:]
|
output = (logits,) + outputs[2:]
|
||||||
|
|||||||
@@ -1510,16 +1510,7 @@ class CanineForTokenClassification(CaninePreTrainedModel):
|
|||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
# Only keep active parts of the loss
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
if attention_mask is not None:
|
|
||||||
active_loss = attention_mask.view(-1) == 1
|
|
||||||
active_logits = logits.view(-1, self.num_labels)
|
|
||||||
active_labels = torch.where(
|
|
||||||
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
|
|
||||||
)
|
|
||||||
loss = loss_fct(active_logits, active_labels)
|
|
||||||
else:
|
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[2:]
|
output = (logits,) + outputs[2:]
|
||||||
|
|||||||
@@ -1238,16 +1238,7 @@ class ConvBertForTokenClassification(ConvBertPreTrainedModel):
|
|||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
# Only keep active parts of the loss
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
if attention_mask is not None:
|
|
||||||
active_loss = attention_mask.view(-1) == 1
|
|
||||||
active_logits = logits.view(-1, self.num_labels)
|
|
||||||
active_labels = torch.where(
|
|
||||||
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
|
|
||||||
)
|
|
||||||
loss = loss_fct(active_logits, active_labels)
|
|
||||||
else:
|
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[1:]
|
output = (logits,) + outputs[1:]
|
||||||
|
|||||||
@@ -1303,16 +1303,7 @@ class DebertaForTokenClassification(DebertaPreTrainedModel):
|
|||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
# Only keep active parts of the loss
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
if attention_mask is not None:
|
|
||||||
active_loss = attention_mask.view(-1) == 1
|
|
||||||
active_logits = logits.view(-1, self.num_labels)
|
|
||||||
active_labels = torch.where(
|
|
||||||
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
|
|
||||||
)
|
|
||||||
loss = loss_fct(active_logits, active_labels)
|
|
||||||
else:
|
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[1:]
|
output = (logits,) + outputs[1:]
|
||||||
|
|||||||
@@ -1418,16 +1418,7 @@ class DebertaV2ForTokenClassification(DebertaV2PreTrainedModel):
|
|||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
# Only keep active parts of the loss
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
if attention_mask is not None:
|
|
||||||
active_loss = attention_mask.view(-1) == 1
|
|
||||||
active_logits = logits.view(-1, self.num_labels)
|
|
||||||
active_labels = torch.where(
|
|
||||||
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
|
|
||||||
)
|
|
||||||
loss = loss_fct(active_logits, active_labels)
|
|
||||||
else:
|
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[1:]
|
output = (logits,) + outputs[1:]
|
||||||
|
|||||||
@@ -973,16 +973,7 @@ class DistilBertForTokenClassification(DistilBertPreTrainedModel):
|
|||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
# Only keep active parts of the loss
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
if attention_mask is not None:
|
|
||||||
active_loss = attention_mask.view(-1) == 1
|
|
||||||
active_logits = logits.view(-1, self.num_labels)
|
|
||||||
active_labels = torch.where(
|
|
||||||
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
|
|
||||||
)
|
|
||||||
loss = loss_fct(active_logits, active_labels)
|
|
||||||
else:
|
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[1:]
|
output = (logits,) + outputs[1:]
|
||||||
|
|||||||
@@ -1242,6 +1242,7 @@ class ElectraForMaskedLM(ElectraPreTrainedModel):
|
|||||||
class ElectraForTokenClassification(ElectraPreTrainedModel):
|
class ElectraForTokenClassification(ElectraPreTrainedModel):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
self.num_labels = config.num_labels
|
||||||
|
|
||||||
self.electra = ElectraModel(config)
|
self.electra = ElectraModel(config)
|
||||||
classifier_dropout = (
|
classifier_dropout = (
|
||||||
@@ -1296,17 +1297,8 @@ class ElectraForTokenClassification(ElectraPreTrainedModel):
|
|||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
loss_fct = nn.CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
# Only keep active parts of the loss
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
if attention_mask is not None:
|
|
||||||
active_loss = attention_mask.view(-1) == 1
|
|
||||||
active_logits = logits.view(-1, self.config.num_labels)
|
|
||||||
active_labels = torch.where(
|
|
||||||
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
|
|
||||||
)
|
|
||||||
loss = loss_fct(active_logits, active_labels)
|
|
||||||
else:
|
|
||||||
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
|
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + discriminator_hidden_states[1:]
|
output = (logits,) + discriminator_hidden_states[1:]
|
||||||
|
|||||||
@@ -1469,16 +1469,7 @@ class FunnelForTokenClassification(FunnelPreTrainedModel):
|
|||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
# Only keep active parts of the loss
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
if attention_mask is not None:
|
|
||||||
active_loss = attention_mask.view(-1) == 1
|
|
||||||
active_logits = logits.view(-1, self.num_labels)
|
|
||||||
active_labels = torch.where(
|
|
||||||
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
|
|
||||||
)
|
|
||||||
loss = loss_fct(active_logits, active_labels)
|
|
||||||
else:
|
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[1:]
|
output = (logits,) + outputs[1:]
|
||||||
|
|||||||
@@ -1527,16 +1527,7 @@ class GPT2ForTokenClassification(GPT2PreTrainedModel):
|
|||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
# Only keep active parts of the loss
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
if attention_mask is not None:
|
|
||||||
active_loss = attention_mask.view(-1) == 1
|
|
||||||
active_logits = logits.view(-1, self.num_labels)
|
|
||||||
active_labels = torch.where(
|
|
||||||
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
|
|
||||||
)
|
|
||||||
loss = loss_fct(active_logits, active_labels)
|
|
||||||
else:
|
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + transformer_outputs[2:]
|
output = (logits,) + transformer_outputs[2:]
|
||||||
|
|||||||
@@ -1219,16 +1219,7 @@ class IBertForTokenClassification(IBertPreTrainedModel):
|
|||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
# Only keep active parts of the loss
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
if attention_mask is not None:
|
|
||||||
active_loss = attention_mask.view(-1) == 1
|
|
||||||
active_logits = logits.view(-1, self.num_labels)
|
|
||||||
active_labels = torch.where(
|
|
||||||
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
|
|
||||||
)
|
|
||||||
loss = loss_fct(active_logits, active_labels)
|
|
||||||
else:
|
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[2:]
|
output = (logits,) + outputs[2:]
|
||||||
|
|||||||
@@ -1213,16 +1213,7 @@ class LayoutLMForTokenClassification(LayoutLMPreTrainedModel):
|
|||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
if attention_mask is not None:
|
|
||||||
active_loss = attention_mask.view(-1) == 1
|
|
||||||
active_logits = logits.view(-1, self.num_labels)
|
|
||||||
active_labels = torch.where(
|
|
||||||
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
|
|
||||||
)
|
|
||||||
loss = loss_fct(active_logits, active_labels)
|
|
||||||
else:
|
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[2:]
|
output = (logits,) + outputs[2:]
|
||||||
|
|||||||
@@ -1205,14 +1205,7 @@ class LayoutLMv2ForTokenClassification(LayoutLMv2PreTrainedModel):
|
|||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
if attention_mask is not None:
|
|
||||||
active_loss = attention_mask.view(-1) == 1
|
|
||||||
active_logits = logits.view(-1, self.num_labels)[active_loss]
|
|
||||||
active_labels = labels.view(-1)[active_loss]
|
|
||||||
loss = loss_fct(active_logits, active_labels)
|
|
||||||
else:
|
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[2:]
|
output = (logits,) + outputs[2:]
|
||||||
|
|||||||
@@ -2156,16 +2156,7 @@ class LongformerForTokenClassification(LongformerPreTrainedModel):
|
|||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
# Only keep active parts of the loss
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
if attention_mask is not None:
|
|
||||||
active_loss = attention_mask.view(-1) == 1
|
|
||||||
active_logits = logits.view(-1, self.num_labels)
|
|
||||||
active_labels = torch.where(
|
|
||||||
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
|
|
||||||
)
|
|
||||||
loss = loss_fct(active_logits, active_labels)
|
|
||||||
else:
|
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[2:]
|
output = (logits,) + outputs[2:]
|
||||||
|
|||||||
@@ -1727,16 +1727,7 @@ class MegatronBertForTokenClassification(MegatronBertPreTrainedModel):
|
|||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
# Only keep active parts of the loss
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
if attention_mask is not None:
|
|
||||||
active_loss = attention_mask.view(-1) == 1
|
|
||||||
active_logits = logits.view(-1, self.num_labels)
|
|
||||||
active_labels = torch.where(
|
|
||||||
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
|
|
||||||
)
|
|
||||||
loss = loss_fct(active_logits, active_labels)
|
|
||||||
else:
|
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[2:]
|
output = (logits,) + outputs[2:]
|
||||||
|
|||||||
@@ -1579,16 +1579,7 @@ class MobileBertForTokenClassification(MobileBertPreTrainedModel):
|
|||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
# Only keep active parts of the loss
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
if attention_mask is not None:
|
|
||||||
active_loss = attention_mask.view(-1) == 1
|
|
||||||
active_logits = logits.view(-1, self.num_labels)
|
|
||||||
active_labels = torch.where(
|
|
||||||
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
|
|
||||||
)
|
|
||||||
loss = loss_fct(active_logits, active_labels)
|
|
||||||
else:
|
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[2:]
|
output = (logits,) + outputs[2:]
|
||||||
|
|||||||
@@ -927,16 +927,7 @@ class MPNetForTokenClassification(MPNetPreTrainedModel):
|
|||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
# Only keep active parts of the loss
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
if attention_mask is not None:
|
|
||||||
active_loss = attention_mask.view(-1) == 1
|
|
||||||
active_logits = logits.view(-1, self.num_labels)
|
|
||||||
active_labels = torch.where(
|
|
||||||
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
|
|
||||||
)
|
|
||||||
loss = loss_fct(active_logits, active_labels)
|
|
||||||
else:
|
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[2:]
|
output = (logits,) + outputs[2:]
|
||||||
|
|||||||
@@ -1017,16 +1017,7 @@ class NystromformerForTokenClassification(NystromformerPreTrainedModel):
|
|||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
# Only keep active parts of the loss
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
if attention_mask is not None:
|
|
||||||
active_loss = attention_mask.view(-1) == 1
|
|
||||||
active_logits = logits.view(-1, self.num_labels)
|
|
||||||
active_labels = torch.where(
|
|
||||||
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
|
|
||||||
)
|
|
||||||
loss = loss_fct(active_logits, active_labels)
|
|
||||||
else:
|
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[1:]
|
output = (logits,) + outputs[1:]
|
||||||
|
|||||||
@@ -1622,16 +1622,7 @@ class QDQBertForTokenClassification(QDQBertPreTrainedModel):
|
|||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
# Only keep active parts of the loss
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
if attention_mask is not None:
|
|
||||||
active_loss = attention_mask.view(-1) == 1
|
|
||||||
active_logits = logits.view(-1, self.num_labels)
|
|
||||||
active_labels = torch.where(
|
|
||||||
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
|
|
||||||
)
|
|
||||||
loss = loss_fct(active_logits, active_labels)
|
|
||||||
else:
|
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[2:]
|
output = (logits,) + outputs[2:]
|
||||||
|
|||||||
@@ -1413,16 +1413,7 @@ class RemBertForTokenClassification(RemBertPreTrainedModel):
|
|||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
# Only keep active parts of the loss
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
if attention_mask is not None:
|
|
||||||
active_loss = attention_mask.view(-1) == 1
|
|
||||||
active_logits = logits.view(-1, self.num_labels)
|
|
||||||
active_labels = torch.where(
|
|
||||||
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
|
|
||||||
)
|
|
||||||
loss = loss_fct(active_logits, active_labels)
|
|
||||||
else:
|
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[2:]
|
output = (logits,) + outputs[2:]
|
||||||
|
|||||||
@@ -1414,16 +1414,7 @@ class RobertaForTokenClassification(RobertaPreTrainedModel):
|
|||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
# Only keep active parts of the loss
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
if attention_mask is not None:
|
|
||||||
active_loss = attention_mask.view(-1) == 1
|
|
||||||
active_logits = logits.view(-1, self.num_labels)
|
|
||||||
active_labels = torch.where(
|
|
||||||
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
|
|
||||||
)
|
|
||||||
loss = loss_fct(active_logits, active_labels)
|
|
||||||
else:
|
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[2:]
|
output = (logits,) + outputs[2:]
|
||||||
|
|||||||
@@ -1472,16 +1472,7 @@ class RoFormerForTokenClassification(RoFormerPreTrainedModel):
|
|||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
# Only keep active parts of the loss
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
if attention_mask is not None:
|
|
||||||
active_loss = attention_mask.view(-1) == 1
|
|
||||||
active_logits = logits.view(-1, self.num_labels)
|
|
||||||
active_labels = torch.where(
|
|
||||||
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
|
|
||||||
)
|
|
||||||
loss = loss_fct(active_logits, active_labels)
|
|
||||||
else:
|
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[1:]
|
output = (logits,) + outputs[1:]
|
||||||
|
|||||||
@@ -984,16 +984,7 @@ class SqueezeBertForTokenClassification(SqueezeBertPreTrainedModel):
|
|||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
# Only keep active parts of the loss
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
if attention_mask is not None:
|
|
||||||
active_loss = attention_mask.view(-1) == 1
|
|
||||||
active_logits = logits.view(-1, self.num_labels)
|
|
||||||
active_labels = torch.where(
|
|
||||||
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
|
|
||||||
)
|
|
||||||
loss = loss_fct(active_logits, active_labels)
|
|
||||||
else:
|
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[2:]
|
output = (logits,) + outputs[2:]
|
||||||
|
|||||||
@@ -1169,16 +1169,7 @@ class XLMForTokenClassification(XLMPreTrainedModel):
|
|||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
# Only keep active parts of the loss
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
if attention_mask is not None:
|
|
||||||
active_loss = attention_mask.view(-1) == 1
|
|
||||||
active_logits = logits.view(-1, self.num_labels)
|
|
||||||
active_labels = torch.where(
|
|
||||||
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
|
|
||||||
)
|
|
||||||
loss = loss_fct(active_logits, active_labels)
|
|
||||||
else:
|
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[1:]
|
output = (logits,) + outputs[1:]
|
||||||
|
|||||||
@@ -1680,16 +1680,7 @@ class XLNetForTokenClassification(XLNetPreTrainedModel):
|
|||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
# Only keep active parts of the loss
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
if attention_mask is not None:
|
|
||||||
active_loss = attention_mask.view(-1) == 1
|
|
||||||
active_logits = logits.view(-1, self.num_labels)
|
|
||||||
active_labels = torch.where(
|
|
||||||
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
|
|
||||||
)
|
|
||||||
loss = loss_fct(active_logits, active_labels)
|
|
||||||
else:
|
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[1:]
|
output = (logits,) + outputs[1:]
|
||||||
|
|||||||
@@ -1455,16 +1455,7 @@ class {{cookiecutter.camelcase_modelname}}ForTokenClassification({{cookiecutter.
|
|||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
# Only keep active parts of the loss
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
if attention_mask is not None:
|
|
||||||
active_loss = attention_mask.view(-1) == 1
|
|
||||||
active_logits = logits.view(-1, self.num_labels)
|
|
||||||
active_labels = torch.where(
|
|
||||||
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
|
|
||||||
)
|
|
||||||
loss = loss_fct(active_logits, active_labels)
|
|
||||||
else:
|
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[1:]
|
output = (logits,) + outputs[1:]
|
||||||
|
|||||||
Reference in New Issue
Block a user