Fix TFTrainer prediction output (#9662)
* Fix TFTrainer prediction output * Update trainer_tf.py * Fix TFTrainer prediction output * Fix evaluation_loss update in TFTrainer * Fix TFTrainer prediction output
This commit is contained in:
committed by
GitHub
parent
9152f16023
commit
6312fed47d
@@ -101,6 +101,7 @@ class TFTrainer:
|
|||||||
self.gradient_accumulator = GradientAccumulator()
|
self.gradient_accumulator = GradientAccumulator()
|
||||||
self.global_step = 0
|
self.global_step = 0
|
||||||
self.epoch_logging = 0
|
self.epoch_logging = 0
|
||||||
|
self.eval_loss = tf.keras.metrics.Sum()
|
||||||
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
self.tb_writer = tb_writer
|
self.tb_writer = tb_writer
|
||||||
@@ -202,13 +203,8 @@ class TFTrainer:
|
|||||||
if num_examples < 0:
|
if num_examples < 0:
|
||||||
raise ValueError("The training dataset must have an asserted cardinality")
|
raise ValueError("The training dataset must have an asserted cardinality")
|
||||||
|
|
||||||
approx = math.floor if self.args.dataloader_drop_last else math.ceil
|
steps = math.ceil(num_examples / self.args.eval_batch_size)
|
||||||
steps = approx(num_examples / self.args.eval_batch_size)
|
ds = test_dataset.batch(self.args.eval_batch_size).prefetch(tf.data.experimental.AUTOTUNE)
|
||||||
ds = (
|
|
||||||
test_dataset.repeat()
|
|
||||||
.batch(self.args.eval_batch_size, drop_remainder=self.args.dataloader_drop_last)
|
|
||||||
.prefetch(tf.data.experimental.AUTOTUNE)
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.args.strategy.experimental_distribute_dataset(ds), steps, num_examples
|
return self.args.strategy.experimental_distribute_dataset(ds), steps, num_examples
|
||||||
|
|
||||||
@@ -300,12 +296,14 @@ class TFTrainer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
logger.info("***** Running %s *****", description)
|
logger.info("***** Running %s *****", description)
|
||||||
logger.info(" Num examples = %d", num_examples)
|
logger.info(" Num examples in dataset = %d", num_examples)
|
||||||
|
if description == "Evaluation":
|
||||||
|
logger.info(" Num examples in used in evaluation = %d", self.args.eval_batch_size * steps)
|
||||||
logger.info(" Batch size = %d", self.args.eval_batch_size)
|
logger.info(" Batch size = %d", self.args.eval_batch_size)
|
||||||
|
|
||||||
label_ids: np.ndarray = None
|
label_ids: np.ndarray = None
|
||||||
preds: np.ndarray = None
|
preds: np.ndarray = None
|
||||||
self.eval_loss = tf.keras.metrics.Sum()
|
self.eval_loss.reset_states()
|
||||||
|
|
||||||
# Reset the past mems state at the beginning of the evaluation if necessary.
|
# Reset the past mems state at the beginning of the evaluation if necessary.
|
||||||
if self.args.past_index >= 0:
|
if self.args.past_index >= 0:
|
||||||
@@ -345,7 +343,7 @@ class TFTrainer:
|
|||||||
else:
|
else:
|
||||||
label_ids = np.append(label_ids, labels.numpy(), axis=0)
|
label_ids = np.append(label_ids, labels.numpy(), axis=0)
|
||||||
|
|
||||||
if step == steps:
|
if step == steps - 1:
|
||||||
break
|
break
|
||||||
|
|
||||||
if self.compute_metrics is not None and preds is not None and label_ids is not None:
|
if self.compute_metrics is not None and preds is not None and label_ids is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user