[s2sTrainer] test + code cleanup (#7467)
This commit is contained in:
@@ -3,36 +3,54 @@ import sys
|
||||
import tempfile
|
||||
from unittest.mock import patch
|
||||
|
||||
from transformers import BartForConditionalGeneration, MarianMTModel
|
||||
from transformers.testing_utils import slow
|
||||
from transformers.trainer_utils import set_seed
|
||||
|
||||
from .finetune_trainer import main
|
||||
from .test_seq2seq_examples import MBART_TINY
|
||||
from .utils import load_json
|
||||
|
||||
|
||||
MODEL_NAME = MBART_TINY
|
||||
# TODO(SS): MODEL_NAME = "sshleifer/student_mbart_en_ro_1_1"
|
||||
set_seed(42)
|
||||
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
|
||||
|
||||
|
||||
@slow
|
||||
def test_model_download():
|
||||
"""This warms up the cache so that we can time the next test without including download time, which varies between machines."""
|
||||
BartForConditionalGeneration.from_pretrained(MODEL_NAME)
|
||||
MarianMTModel.from_pretrained(MARIAN_MODEL)
|
||||
|
||||
|
||||
@slow
|
||||
def test_finetune_trainer():
|
||||
output_dir = run_trainer(1, "12", MBART_TINY, 1)
|
||||
logs = load_json(os.path.join(output_dir, "log_history.json"))
|
||||
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
|
||||
first_step_stats = eval_metrics[0]
|
||||
assert "eval_bleu" in first_step_stats
|
||||
|
||||
|
||||
@slow
|
||||
def test_finetune_trainer_slow():
|
||||
# TODO(SS): This will fail on devices with more than 1 GPU.
|
||||
# There is a missing call to __init__process_group somewhere
|
||||
output_dir = run_trainer(eval_steps=2, max_len="32", model_name=MARIAN_MODEL, num_train_epochs=3)
|
||||
|
||||
# Check metrics
|
||||
logs = load_json(os.path.join(output_dir, "log_history.json"))
|
||||
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
|
||||
first_step_stats = eval_metrics[0]
|
||||
last_step_stats = eval_metrics[-1]
|
||||
|
||||
assert first_step_stats["eval_bleu"] < last_step_stats["eval_bleu"] # model learned nothing
|
||||
assert isinstance(last_step_stats["eval_bleu"], float)
|
||||
|
||||
# test if do_predict saves generations and metrics
|
||||
contents = os.listdir(output_dir)
|
||||
contents = {os.path.basename(p) for p in contents}
|
||||
assert "test_generations.txt" in contents
|
||||
assert "test_results.json" in contents
|
||||
|
||||
|
||||
def run_trainer(eval_steps: int, max_len: str, model_name: str, num_train_epochs: int):
|
||||
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
|
||||
output_dir = tempfile.mkdtemp(prefix="marian_output")
|
||||
max_len = "128"
|
||||
num_train_epochs = 4
|
||||
eval_steps = 2
|
||||
output_dir = tempfile.mkdtemp(prefix="test_output")
|
||||
argv = [
|
||||
"--model_name_or_path",
|
||||
MARIAN_MODEL,
|
||||
model_name,
|
||||
"--data_dir",
|
||||
data_dir,
|
||||
"--output_dir",
|
||||
@@ -72,25 +90,17 @@ def test_finetune_trainer():
|
||||
"--sortish_sampler",
|
||||
"--label_smoothing",
|
||||
"0.1",
|
||||
# "--eval_beams",
|
||||
# "2",
|
||||
"--task",
|
||||
"translation",
|
||||
"--tgt_lang",
|
||||
"ro_RO",
|
||||
"--src_lang",
|
||||
"en_XX",
|
||||
]
|
||||
|
||||
testargs = ["finetune_trainer.py"] + argv
|
||||
with patch.object(sys, "argv", testargs):
|
||||
main()
|
||||
|
||||
# Check metrics
|
||||
logs = load_json(os.path.join(output_dir, "log_history.json"))
|
||||
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
|
||||
first_step_stats = eval_metrics[0]
|
||||
last_step_stats = eval_metrics[-1]
|
||||
|
||||
assert first_step_stats["eval_bleu"] < last_step_stats["eval_bleu"] # model learned nothing
|
||||
assert isinstance(last_step_stats["eval_bleu"], float)
|
||||
|
||||
# test if do_predict saves generations and metrics
|
||||
contents = os.listdir(output_dir)
|
||||
contents = {os.path.basename(p) for p in contents}
|
||||
assert "test_generations.txt" in contents
|
||||
assert "test_results.json" in contents
|
||||
return output_dir
|
||||
|
||||
Reference in New Issue
Block a user