Fix arguments passed to predict function in QA Seq2seq training script (#21026)
fix args passed to predict function
This commit is contained in:
@@ -151,7 +151,7 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
if self.post_process_function is None or self.compute_metrics is None:
|
if self.post_process_function is None or self.compute_metrics is None:
|
||||||
return output
|
return output
|
||||||
|
|
||||||
predictions = self.post_process_function(predict_examples, predict_dataset, output.predictions, "predict")
|
predictions = self.post_process_function(predict_examples, predict_dataset, output, "predict")
|
||||||
metrics = self.compute_metrics(predictions)
|
metrics = self.compute_metrics(predictions)
|
||||||
|
|
||||||
# Prefix all keys with metric_key_prefix + '_'
|
# Prefix all keys with metric_key_prefix + '_'
|
||||||
|
|||||||
Reference in New Issue
Block a user