modify qa-trainer (#11872)

* modify qa-trainer

* fix flax model
This commit is contained in:
Fan Zhang
2021-06-01 20:28:41 +08:00
committed by GitHub
parent 9ec0f01b6c
commit 7e73601f32
25 changed files with 57 additions and 49 deletions

View File

@@ -692,7 +692,11 @@ def main():
if completed_steps >= args.max_train_steps: if completed_steps >= args.max_train_steps:
break break
# Validation # Evaluation
logger.info("***** Running Evaluation *****")
logger.info(f" Num examples = {len(eval_dataset)}")
logger.info(f" Batch size = {args.per_device_eval_batch_size}")
all_start_logits = [] all_start_logits = []
all_end_logits = [] all_end_logits = []
for step, batch in enumerate(eval_dataloader): for step, batch in enumerate(eval_dataloader):
@@ -725,6 +729,10 @@ def main():
# Prediction # Prediction
if args.do_predict: if args.do_predict:
logger.info("***** Running Prediction *****")
logger.info(f" Num examples = {len(predict_dataset)}")
logger.info(f" Batch size = {args.per_device_eval_batch_size}")
all_start_logits = [] all_start_logits = []
all_end_logits = [] all_end_logits = []
for step, batch in enumerate(predict_dataloader): for step, batch in enumerate(predict_dataloader):

View File

@@ -1218,8 +1218,8 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None total_loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:

View File

@@ -1556,8 +1556,8 @@ class BartForQuestionAnswering(BartPretrainedModel):
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None total_loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:

View File

@@ -1801,8 +1801,8 @@ class BertForQuestionAnswering(BertPreTrainedModel):
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None total_loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:

View File

@@ -2983,8 +2983,8 @@ class BigBirdForQuestionAnswering(BigBirdPreTrainedModel):
logits = logits - logits_mask * 1e6 logits = logits - logits_mask * 1e6
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None total_loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:

View File

@@ -2761,8 +2761,8 @@ class BigBirdPegasusForQuestionAnswering(BigBirdPegasusPreTrainedModel):
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None total_loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:

View File

@@ -1293,8 +1293,8 @@ class ConvBertForQuestionAnswering(ConvBertPreTrainedModel):
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None total_loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:

View File

@@ -1364,8 +1364,8 @@ class DebertaForQuestionAnswering(DebertaPreTrainedModel):
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None total_loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:

View File

@@ -1488,8 +1488,8 @@ class DebertaV2ForQuestionAnswering(DebertaV2PreTrainedModel):
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None total_loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:

View File

@@ -728,8 +728,8 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
hidden_states = self.dropout(hidden_states) # (bs, max_query_len, dim) hidden_states = self.dropout(hidden_states) # (bs, max_query_len, dim)
logits = self.qa_outputs(hidden_states) # (bs, max_query_len, 2) logits = self.qa_outputs(hidden_states) # (bs, max_query_len, 2)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) # (bs, max_query_len) start_logits = start_logits.squeeze(-1).contiguous() # (bs, max_query_len)
end_logits = end_logits.squeeze(-1) # (bs, max_query_len) end_logits = end_logits.squeeze(-1).contiguous() # (bs, max_query_len)
total_loss = None total_loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:

View File

@@ -241,8 +241,8 @@ class DPRSpanPredictor(PreTrainedModel):
# compute logits # compute logits
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1).contiguous()
relevance_logits = self.qa_classifier(sequence_output[:, 0, :]) relevance_logits = self.qa_classifier(sequence_output[:, 0, :])
# resize # resize

View File

@@ -1318,8 +1318,8 @@ class ElectraForQuestionAnswering(ElectraPreTrainedModel):
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None total_loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:

View File

@@ -1549,8 +1549,8 @@ class FunnelForQuestionAnswering(FunnelPreTrainedModel):
logits = self.qa_outputs(last_hidden_state) logits = self.qa_outputs(last_hidden_state)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None total_loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:

View File

@@ -1319,8 +1319,8 @@ class IBertForQuestionAnswering(IBertPreTrainedModel):
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None total_loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:

View File

@@ -2585,8 +2585,8 @@ class LEDForQuestionAnswering(LEDPreTrainedModel):
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None total_loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:

View File

@@ -2017,8 +2017,8 @@ class LongformerForQuestionAnswering(LongformerPreTrainedModel):
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None total_loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:

View File

@@ -1563,8 +1563,8 @@ class MBartForQuestionAnswering(MBartPreTrainedModel):
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None total_loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:

View File

@@ -1794,8 +1794,8 @@ class MegatronBertForQuestionAnswering(MegatronBertPreTrainedModel):
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None total_loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:

View File

@@ -1371,8 +1371,8 @@ class MobileBertForQuestionAnswering(MobileBertPreTrainedModel):
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None total_loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:

View File

@@ -1023,8 +1023,8 @@ class MPNetForQuestionAnswering(MPNetPreTrainedModel):
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None total_loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:

View File

@@ -2555,8 +2555,8 @@ class ReformerForQuestionAnswering(ReformerPreTrainedModel):
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None total_loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:

View File

@@ -1472,8 +1472,8 @@ class RobertaForQuestionAnswering(RobertaPreTrainedModel):
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None total_loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:

View File

@@ -1068,8 +1068,8 @@ class SqueezeBertForQuestionAnswering(SqueezeBertPreTrainedModel):
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None total_loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:

View File

@@ -941,8 +941,8 @@ class XLMForQuestionAnsweringSimple(XLMPreTrainedModel):
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None total_loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:

View File

@@ -1862,8 +1862,8 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None total_loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None: