[trainer] 2 bug fixes and a rename (#12309)
* bug fixes and a rename * add extended DDP test
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
@@ -21,6 +22,7 @@ from unittest.mock import patch
|
||||
from transformers.file_utils import is_apex_available
|
||||
from transformers.integrations import is_fairscale_available
|
||||
from transformers.testing_utils import (
|
||||
CaptureStderr,
|
||||
ExtendSysPath,
|
||||
TestCasePlus,
|
||||
execute_subprocess_async,
|
||||
@@ -68,7 +70,15 @@ def require_apex(test_case):
|
||||
|
||||
|
||||
class TestTrainerExt(TestCasePlus):
|
||||
def run_seq2seq_quick(self, distributed=False, extra_args_str=None, predict_with_generate=True):
|
||||
def run_seq2seq_quick(
|
||||
self,
|
||||
distributed=False,
|
||||
extra_args_str=None,
|
||||
predict_with_generate=True,
|
||||
do_train=True,
|
||||
do_eval=True,
|
||||
do_predict=True,
|
||||
):
|
||||
output_dir = self.run_trainer(
|
||||
eval_steps=1,
|
||||
max_len=12,
|
||||
@@ -77,8 +87,15 @@ class TestTrainerExt(TestCasePlus):
|
||||
distributed=distributed,
|
||||
extra_args_str=extra_args_str,
|
||||
predict_with_generate=predict_with_generate,
|
||||
do_train=do_train,
|
||||
do_eval=do_eval,
|
||||
do_predict=do_predict,
|
||||
)
|
||||
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
|
||||
|
||||
if not do_eval:
|
||||
return
|
||||
|
||||
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
|
||||
|
||||
first_step_stats = eval_metrics[0]
|
||||
@@ -145,6 +162,49 @@ class TestTrainerExt(TestCasePlus):
|
||||
# to reproduce the problem set distributed=False
|
||||
self.run_seq2seq_quick(distributed=True, extra_args_str="--fp16 --fp16_backend=apex")
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_trainer_log_level_replica(self):
|
||||
log_info_string = "Running training"
|
||||
kwargs = dict(distributed=True, predict_with_generate=False, do_eval=False, do_predict=False)
|
||||
|
||||
# test with the default log_level - should be info and thus log info once
|
||||
with CaptureStderr() as cl:
|
||||
self.run_seq2seq_quick(
|
||||
**kwargs,
|
||||
extra_args_str="",
|
||||
)
|
||||
n_matches = len(re.findall(log_info_string, cl.err))
|
||||
self.assertEqual(n_matches, 1)
|
||||
|
||||
# test with low log_level and log_level_replica - should be noisy on all processes
|
||||
# now the info string should appear twice on 2 processes
|
||||
with CaptureStderr() as cl:
|
||||
self.run_seq2seq_quick(
|
||||
**kwargs,
|
||||
extra_args_str="--log_level debug --log_level_replica debug",
|
||||
)
|
||||
n_matches = len(re.findall(log_info_string, cl.err))
|
||||
self.assertEqual(n_matches, 2)
|
||||
|
||||
# test with high log_level and low log_level_replica
|
||||
# now the info string should appear once only on the replica
|
||||
with CaptureStderr() as cl:
|
||||
self.run_seq2seq_quick(
|
||||
**kwargs,
|
||||
extra_args_str="--log_level error --log_level_replica debug",
|
||||
)
|
||||
n_matches = len(re.findall(log_info_string, cl.err))
|
||||
self.assertEqual(n_matches, 1)
|
||||
|
||||
# test with high log_level and log_level_replica - should be quiet on all processes
|
||||
with CaptureStderr() as cl:
|
||||
self.run_seq2seq_quick(
|
||||
**kwargs,
|
||||
extra_args_str="--log_level error --log_level_replica error",
|
||||
)
|
||||
n_matches = len(re.findall(log_info_string, cl.err))
|
||||
self.assertEqual(n_matches, 0)
|
||||
|
||||
@slow
|
||||
def test_run_seq2seq_slow(self):
|
||||
output_dir = self.run_trainer(
|
||||
@@ -181,10 +241,13 @@ class TestTrainerExt(TestCasePlus):
|
||||
distributed: bool = False,
|
||||
extra_args_str: str = None,
|
||||
predict_with_generate: bool = True,
|
||||
do_train: bool = True,
|
||||
do_eval: bool = True,
|
||||
do_predict: bool = True,
|
||||
):
|
||||
data_dir = self.test_file_dir / "../fixtures/tests_samples/wmt_en_ro"
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
args = f"""
|
||||
args_train = f"""
|
||||
--model_name_or_path {model_name}
|
||||
--train_file {data_dir}/train.json
|
||||
--validation_file {data_dir}/val.json
|
||||
@@ -192,21 +255,14 @@ class TestTrainerExt(TestCasePlus):
|
||||
--output_dir {output_dir}
|
||||
--overwrite_output_dir
|
||||
--max_train_samples 8
|
||||
--max_eval_samples 8
|
||||
--max_source_length {max_len}
|
||||
--max_target_length {max_len}
|
||||
--val_max_target_length {max_len}
|
||||
--do_train
|
||||
--do_eval
|
||||
--do_predict
|
||||
--num_train_epochs {str(num_train_epochs)}
|
||||
--per_device_train_batch_size 4
|
||||
--per_device_eval_batch_size 4
|
||||
--learning_rate {learning_rate}
|
||||
--warmup_steps 8
|
||||
--evaluation_strategy steps
|
||||
--logging_steps 0
|
||||
--eval_steps {str(eval_steps)}
|
||||
--save_steps {str(eval_steps)}
|
||||
--group_by_length
|
||||
--label_smoothing_factor 0.1
|
||||
@@ -214,6 +270,30 @@ class TestTrainerExt(TestCasePlus):
|
||||
--target_lang ro_RO
|
||||
--source_lang en_XX
|
||||
"""
|
||||
|
||||
args_eval = f"""
|
||||
--do_eval
|
||||
--per_device_eval_batch_size 4
|
||||
--max_eval_samples 8
|
||||
--val_max_target_length {max_len}
|
||||
--evaluation_strategy steps
|
||||
--eval_steps {str(eval_steps)}
|
||||
"""
|
||||
|
||||
args_predict = """
|
||||
--do_predict
|
||||
"""
|
||||
|
||||
args = ""
|
||||
if do_train:
|
||||
args += args_train
|
||||
|
||||
if do_eval:
|
||||
args += args_eval
|
||||
|
||||
if do_predict:
|
||||
args += args_predict
|
||||
|
||||
if predict_with_generate:
|
||||
args += "--predict_with_generate"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user