[trainer] 2 bug fixes and a rename (#12309)

* bug fixes and a rename

* add extended DDP test
This commit is contained in:
Stas Bekman
2021-06-22 11:13:23 -07:00
committed by GitHub
parent 64029abe4c
commit ebe5413589
7 changed files with 112 additions and 18 deletions

View File

@@ -14,6 +14,7 @@
import math
import os
import re
import sys
import unittest
from unittest.mock import patch
@@ -21,6 +22,7 @@ from unittest.mock import patch
from transformers.file_utils import is_apex_available
from transformers.integrations import is_fairscale_available
from transformers.testing_utils import (
CaptureStderr,
ExtendSysPath,
TestCasePlus,
execute_subprocess_async,
@@ -68,7 +70,15 @@ def require_apex(test_case):
class TestTrainerExt(TestCasePlus):
def run_seq2seq_quick(self, distributed=False, extra_args_str=None, predict_with_generate=True):
def run_seq2seq_quick(
self,
distributed=False,
extra_args_str=None,
predict_with_generate=True,
do_train=True,
do_eval=True,
do_predict=True,
):
output_dir = self.run_trainer(
eval_steps=1,
max_len=12,
@@ -77,8 +87,15 @@ class TestTrainerExt(TestCasePlus):
distributed=distributed,
extra_args_str=extra_args_str,
predict_with_generate=predict_with_generate,
do_train=do_train,
do_eval=do_eval,
do_predict=do_predict,
)
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
if not do_eval:
return
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
first_step_stats = eval_metrics[0]
@@ -145,6 +162,49 @@ class TestTrainerExt(TestCasePlus):
# to reproduce the problem set distributed=False
self.run_seq2seq_quick(distributed=True, extra_args_str="--fp16 --fp16_backend=apex")
@require_torch_multi_gpu
def test_trainer_log_level_replica(self):
log_info_string = "Running training"
kwargs = dict(distributed=True, predict_with_generate=False, do_eval=False, do_predict=False)
# test with the default log_level - should be info and thus log info once
with CaptureStderr() as cl:
self.run_seq2seq_quick(
**kwargs,
extra_args_str="",
)
n_matches = len(re.findall(log_info_string, cl.err))
self.assertEqual(n_matches, 1)
# test with low log_level and log_level_replica - should be noisy on all processes
# now the info string should appear twice on 2 processes
with CaptureStderr() as cl:
self.run_seq2seq_quick(
**kwargs,
extra_args_str="--log_level debug --log_level_replica debug",
)
n_matches = len(re.findall(log_info_string, cl.err))
self.assertEqual(n_matches, 2)
# test with high log_level and low log_level_replica
# now the info string should appear once only on the replica
with CaptureStderr() as cl:
self.run_seq2seq_quick(
**kwargs,
extra_args_str="--log_level error --log_level_replica debug",
)
n_matches = len(re.findall(log_info_string, cl.err))
self.assertEqual(n_matches, 1)
# test with high log_level and log_level_replica - should be quiet on all processes
with CaptureStderr() as cl:
self.run_seq2seq_quick(
**kwargs,
extra_args_str="--log_level error --log_level_replica error",
)
n_matches = len(re.findall(log_info_string, cl.err))
self.assertEqual(n_matches, 0)
@slow
def test_run_seq2seq_slow(self):
output_dir = self.run_trainer(
@@ -181,10 +241,13 @@ class TestTrainerExt(TestCasePlus):
distributed: bool = False,
extra_args_str: str = None,
predict_with_generate: bool = True,
do_train: bool = True,
do_eval: bool = True,
do_predict: bool = True,
):
data_dir = self.test_file_dir / "../fixtures/tests_samples/wmt_en_ro"
output_dir = self.get_auto_remove_tmp_dir()
args = f"""
args_train = f"""
--model_name_or_path {model_name}
--train_file {data_dir}/train.json
--validation_file {data_dir}/val.json
@@ -192,21 +255,14 @@ class TestTrainerExt(TestCasePlus):
--output_dir {output_dir}
--overwrite_output_dir
--max_train_samples 8
--max_eval_samples 8
--max_source_length {max_len}
--max_target_length {max_len}
--val_max_target_length {max_len}
--do_train
--do_eval
--do_predict
--num_train_epochs {str(num_train_epochs)}
--per_device_train_batch_size 4
--per_device_eval_batch_size 4
--learning_rate {learning_rate}
--warmup_steps 8
--evaluation_strategy steps
--logging_steps 0
--eval_steps {str(eval_steps)}
--save_steps {str(eval_steps)}
--group_by_length
--label_smoothing_factor 0.1
@@ -214,6 +270,30 @@ class TestTrainerExt(TestCasePlus):
--target_lang ro_RO
--source_lang en_XX
"""
args_eval = f"""
--do_eval
--per_device_eval_batch_size 4
--max_eval_samples 8
--val_max_target_length {max_len}
--evaluation_strategy steps
--eval_steps {str(eval_steps)}
"""
args_predict = """
--do_predict
"""
args = ""
if do_train:
args += args_train
if do_eval:
args += args_eval
if do_predict:
args += args_predict
if predict_with_generate:
args += "--predict_with_generate"