Fix loss calculation in TFXXXForTokenClassification models (#15294)

* Fix loss calculation in TFFunnelForTokenClassification

* revert the change in TFFunnelForTokenClassification

* fix FunnelForTokenClassification loss

* fix other TokenClassification loss

* fix more

* fix more

* add num_labels to ElectraForTokenClassification

* revert the change to research projects

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2022-01-31 17:43:08 +01:00
committed by GitHub
parent 44c7857b87
commit 554d333ece
27 changed files with 29 additions and 269 deletions

View File

@@ -1455,16 +1455,7 @@ class {{cookiecutter.camelcase_modelname}}ForTokenClassification({{cookiecutter.
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
# Only keep active parts of the loss
if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)
active_labels = torch.where(
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
)
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]