Clean up diffs in Trainer/TFTrainer (#5417)

* Cleanup and unify Trainer/TFTrainer

* Forgot to adapt TFTrainingArgs

* In tf scripts n_gpu -> n_replicas

* Update src/transformers/training_args.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Address review comments

* Formatting

* Fix typo

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
Sylvain Gugger
2020-07-01 11:00:20 -04:00
committed by GitHub
parent 43cb03a93d
commit 734a28a767
10 changed files with 109 additions and 56 deletions

View File

@@ -108,7 +108,10 @@ def main():
level=logging.INFO,
)
logger.warning(
"device: %s, n_gpu: %s, 16-bits training: %s", training_args.device, training_args.n_gpu, training_args.fp16,
"device: %s, n_replicas: %s, 16-bits training: %s",
training_args.device,
training_args.n_replicas,
training_args.fp16,
)
logger.info("Training/evaluation parameters %s", training_args)

View File

@@ -137,9 +137,9 @@ def main():
level=logging.INFO,
)
logger.info(
"n_gpu: %s, distributed training: %s, 16-bits training: %s",
training_args.n_gpu,
bool(training_args.n_gpu > 1),
"n_replicas: %s, distributed training: %s, 16-bits training: %s",
training_args.n_replicas,
bool(training_args.n_replicas > 1),
training_args.fp16,
)
logger.info("Training/evaluation parameters %s", training_args)

View File

@@ -131,9 +131,9 @@ def main():
level=logging.INFO,
)
logger.info(
"n_gpu: %s, distributed training: %s, 16-bits training: %s",
training_args.n_gpu,
bool(training_args.n_gpu > 1),
"n_replicas: %s, distributed training: %s, 16-bits training: %s",
training_args.n_replicas,
bool(training_args.n_replicas > 1),
training_args.fp16,
)
logger.info("Training/evaluation parameters %s", training_args)

View File

@@ -109,9 +109,9 @@ def main():
level=logging.INFO,
)
logger.info(
"n_gpu: %s, distributed training: %s, 16-bits training: %s",
training_args.n_gpu,
bool(training_args.n_gpu > 1),
"n_replicas: %s, distributed training: %s, 16-bits training: %s",
training_args.n_replicas,
bool(training_args.n_replicas > 1),
training_args.fp16,
)
logger.info("Training/evaluation parameters %s", training_args)