Add predict step accumulation (#7767)
* Add eval_accumulation_step and clean distributed eval * Add TPU test * Add TPU stuff * Fix arg name * Fix Seq2SeqTrainer * Fix total_size * Update src/transformers/trainer_pt_utils.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Doc and add test to TPU * Add unit test * Adapt name Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
@@ -13,15 +13,14 @@
|
||||
# CUDA_VISIBLE_DEVICES=-1 python ./tests/test_trainer_distributed.py
|
||||
#
|
||||
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from typing import Dict
|
||||
|
||||
from transformers import EvalPrediction, HfArgumentParser, TrainingArguments, is_torch_available
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -101,4 +100,20 @@ if __name__ == "__main__":
|
||||
logger.error(p.metrics)
|
||||
exit(1)
|
||||
|
||||
trainer.args.eval_accumulation_steps = 2
|
||||
|
||||
metrics = trainer.evaluate()
|
||||
logger.info(metrics)
|
||||
if metrics["eval_success"] is not True:
|
||||
logger.error(metrics)
|
||||
exit(1)
|
||||
|
||||
p = trainer.predict(dataset)
|
||||
logger.info(p.metrics)
|
||||
if p.metrics["eval_success"] is not True:
|
||||
logger.error(p.metrics)
|
||||
exit(1)
|
||||
|
||||
trainer.args.eval_accumulation_steps = None
|
||||
|
||||
logger.info("🔥 All distributed tests successful")
|
||||
|
||||
Reference in New Issue
Block a user