Fixes bug that appears when using QA bert and distilation. (#12026)

* Fixing bug that appears when using distilation (and potentially other uses).
During backward pass Pytorch complains with:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
This happens because the QA model code modifies the start_positions and end_positions input tensors, using clamp_ function: as a consequence the teacher and the student both modifies the inputs, and backward pass fails.

* Fixing all models QA clamp_ bug.
This commit is contained in:
François Lagunas
2021-06-07 17:21:59 +02:00
committed by GitHub
parent 59f75d538b
commit f8bd8c6c7e
25 changed files with 52 additions and 52 deletions

View File

@@ -1230,8 +1230,8 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)

View File

@@ -1578,8 +1578,8 @@ class BartForQuestionAnswering(BartPretrainedModel):
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)

View File

@@ -1813,8 +1813,8 @@ class BertForQuestionAnswering(BertPreTrainedModel):
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)

View File

@@ -2995,8 +2995,8 @@ class BigBirdForQuestionAnswering(BigBirdPreTrainedModel):
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)

View File

@@ -2783,8 +2783,8 @@ class BigBirdPegasusForQuestionAnswering(BigBirdPegasusPreTrainedModel):
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)

View File

@@ -1305,8 +1305,8 @@ class ConvBertForQuestionAnswering(ConvBertPreTrainedModel):
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)

View File

@@ -1376,8 +1376,8 @@ class DebertaForQuestionAnswering(DebertaPreTrainedModel):
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)

View File

@@ -1500,8 +1500,8 @@ class DebertaV2ForQuestionAnswering(DebertaV2PreTrainedModel):
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)

View File

@@ -740,8 +740,8 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index) loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)

View File

@@ -1330,8 +1330,8 @@ class ElectraForQuestionAnswering(ElectraPreTrainedModel):
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)

View File

@@ -1561,8 +1561,8 @@ class FunnelForQuestionAnswering(FunnelPreTrainedModel):
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)

View File

@@ -1331,8 +1331,8 @@ class IBertForQuestionAnswering(IBertPreTrainedModel):
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)

View File

@@ -2607,8 +2607,8 @@ class LEDForQuestionAnswering(LEDPreTrainedModel):
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)

View File

@@ -2029,8 +2029,8 @@ class LongformerForQuestionAnswering(LongformerPreTrainedModel):
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)

View File

@@ -1585,8 +1585,8 @@ class MBartForQuestionAnswering(MBartPreTrainedModel):
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)

View File

@@ -1806,8 +1806,8 @@ class MegatronBertForQuestionAnswering(MegatronBertPreTrainedModel):
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)

View File

@@ -1383,8 +1383,8 @@ class MobileBertForQuestionAnswering(MobileBertPreTrainedModel):
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)

View File

@@ -1035,8 +1035,8 @@ class MPNetForQuestionAnswering(MPNetPreTrainedModel):
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)

View File

@@ -2567,8 +2567,8 @@ class ReformerForQuestionAnswering(ReformerPreTrainedModel):
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)

View File

@@ -1484,8 +1484,8 @@ class RobertaForQuestionAnswering(RobertaPreTrainedModel):
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)

View File

@@ -1554,8 +1554,8 @@ class RoFormerForQuestionAnswering(RoFormerPreTrainedModel):
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)

View File

@@ -1080,8 +1080,8 @@ class SqueezeBertForQuestionAnswering(SqueezeBertPreTrainedModel):
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)

View File

@@ -953,8 +953,8 @@ class XLMForQuestionAnsweringSimple(XLMPreTrainedModel):
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)

View File

@@ -1874,8 +1874,8 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)

View File

@@ -1516,8 +1516,8 @@ class {{cookiecutter.camelcase_modelname}}ForQuestionAnswering({{cookiecutter.ca
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)
@@ -3066,8 +3066,8 @@ class {{cookiecutter.camelcase_modelname}}ForQuestionAnswering({{cookiecutter.ca
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)