Fixes bug that appears when using QA bert and distilation. (#12026)
* Fixing bug that appears when using distilation (and potentially other uses). During backward pass Pytorch complains with: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation This happens because the QA model code modifies the start_positions and end_positions input tensors, using clamp_ function: as a consequence the teacher and the student both modifies the inputs, and backward pass fails. * Fixing all models QA clamp_ bug.
This commit is contained in:
@@ -1230,8 +1230,8 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
|
|||||||
end_positions = end_positions.squeeze(-1)
|
end_positions = end_positions.squeeze(-1)
|
||||||
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||||
ignored_index = start_logits.size(1)
|
ignored_index = start_logits.size(1)
|
||||||
start_positions.clamp_(0, ignored_index)
|
start_positions = start_positions.clamp(0, ignored_index)
|
||||||
end_positions.clamp_(0, ignored_index)
|
end_positions = end_positions.clamp(0, ignored_index)
|
||||||
|
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||||
start_loss = loss_fct(start_logits, start_positions)
|
start_loss = loss_fct(start_logits, start_positions)
|
||||||
|
|||||||
@@ -1578,8 +1578,8 @@ class BartForQuestionAnswering(BartPretrainedModel):
|
|||||||
end_positions = end_positions.squeeze(-1)
|
end_positions = end_positions.squeeze(-1)
|
||||||
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||||
ignored_index = start_logits.size(1)
|
ignored_index = start_logits.size(1)
|
||||||
start_positions.clamp_(0, ignored_index)
|
start_positions = start_positions.clamp(0, ignored_index)
|
||||||
end_positions.clamp_(0, ignored_index)
|
end_positions = end_positions.clamp(0, ignored_index)
|
||||||
|
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||||
start_loss = loss_fct(start_logits, start_positions)
|
start_loss = loss_fct(start_logits, start_positions)
|
||||||
|
|||||||
@@ -1813,8 +1813,8 @@ class BertForQuestionAnswering(BertPreTrainedModel):
|
|||||||
end_positions = end_positions.squeeze(-1)
|
end_positions = end_positions.squeeze(-1)
|
||||||
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||||
ignored_index = start_logits.size(1)
|
ignored_index = start_logits.size(1)
|
||||||
start_positions.clamp_(0, ignored_index)
|
start_positions = start_positions.clamp(0, ignored_index)
|
||||||
end_positions.clamp_(0, ignored_index)
|
end_positions = end_positions.clamp(0, ignored_index)
|
||||||
|
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||||
start_loss = loss_fct(start_logits, start_positions)
|
start_loss = loss_fct(start_logits, start_positions)
|
||||||
|
|||||||
@@ -2995,8 +2995,8 @@ class BigBirdForQuestionAnswering(BigBirdPreTrainedModel):
|
|||||||
end_positions = end_positions.squeeze(-1)
|
end_positions = end_positions.squeeze(-1)
|
||||||
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||||
ignored_index = start_logits.size(1)
|
ignored_index = start_logits.size(1)
|
||||||
start_positions.clamp_(0, ignored_index)
|
start_positions = start_positions.clamp(0, ignored_index)
|
||||||
end_positions.clamp_(0, ignored_index)
|
end_positions = end_positions.clamp(0, ignored_index)
|
||||||
|
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||||
start_loss = loss_fct(start_logits, start_positions)
|
start_loss = loss_fct(start_logits, start_positions)
|
||||||
|
|||||||
@@ -2783,8 +2783,8 @@ class BigBirdPegasusForQuestionAnswering(BigBirdPegasusPreTrainedModel):
|
|||||||
end_positions = end_positions.squeeze(-1)
|
end_positions = end_positions.squeeze(-1)
|
||||||
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||||
ignored_index = start_logits.size(1)
|
ignored_index = start_logits.size(1)
|
||||||
start_positions.clamp_(0, ignored_index)
|
start_positions = start_positions.clamp(0, ignored_index)
|
||||||
end_positions.clamp_(0, ignored_index)
|
end_positions = end_positions.clamp(0, ignored_index)
|
||||||
|
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||||
start_loss = loss_fct(start_logits, start_positions)
|
start_loss = loss_fct(start_logits, start_positions)
|
||||||
|
|||||||
@@ -1305,8 +1305,8 @@ class ConvBertForQuestionAnswering(ConvBertPreTrainedModel):
|
|||||||
end_positions = end_positions.squeeze(-1)
|
end_positions = end_positions.squeeze(-1)
|
||||||
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||||
ignored_index = start_logits.size(1)
|
ignored_index = start_logits.size(1)
|
||||||
start_positions.clamp_(0, ignored_index)
|
start_positions = start_positions.clamp(0, ignored_index)
|
||||||
end_positions.clamp_(0, ignored_index)
|
end_positions = end_positions.clamp(0, ignored_index)
|
||||||
|
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||||
start_loss = loss_fct(start_logits, start_positions)
|
start_loss = loss_fct(start_logits, start_positions)
|
||||||
|
|||||||
@@ -1376,8 +1376,8 @@ class DebertaForQuestionAnswering(DebertaPreTrainedModel):
|
|||||||
end_positions = end_positions.squeeze(-1)
|
end_positions = end_positions.squeeze(-1)
|
||||||
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||||
ignored_index = start_logits.size(1)
|
ignored_index = start_logits.size(1)
|
||||||
start_positions.clamp_(0, ignored_index)
|
start_positions = start_positions.clamp(0, ignored_index)
|
||||||
end_positions.clamp_(0, ignored_index)
|
end_positions = end_positions.clamp(0, ignored_index)
|
||||||
|
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||||
start_loss = loss_fct(start_logits, start_positions)
|
start_loss = loss_fct(start_logits, start_positions)
|
||||||
|
|||||||
@@ -1500,8 +1500,8 @@ class DebertaV2ForQuestionAnswering(DebertaV2PreTrainedModel):
|
|||||||
end_positions = end_positions.squeeze(-1)
|
end_positions = end_positions.squeeze(-1)
|
||||||
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||||
ignored_index = start_logits.size(1)
|
ignored_index = start_logits.size(1)
|
||||||
start_positions.clamp_(0, ignored_index)
|
start_positions = start_positions.clamp(0, ignored_index)
|
||||||
end_positions.clamp_(0, ignored_index)
|
end_positions = end_positions.clamp(0, ignored_index)
|
||||||
|
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||||
start_loss = loss_fct(start_logits, start_positions)
|
start_loss = loss_fct(start_logits, start_positions)
|
||||||
|
|||||||
@@ -740,8 +740,8 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
|
|||||||
end_positions = end_positions.squeeze(-1)
|
end_positions = end_positions.squeeze(-1)
|
||||||
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||||
ignored_index = start_logits.size(1)
|
ignored_index = start_logits.size(1)
|
||||||
start_positions.clamp_(0, ignored_index)
|
start_positions = start_positions.clamp(0, ignored_index)
|
||||||
end_positions.clamp_(0, ignored_index)
|
end_positions = end_positions.clamp(0, ignored_index)
|
||||||
|
|
||||||
loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
|
loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
|
||||||
start_loss = loss_fct(start_logits, start_positions)
|
start_loss = loss_fct(start_logits, start_positions)
|
||||||
|
|||||||
@@ -1330,8 +1330,8 @@ class ElectraForQuestionAnswering(ElectraPreTrainedModel):
|
|||||||
end_positions = end_positions.squeeze(-1)
|
end_positions = end_positions.squeeze(-1)
|
||||||
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||||
ignored_index = start_logits.size(1)
|
ignored_index = start_logits.size(1)
|
||||||
start_positions.clamp_(0, ignored_index)
|
start_positions = start_positions.clamp(0, ignored_index)
|
||||||
end_positions.clamp_(0, ignored_index)
|
end_positions = end_positions.clamp(0, ignored_index)
|
||||||
|
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||||
start_loss = loss_fct(start_logits, start_positions)
|
start_loss = loss_fct(start_logits, start_positions)
|
||||||
|
|||||||
@@ -1561,8 +1561,8 @@ class FunnelForQuestionAnswering(FunnelPreTrainedModel):
|
|||||||
end_positions = end_positions.squeeze(-1)
|
end_positions = end_positions.squeeze(-1)
|
||||||
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||||
ignored_index = start_logits.size(1)
|
ignored_index = start_logits.size(1)
|
||||||
start_positions.clamp_(0, ignored_index)
|
start_positions = start_positions.clamp(0, ignored_index)
|
||||||
end_positions.clamp_(0, ignored_index)
|
end_positions = end_positions.clamp(0, ignored_index)
|
||||||
|
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||||
start_loss = loss_fct(start_logits, start_positions)
|
start_loss = loss_fct(start_logits, start_positions)
|
||||||
|
|||||||
@@ -1331,8 +1331,8 @@ class IBertForQuestionAnswering(IBertPreTrainedModel):
|
|||||||
end_positions = end_positions.squeeze(-1)
|
end_positions = end_positions.squeeze(-1)
|
||||||
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||||
ignored_index = start_logits.size(1)
|
ignored_index = start_logits.size(1)
|
||||||
start_positions.clamp_(0, ignored_index)
|
start_positions = start_positions.clamp(0, ignored_index)
|
||||||
end_positions.clamp_(0, ignored_index)
|
end_positions = end_positions.clamp(0, ignored_index)
|
||||||
|
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||||
start_loss = loss_fct(start_logits, start_positions)
|
start_loss = loss_fct(start_logits, start_positions)
|
||||||
|
|||||||
@@ -2607,8 +2607,8 @@ class LEDForQuestionAnswering(LEDPreTrainedModel):
|
|||||||
end_positions = end_positions.squeeze(-1)
|
end_positions = end_positions.squeeze(-1)
|
||||||
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||||
ignored_index = start_logits.size(1)
|
ignored_index = start_logits.size(1)
|
||||||
start_positions.clamp_(0, ignored_index)
|
start_positions = start_positions.clamp(0, ignored_index)
|
||||||
end_positions.clamp_(0, ignored_index)
|
end_positions = end_positions.clamp(0, ignored_index)
|
||||||
|
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||||
start_loss = loss_fct(start_logits, start_positions)
|
start_loss = loss_fct(start_logits, start_positions)
|
||||||
|
|||||||
@@ -2029,8 +2029,8 @@ class LongformerForQuestionAnswering(LongformerPreTrainedModel):
|
|||||||
end_positions = end_positions.squeeze(-1)
|
end_positions = end_positions.squeeze(-1)
|
||||||
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||||
ignored_index = start_logits.size(1)
|
ignored_index = start_logits.size(1)
|
||||||
start_positions.clamp_(0, ignored_index)
|
start_positions = start_positions.clamp(0, ignored_index)
|
||||||
end_positions.clamp_(0, ignored_index)
|
end_positions = end_positions.clamp(0, ignored_index)
|
||||||
|
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||||
start_loss = loss_fct(start_logits, start_positions)
|
start_loss = loss_fct(start_logits, start_positions)
|
||||||
|
|||||||
@@ -1585,8 +1585,8 @@ class MBartForQuestionAnswering(MBartPreTrainedModel):
|
|||||||
end_positions = end_positions.squeeze(-1)
|
end_positions = end_positions.squeeze(-1)
|
||||||
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||||
ignored_index = start_logits.size(1)
|
ignored_index = start_logits.size(1)
|
||||||
start_positions.clamp_(0, ignored_index)
|
start_positions = start_positions.clamp(0, ignored_index)
|
||||||
end_positions.clamp_(0, ignored_index)
|
end_positions = end_positions.clamp(0, ignored_index)
|
||||||
|
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||||
start_loss = loss_fct(start_logits, start_positions)
|
start_loss = loss_fct(start_logits, start_positions)
|
||||||
|
|||||||
@@ -1806,8 +1806,8 @@ class MegatronBertForQuestionAnswering(MegatronBertPreTrainedModel):
|
|||||||
end_positions = end_positions.squeeze(-1)
|
end_positions = end_positions.squeeze(-1)
|
||||||
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||||
ignored_index = start_logits.size(1)
|
ignored_index = start_logits.size(1)
|
||||||
start_positions.clamp_(0, ignored_index)
|
start_positions = start_positions.clamp(0, ignored_index)
|
||||||
end_positions.clamp_(0, ignored_index)
|
end_positions = end_positions.clamp(0, ignored_index)
|
||||||
|
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||||
start_loss = loss_fct(start_logits, start_positions)
|
start_loss = loss_fct(start_logits, start_positions)
|
||||||
|
|||||||
@@ -1383,8 +1383,8 @@ class MobileBertForQuestionAnswering(MobileBertPreTrainedModel):
|
|||||||
end_positions = end_positions.squeeze(-1)
|
end_positions = end_positions.squeeze(-1)
|
||||||
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||||
ignored_index = start_logits.size(1)
|
ignored_index = start_logits.size(1)
|
||||||
start_positions.clamp_(0, ignored_index)
|
start_positions = start_positions.clamp(0, ignored_index)
|
||||||
end_positions.clamp_(0, ignored_index)
|
end_positions = end_positions.clamp(0, ignored_index)
|
||||||
|
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||||
start_loss = loss_fct(start_logits, start_positions)
|
start_loss = loss_fct(start_logits, start_positions)
|
||||||
|
|||||||
@@ -1035,8 +1035,8 @@ class MPNetForQuestionAnswering(MPNetPreTrainedModel):
|
|||||||
end_positions = end_positions.squeeze(-1)
|
end_positions = end_positions.squeeze(-1)
|
||||||
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||||
ignored_index = start_logits.size(1)
|
ignored_index = start_logits.size(1)
|
||||||
start_positions.clamp_(0, ignored_index)
|
start_positions = start_positions.clamp(0, ignored_index)
|
||||||
end_positions.clamp_(0, ignored_index)
|
end_positions = end_positions.clamp(0, ignored_index)
|
||||||
|
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||||
start_loss = loss_fct(start_logits, start_positions)
|
start_loss = loss_fct(start_logits, start_positions)
|
||||||
|
|||||||
@@ -2567,8 +2567,8 @@ class ReformerForQuestionAnswering(ReformerPreTrainedModel):
|
|||||||
end_positions = end_positions.squeeze(-1)
|
end_positions = end_positions.squeeze(-1)
|
||||||
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||||
ignored_index = start_logits.size(1)
|
ignored_index = start_logits.size(1)
|
||||||
start_positions.clamp_(0, ignored_index)
|
start_positions = start_positions.clamp(0, ignored_index)
|
||||||
end_positions.clamp_(0, ignored_index)
|
end_positions = end_positions.clamp(0, ignored_index)
|
||||||
|
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||||
start_loss = loss_fct(start_logits, start_positions)
|
start_loss = loss_fct(start_logits, start_positions)
|
||||||
|
|||||||
@@ -1484,8 +1484,8 @@ class RobertaForQuestionAnswering(RobertaPreTrainedModel):
|
|||||||
end_positions = end_positions.squeeze(-1)
|
end_positions = end_positions.squeeze(-1)
|
||||||
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||||
ignored_index = start_logits.size(1)
|
ignored_index = start_logits.size(1)
|
||||||
start_positions.clamp_(0, ignored_index)
|
start_positions = start_positions.clamp(0, ignored_index)
|
||||||
end_positions.clamp_(0, ignored_index)
|
end_positions = end_positions.clamp(0, ignored_index)
|
||||||
|
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||||
start_loss = loss_fct(start_logits, start_positions)
|
start_loss = loss_fct(start_logits, start_positions)
|
||||||
|
|||||||
@@ -1554,8 +1554,8 @@ class RoFormerForQuestionAnswering(RoFormerPreTrainedModel):
|
|||||||
end_positions = end_positions.squeeze(-1)
|
end_positions = end_positions.squeeze(-1)
|
||||||
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||||
ignored_index = start_logits.size(1)
|
ignored_index = start_logits.size(1)
|
||||||
start_positions.clamp_(0, ignored_index)
|
start_positions = start_positions.clamp(0, ignored_index)
|
||||||
end_positions.clamp_(0, ignored_index)
|
end_positions = end_positions.clamp(0, ignored_index)
|
||||||
|
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||||
start_loss = loss_fct(start_logits, start_positions)
|
start_loss = loss_fct(start_logits, start_positions)
|
||||||
|
|||||||
@@ -1080,8 +1080,8 @@ class SqueezeBertForQuestionAnswering(SqueezeBertPreTrainedModel):
|
|||||||
end_positions = end_positions.squeeze(-1)
|
end_positions = end_positions.squeeze(-1)
|
||||||
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||||
ignored_index = start_logits.size(1)
|
ignored_index = start_logits.size(1)
|
||||||
start_positions.clamp_(0, ignored_index)
|
start_positions = start_positions.clamp(0, ignored_index)
|
||||||
end_positions.clamp_(0, ignored_index)
|
end_positions = end_positions.clamp(0, ignored_index)
|
||||||
|
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||||
start_loss = loss_fct(start_logits, start_positions)
|
start_loss = loss_fct(start_logits, start_positions)
|
||||||
|
|||||||
@@ -953,8 +953,8 @@ class XLMForQuestionAnsweringSimple(XLMPreTrainedModel):
|
|||||||
end_positions = end_positions.squeeze(-1)
|
end_positions = end_positions.squeeze(-1)
|
||||||
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||||
ignored_index = start_logits.size(1)
|
ignored_index = start_logits.size(1)
|
||||||
start_positions.clamp_(0, ignored_index)
|
start_positions = start_positions.clamp(0, ignored_index)
|
||||||
end_positions.clamp_(0, ignored_index)
|
end_positions = end_positions.clamp(0, ignored_index)
|
||||||
|
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||||
start_loss = loss_fct(start_logits, start_positions)
|
start_loss = loss_fct(start_logits, start_positions)
|
||||||
|
|||||||
@@ -1874,8 +1874,8 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
|
|||||||
end_positions = end_positions.squeeze(-1)
|
end_positions = end_positions.squeeze(-1)
|
||||||
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||||
ignored_index = start_logits.size(1)
|
ignored_index = start_logits.size(1)
|
||||||
start_positions.clamp_(0, ignored_index)
|
start_positions = start_positions.clamp(0, ignored_index)
|
||||||
end_positions.clamp_(0, ignored_index)
|
end_positions = end_positions.clamp(0, ignored_index)
|
||||||
|
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||||
start_loss = loss_fct(start_logits, start_positions)
|
start_loss = loss_fct(start_logits, start_positions)
|
||||||
|
|||||||
@@ -1516,8 +1516,8 @@ class {{cookiecutter.camelcase_modelname}}ForQuestionAnswering({{cookiecutter.ca
|
|||||||
end_positions = end_positions.squeeze(-1)
|
end_positions = end_positions.squeeze(-1)
|
||||||
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||||
ignored_index = start_logits.size(1)
|
ignored_index = start_logits.size(1)
|
||||||
start_positions.clamp_(0, ignored_index)
|
start_positions = start_positions.clamp(0, ignored_index)
|
||||||
end_positions.clamp_(0, ignored_index)
|
end_positions = end_positions.clamp(0, ignored_index)
|
||||||
|
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||||
start_loss = loss_fct(start_logits, start_positions)
|
start_loss = loss_fct(start_logits, start_positions)
|
||||||
@@ -3066,8 +3066,8 @@ class {{cookiecutter.camelcase_modelname}}ForQuestionAnswering({{cookiecutter.ca
|
|||||||
end_positions = end_positions.squeeze(-1)
|
end_positions = end_positions.squeeze(-1)
|
||||||
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||||
ignored_index = start_logits.size(1)
|
ignored_index = start_logits.size(1)
|
||||||
start_positions.clamp_(0, ignored_index)
|
start_positions = start_positions.clamp(0, ignored_index)
|
||||||
end_positions.clamp_(0, ignored_index)
|
end_positions = end_positions.clamp(0, ignored_index)
|
||||||
|
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||||
start_loss = loss_fct(start_logits, start_positions)
|
start_loss = loss_fct(start_logits, start_positions)
|
||||||
|
|||||||
Reference in New Issue
Block a user