From 9f7b2b243230a0ff7e61f48f852e3a5b6f6d86fa Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Sat, 17 Oct 2020 11:33:21 -0700 Subject: [PATCH] [s2s testing] turn all to unittests, use auto-delete temp dirs (#7859) --- examples/seq2seq/test_bash_script.py | 290 +++++------ .../seq2seq/test_data/wmt_en_ro/train.len | Bin 26 -> 34 bytes examples/seq2seq/test_data/wmt_en_ro/val.len | Bin 40 -> 48 bytes examples/seq2seq/test_datasets.py | 386 +++++++------- examples/seq2seq/test_finetune_trainer.py | 126 +++-- examples/seq2seq/test_seq2seq_examples.py | 488 +++++++++--------- 6 files changed, 635 insertions(+), 655 deletions(-) diff --git a/examples/seq2seq/test_bash_script.py b/examples/seq2seq/test_bash_script.py index 0c33771d44..24ce9bfe6b 100644 --- a/examples/seq2seq/test_bash_script.py +++ b/examples/seq2seq/test_bash_script.py @@ -3,7 +3,6 @@ import argparse import os import sys -import tempfile from pathlib import Path from unittest.mock import patch @@ -16,7 +15,7 @@ from distillation import BartSummarizationDistiller, distill_main from finetune import SummarizationModule, main from test_seq2seq_examples import CUDA_AVAILABLE, MBART_TINY from transformers import BartForConditionalGeneration, MarianMTModel -from transformers.testing_utils import slow +from transformers.testing_utils import TestCasePlus, slow from utils import load_json @@ -24,163 +23,164 @@ MODEL_NAME = MBART_TINY MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1" -@slow -@pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU") -def test_model_download(): - """This warms up the cache so that we can time the next test without including download time, which varies between machines.""" - BartForConditionalGeneration.from_pretrained(MODEL_NAME) - MarianMTModel.from_pretrained(MARIAN_MODEL) +class TestAll(TestCasePlus): + @slow + @pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU") + def test_model_download(self): + """This warms up the cache so that we can time the next test without including download time, which varies between machines.""" + BartForConditionalGeneration.from_pretrained(MODEL_NAME) + MarianMTModel.from_pretrained(MARIAN_MODEL) + @timeout_decorator.timeout(120) + @slow + @pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU") + def test_train_mbart_cc25_enro_script(self): + data_dir = "examples/seq2seq/test_data/wmt_en_ro" + env_vars_to_replace = { + "--fp16_opt_level=O1": "", + "$MAX_LEN": 128, + "$BS": 4, + "$GAS": 1, + "$ENRO_DIR": data_dir, + "facebook/mbart-large-cc25": MODEL_NAME, + # Download is 120MB in previous test. + "val_check_interval=0.25": "val_check_interval=1.0", + } -@timeout_decorator.timeout(120) -@slow -@pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU") -def test_train_mbart_cc25_enro_script(): - data_dir = "examples/seq2seq/test_data/wmt_en_ro" - env_vars_to_replace = { - "--fp16_opt_level=O1": "", - "$MAX_LEN": 128, - "$BS": 4, - "$GAS": 1, - "$ENRO_DIR": data_dir, - "facebook/mbart-large-cc25": MODEL_NAME, - # Download is 120MB in previous test. - "val_check_interval=0.25": "val_check_interval=1.0", - } + # Clean up bash script + bash_script = Path("examples/seq2seq/train_mbart_cc25_enro.sh").open().read().split("finetune.py")[1].strip() + bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "") + for k, v in env_vars_to_replace.items(): + bash_script = bash_script.replace(k, str(v)) + output_dir = self.get_auto_remove_tmp_dir() - # Clean up bash script - bash_script = Path("examples/seq2seq/train_mbart_cc25_enro.sh").open().read().split("finetune.py")[1].strip() - bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "") - for k, v in env_vars_to_replace.items(): - bash_script = bash_script.replace(k, str(v)) - output_dir = tempfile.mkdtemp(prefix="output_mbart") + bash_script = bash_script.replace("--fp16 ", "") + testargs = ( + ["finetune.py"] + + bash_script.split() + + [ + f"--output_dir={output_dir}", + "--gpus=1", + "--learning_rate=3e-1", + "--warmup_steps=0", + "--val_check_interval=1.0", + "--tokenizer_name=facebook/mbart-large-en-ro", + ] + ) + with patch.object(sys, "argv", testargs): + parser = argparse.ArgumentParser() + parser = pl.Trainer.add_argparse_args(parser) + parser = SummarizationModule.add_model_specific_args(parser, os.getcwd()) + args = parser.parse_args() + args.do_predict = False + # assert args.gpus == gpus THIS BREAKS for multigpu + model = main(args) - bash_script = bash_script.replace("--fp16 ", "") - testargs = ( - ["finetune.py"] - + bash_script.split() - + [ - f"--output_dir={output_dir}", - "--gpus=1", - "--learning_rate=3e-1", - "--warmup_steps=0", - "--val_check_interval=1.0", - "--tokenizer_name=facebook/mbart-large-en-ro", - ] - ) - with patch.object(sys, "argv", testargs): - parser = argparse.ArgumentParser() - parser = pl.Trainer.add_argparse_args(parser) - parser = SummarizationModule.add_model_specific_args(parser, os.getcwd()) - args = parser.parse_args() - args.do_predict = False - # assert args.gpus == gpus THIS BREAKS for multigpu - model = main(args) + # Check metrics + metrics = load_json(model.metrics_save_path) + first_step_stats = metrics["val"][0] + last_step_stats = metrics["val"][-1] + assert ( + len(metrics["val"]) == (args.max_epochs / args.val_check_interval) + 1 + ) # +1 accounts for val_sanity_check - # Check metrics - metrics = load_json(model.metrics_save_path) - first_step_stats = metrics["val"][0] - last_step_stats = metrics["val"][-1] - assert len(metrics["val"]) == (args.max_epochs / args.val_check_interval) + 1 # +1 accounts for val_sanity_check + assert last_step_stats["val_avg_gen_time"] >= 0.01 - assert last_step_stats["val_avg_gen_time"] >= 0.01 + assert first_step_stats["val_avg_bleu"] < last_step_stats["val_avg_bleu"] # model learned nothing + assert 1.0 >= last_step_stats["val_avg_gen_time"] # model hanging on generate. Maybe bad config was saved. + assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float) - assert first_step_stats["val_avg_bleu"] < last_step_stats["val_avg_bleu"] # model learned nothing - assert 1.0 >= last_step_stats["val_avg_gen_time"] # model hanging on generate. Maybe bad config was saved. - assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float) + # check lightning ckpt can be loaded and has a reasonable statedict + contents = os.listdir(output_dir) + ckpt_path = [x for x in contents if x.endswith(".ckpt")][0] + full_path = os.path.join(args.output_dir, ckpt_path) + ckpt = torch.load(full_path, map_location="cpu") + expected_key = "model.model.decoder.layers.0.encoder_attn_layer_norm.weight" + assert expected_key in ckpt["state_dict"] + assert ckpt["state_dict"]["model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32 - # check lightning ckpt can be loaded and has a reasonable statedict - contents = os.listdir(output_dir) - ckpt_path = [x for x in contents if x.endswith(".ckpt")][0] - full_path = os.path.join(args.output_dir, ckpt_path) - ckpt = torch.load(full_path, map_location="cpu") - expected_key = "model.model.decoder.layers.0.encoder_attn_layer_norm.weight" - assert expected_key in ckpt["state_dict"] - assert ckpt["state_dict"]["model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32 + # TODO: turn on args.do_predict when PL bug fixed. + if args.do_predict: + contents = {os.path.basename(p) for p in contents} + assert "test_generations.txt" in contents + assert "test_results.txt" in contents + # assert len(metrics["val"]) == desired_n_evals + assert len(metrics["test"]) == 1 - # TODO: turn on args.do_predict when PL bug fixed. - if args.do_predict: - contents = {os.path.basename(p) for p in contents} - assert "test_generations.txt" in contents - assert "test_results.txt" in contents - # assert len(metrics["val"]) == desired_n_evals - assert len(metrics["test"]) == 1 + @timeout_decorator.timeout(600) + @slow + @pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU") + def test_opus_mt_distill_script(self): + data_dir = "examples/seq2seq/test_data/wmt_en_ro" + env_vars_to_replace = { + "--fp16_opt_level=O1": "", + "$MAX_LEN": 128, + "$BS": 16, + "$GAS": 1, + "$ENRO_DIR": data_dir, + "$m": "sshleifer/student_marian_en_ro_6_1", + "val_check_interval=0.25": "val_check_interval=1.0", + } + # Clean up bash script + bash_script = ( + Path("examples/seq2seq/distil_marian_no_teacher.sh").open().read().split("distillation.py")[1].strip() + ) + bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "") + bash_script = bash_script.replace("--fp16 ", " ") -@timeout_decorator.timeout(600) -@slow -@pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU") -def test_opus_mt_distill_script(): - data_dir = "examples/seq2seq/test_data/wmt_en_ro" - env_vars_to_replace = { - "--fp16_opt_level=O1": "", - "$MAX_LEN": 128, - "$BS": 16, - "$GAS": 1, - "$ENRO_DIR": data_dir, - "$m": "sshleifer/student_marian_en_ro_6_1", - "val_check_interval=0.25": "val_check_interval=1.0", - } + for k, v in env_vars_to_replace.items(): + bash_script = bash_script.replace(k, str(v)) + output_dir = self.get_auto_remove_tmp_dir() + bash_script = bash_script.replace("--fp16", "") + epochs = 6 + testargs = ( + ["distillation.py"] + + bash_script.split() + + [ + f"--output_dir={output_dir}", + "--gpus=1", + "--learning_rate=1e-3", + f"--num_train_epochs={epochs}", + "--warmup_steps=10", + "--val_check_interval=1.0", + ] + ) + with patch.object(sys, "argv", testargs): + parser = argparse.ArgumentParser() + parser = pl.Trainer.add_argparse_args(parser) + parser = BartSummarizationDistiller.add_model_specific_args(parser, os.getcwd()) + args = parser.parse_args() + args.do_predict = False + # assert args.gpus == gpus THIS BREAKS for multigpu - # Clean up bash script - bash_script = ( - Path("examples/seq2seq/distil_marian_no_teacher.sh").open().read().split("distillation.py")[1].strip() - ) - bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "") - bash_script = bash_script.replace("--fp16 ", " ") + model = distill_main(args) - for k, v in env_vars_to_replace.items(): - bash_script = bash_script.replace(k, str(v)) - output_dir = tempfile.mkdtemp(prefix="marian_output") - bash_script = bash_script.replace("--fp16", "") - epochs = 6 - testargs = ( - ["distillation.py"] - + bash_script.split() - + [ - f"--output_dir={output_dir}", - "--gpus=1", - "--learning_rate=1e-3", - f"--num_train_epochs={epochs}", - "--warmup_steps=10", - "--val_check_interval=1.0", - ] - ) - with patch.object(sys, "argv", testargs): - parser = argparse.ArgumentParser() - parser = pl.Trainer.add_argparse_args(parser) - parser = BartSummarizationDistiller.add_model_specific_args(parser, os.getcwd()) - args = parser.parse_args() - args.do_predict = False - # assert args.gpus == gpus THIS BREAKS for multigpu + # Check metrics + metrics = load_json(model.metrics_save_path) + first_step_stats = metrics["val"][0] + last_step_stats = metrics["val"][-1] + assert len(metrics["val"]) >= (args.max_epochs / args.val_check_interval) # +1 accounts for val_sanity_check - model = distill_main(args) + assert last_step_stats["val_avg_gen_time"] >= 0.01 - # Check metrics - metrics = load_json(model.metrics_save_path) - first_step_stats = metrics["val"][0] - last_step_stats = metrics["val"][-1] - assert len(metrics["val"]) >= (args.max_epochs / args.val_check_interval) # +1 accounts for val_sanity_check + assert first_step_stats["val_avg_bleu"] < last_step_stats["val_avg_bleu"] # model learned nothing + assert 1.0 >= last_step_stats["val_avg_gen_time"] # model hanging on generate. Maybe bad config was saved. + assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float) - assert last_step_stats["val_avg_gen_time"] >= 0.01 + # check lightning ckpt can be loaded and has a reasonable statedict + contents = os.listdir(output_dir) + ckpt_path = [x for x in contents if x.endswith(".ckpt")][0] + full_path = os.path.join(args.output_dir, ckpt_path) + ckpt = torch.load(full_path, map_location="cpu") + expected_key = "model.model.decoder.layers.0.encoder_attn_layer_norm.weight" + assert expected_key in ckpt["state_dict"] + assert ckpt["state_dict"]["model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32 - assert first_step_stats["val_avg_bleu"] < last_step_stats["val_avg_bleu"] # model learned nothing - assert 1.0 >= last_step_stats["val_avg_gen_time"] # model hanging on generate. Maybe bad config was saved. - assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float) - - # check lightning ckpt can be loaded and has a reasonable statedict - contents = os.listdir(output_dir) - ckpt_path = [x for x in contents if x.endswith(".ckpt")][0] - full_path = os.path.join(args.output_dir, ckpt_path) - ckpt = torch.load(full_path, map_location="cpu") - expected_key = "model.model.decoder.layers.0.encoder_attn_layer_norm.weight" - assert expected_key in ckpt["state_dict"] - assert ckpt["state_dict"]["model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32 - - # TODO: turn on args.do_predict when PL bug fixed. - if args.do_predict: - contents = {os.path.basename(p) for p in contents} - assert "test_generations.txt" in contents - assert "test_results.txt" in contents - # assert len(metrics["val"]) == desired_n_evals - assert len(metrics["test"]) == 1 + # TODO: turn on args.do_predict when PL bug fixed. + if args.do_predict: + contents = {os.path.basename(p) for p in contents} + assert "test_generations.txt" in contents + assert "test_results.txt" in contents + # assert len(metrics["val"]) == desired_n_evals + assert len(metrics["test"]) == 1 diff --git a/examples/seq2seq/test_data/wmt_en_ro/train.len b/examples/seq2seq/test_data/wmt_en_ro/train.len index 33ce003c8ae3139914a389a714812a2ab13aece4..2632a33e8b8a3a601bdf93f3b8e6783a0c7bba60 100644 GIT binary patch literal 34 lcmZo*nJUfz0kKmwye+)ry(fEDdzX5%dDnR7dRM3F0RV-J2?PKD literal 26 hcmZo*jxA)+@V4-d_nz!s?Op24=3V2R>s_6y2LNfV2s8iy diff --git a/examples/seq2seq/test_data/wmt_en_ro/val.len b/examples/seq2seq/test_data/wmt_en_ro/val.len index 897314a960b28d927b597805693e63f9de71d903..fdf8fa353eb8d41a6d92ccae9df4b2f8aaf970b5 100644 GIT binary patch delta 19 WcmdNe;B8=;s>%QXu~R0pDgyu?i~~*p delta 11 ScmXreU~OQIEo7L;s|)}RKmu?8 diff --git a/examples/seq2seq/test_datasets.py b/examples/seq2seq/test_datasets.py index dbd6769ad2..4b5c95ed4e 100644 --- a/examples/seq2seq/test_datasets.py +++ b/examples/seq2seq/test_datasets.py @@ -1,5 +1,4 @@ import os -import tempfile from pathlib import Path import numpy as np @@ -7,11 +6,12 @@ import pytest from torch.utils.data import DataLoader from pack_dataset import pack_data_dir +from parameterized import parameterized from save_len_file import save_len_file from test_seq2seq_examples import ARTICLES, BART_TINY, MARIAN_TINY, MBART_TINY, SUMMARIES, T5_TINY, make_test_data_dir from transformers import AutoTokenizer from transformers.modeling_bart import shift_tokens_right -from transformers.testing_utils import slow +from transformers.testing_utils import TestCasePlus, slow from utils import FAIRSEQ_AVAILABLE, DistributedSortishSampler, LegacySeq2SeqDataset, Seq2SeqDataset @@ -19,202 +19,198 @@ BERT_BASE_CASED = "bert-base-cased" PEGASUS_XSUM = "google/pegasus-xsum" -@slow -@pytest.mark.parametrize( - "tok_name", - [ - MBART_TINY, - MARIAN_TINY, - T5_TINY, - BART_TINY, - PEGASUS_XSUM, - ], -) -def test_seq2seq_dataset_truncation(tok_name): - tokenizer = AutoTokenizer.from_pretrained(tok_name) - tmp_dir = make_test_data_dir() - max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES) - max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES) - max_src_len = 4 - max_tgt_len = 8 - assert max_len_target > max_src_len # Will be truncated - assert max_len_source > max_src_len # Will be truncated - src_lang, tgt_lang = "ro_RO", "de_DE" # ignored for all but mbart, but never causes error. - train_dataset = Seq2SeqDataset( - tokenizer, - data_dir=tmp_dir, - type_path="train", - max_source_length=max_src_len, - max_target_length=max_tgt_len, # ignored - src_lang=src_lang, - tgt_lang=tgt_lang, +class TestAll(TestCasePlus): + @parameterized.expand( + [ + MBART_TINY, + MARIAN_TINY, + T5_TINY, + BART_TINY, + PEGASUS_XSUM, + ], ) - dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn) - for batch in dataloader: - assert isinstance(batch, dict) - assert batch["attention_mask"].shape == batch["input_ids"].shape - # show that articles were trimmed. - assert batch["input_ids"].shape[1] == max_src_len - # show that targets are the same len - assert batch["labels"].shape[1] == max_tgt_len - if tok_name != MBART_TINY: - continue - # check language codes in correct place - batch["decoder_input_ids"] = shift_tokens_right(batch["labels"], tokenizer.pad_token_id) - assert batch["decoder_input_ids"][0, 0].item() == tokenizer.lang_code_to_id[tgt_lang] - assert batch["decoder_input_ids"][0, -1].item() == tokenizer.eos_token_id - assert batch["input_ids"][0, -2].item() == tokenizer.eos_token_id - assert batch["input_ids"][0, -1].item() == tokenizer.lang_code_to_id[src_lang] - - break # No need to test every batch - - -@pytest.mark.parametrize("tok", [BART_TINY, BERT_BASE_CASED]) -def test_legacy_dataset_truncation(tok): - tokenizer = AutoTokenizer.from_pretrained(tok) - tmp_dir = make_test_data_dir() - max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES) - max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES) - trunc_target = 4 - train_dataset = LegacySeq2SeqDataset( - tokenizer, - data_dir=tmp_dir, - type_path="train", - max_source_length=20, - max_target_length=trunc_target, - ) - dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn) - for batch in dataloader: - assert batch["attention_mask"].shape == batch["input_ids"].shape - # show that articles were trimmed. - assert batch["input_ids"].shape[1] == max_len_source - assert 20 >= batch["input_ids"].shape[1] # trimmed significantly - # show that targets were truncated - assert batch["labels"].shape[1] == trunc_target # Truncated - assert max_len_target > trunc_target # Truncated - break # No need to test every batch - - -def test_pack_dataset(): - tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25") - - tmp_dir = Path(make_test_data_dir()) - orig_examples = tmp_dir.joinpath("train.source").open().readlines() - save_dir = Path(tempfile.mkdtemp(prefix="packed_")) - pack_data_dir(tokenizer, tmp_dir, 128, save_dir) - orig_paths = {x.name for x in tmp_dir.iterdir()} - new_paths = {x.name for x in save_dir.iterdir()} - packed_examples = save_dir.joinpath("train.source").open().readlines() - # orig: [' Sam ate lunch today.\n', 'Sams lunch ingredients.'] - # desired_packed: [' Sam ate lunch today.\n Sams lunch ingredients.'] - assert len(packed_examples) < len(orig_examples) - assert len(packed_examples) == 1 - assert len(packed_examples[0]) == sum(len(x) for x in orig_examples) - assert orig_paths == new_paths - - -@pytest.mark.skipif(not FAIRSEQ_AVAILABLE, reason="This test requires fairseq") -def test_dynamic_batch_size(): - if not FAIRSEQ_AVAILABLE: - return - ds, max_tokens, tokenizer = _get_dataset(max_len=64) - required_batch_size_multiple = 64 - batch_sampler = ds.make_dynamic_sampler(max_tokens, required_batch_size_multiple=required_batch_size_multiple) - batch_sizes = [len(x) for x in batch_sampler] - assert len(set(batch_sizes)) > 1 # it's not dynamic batch size if every batch is the same length - assert sum(batch_sizes) == len(ds) # no dropped or added examples - data_loader = DataLoader(ds, batch_sampler=batch_sampler, collate_fn=ds.collate_fn, num_workers=2) - failures = [] - num_src_per_batch = [] - for batch in data_loader: - src_shape = batch["input_ids"].shape - bs = src_shape[0] - assert bs % required_batch_size_multiple == 0 or bs < required_batch_size_multiple - num_src_tokens = np.product(batch["input_ids"].shape) - num_src_per_batch.append(num_src_tokens) - if num_src_tokens > (max_tokens * 1.1): - failures.append(num_src_tokens) - assert num_src_per_batch[0] == max(num_src_per_batch) - if failures: - raise AssertionError(f"too many tokens in {len(failures)} batches") - - -def test_sortish_sampler_reduces_padding(): - ds, _, tokenizer = _get_dataset(max_len=512) - bs = 2 - sortish_sampler = ds.make_sortish_sampler(bs, shuffle=False) - - naive_dl = DataLoader(ds, batch_size=bs, collate_fn=ds.collate_fn, num_workers=2) - sortish_dl = DataLoader(ds, batch_size=bs, collate_fn=ds.collate_fn, num_workers=2, sampler=sortish_sampler) - - pad = tokenizer.pad_token_id - - def count_pad_tokens(data_loader, k="input_ids"): - return [batch[k].eq(pad).sum().item() for batch in data_loader] - - assert sum(count_pad_tokens(sortish_dl, k="labels")) < sum(count_pad_tokens(naive_dl, k="labels")) - assert sum(count_pad_tokens(sortish_dl)) < sum(count_pad_tokens(naive_dl)) - assert len(sortish_dl) == len(naive_dl) - - -def _get_dataset(n_obs=1000, max_len=128): - if os.getenv("USE_REAL_DATA", False): - data_dir = "examples/seq2seq/wmt_en_ro" - max_tokens = max_len * 2 * 64 - if not Path(data_dir).joinpath("train.len").exists(): - save_len_file(MARIAN_TINY, data_dir) - else: - data_dir = "examples/seq2seq/test_data/wmt_en_ro" - max_tokens = max_len * 4 - save_len_file(MARIAN_TINY, data_dir) - - tokenizer = AutoTokenizer.from_pretrained(MARIAN_TINY) - ds = Seq2SeqDataset( - tokenizer, - data_dir=data_dir, - type_path="train", - max_source_length=max_len, - max_target_length=max_len, - n_obs=n_obs, - ) - return ds, max_tokens, tokenizer - - -def test_distributed_sortish_sampler_splits_indices_between_procs(): - ds, max_tokens, tokenizer = _get_dataset() - ids1 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=0, add_extra_examples=False)) - ids2 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=1, add_extra_examples=False)) - assert ids1.intersection(ids2) == set() - - -@pytest.mark.parametrize( - "tok_name", - [ - MBART_TINY, - MARIAN_TINY, - T5_TINY, - BART_TINY, - PEGASUS_XSUM, - ], -) -def test_dataset_kwargs(tok_name): - tokenizer = AutoTokenizer.from_pretrained(tok_name) - if tok_name == MBART_TINY: + @slow + def test_seq2seq_dataset_truncation(self, tok_name): + tokenizer = AutoTokenizer.from_pretrained(tok_name) + tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()) + max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES) + max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES) + max_src_len = 4 + max_tgt_len = 8 + assert max_len_target > max_src_len # Will be truncated + assert max_len_source > max_src_len # Will be truncated + src_lang, tgt_lang = "ro_RO", "de_DE" # ignored for all but mbart, but never causes error. train_dataset = Seq2SeqDataset( tokenizer, - data_dir=make_test_data_dir(), + data_dir=tmp_dir, type_path="train", - max_source_length=4, - max_target_length=8, - src_lang="EN", - tgt_lang="FR", + max_source_length=max_src_len, + max_target_length=max_tgt_len, # ignored + src_lang=src_lang, + tgt_lang=tgt_lang, ) - kwargs = train_dataset.dataset_kwargs - assert "src_lang" in kwargs and "tgt_lang" in kwargs - else: - train_dataset = Seq2SeqDataset( - tokenizer, data_dir=make_test_data_dir(), type_path="train", max_source_length=4, max_target_length=8 + dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn) + for batch in dataloader: + assert isinstance(batch, dict) + assert batch["attention_mask"].shape == batch["input_ids"].shape + # show that articles were trimmed. + assert batch["input_ids"].shape[1] == max_src_len + # show that targets are the same len + assert batch["labels"].shape[1] == max_tgt_len + if tok_name != MBART_TINY: + continue + # check language codes in correct place + batch["decoder_input_ids"] = shift_tokens_right(batch["labels"], tokenizer.pad_token_id) + assert batch["decoder_input_ids"][0, 0].item() == tokenizer.lang_code_to_id[tgt_lang] + assert batch["decoder_input_ids"][0, -1].item() == tokenizer.eos_token_id + assert batch["input_ids"][0, -2].item() == tokenizer.eos_token_id + assert batch["input_ids"][0, -1].item() == tokenizer.lang_code_to_id[src_lang] + + break # No need to test every batch + + @parameterized.expand([BART_TINY, BERT_BASE_CASED]) + def test_legacy_dataset_truncation(self, tok): + tokenizer = AutoTokenizer.from_pretrained(tok) + tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()) + max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES) + max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES) + trunc_target = 4 + train_dataset = LegacySeq2SeqDataset( + tokenizer, + data_dir=tmp_dir, + type_path="train", + max_source_length=20, + max_target_length=trunc_target, ) - kwargs = train_dataset.dataset_kwargs - assert "add_prefix_space" not in kwargs if tok_name != BART_TINY else "add_prefix_space" in kwargs - assert len(kwargs) == 1 if tok_name == BART_TINY else len(kwargs) == 0 + dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn) + for batch in dataloader: + assert batch["attention_mask"].shape == batch["input_ids"].shape + # show that articles were trimmed. + assert batch["input_ids"].shape[1] == max_len_source + assert 20 >= batch["input_ids"].shape[1] # trimmed significantly + # show that targets were truncated + assert batch["labels"].shape[1] == trunc_target # Truncated + assert max_len_target > trunc_target # Truncated + break # No need to test every batch + + def test_pack_dataset(self): + tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25") + + tmp_dir = Path(make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())) + orig_examples = tmp_dir.joinpath("train.source").open().readlines() + save_dir = Path(make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())) + pack_data_dir(tokenizer, tmp_dir, 128, save_dir) + orig_paths = {x.name for x in tmp_dir.iterdir()} + new_paths = {x.name for x in save_dir.iterdir()} + packed_examples = save_dir.joinpath("train.source").open().readlines() + # orig: [' Sam ate lunch today.\n', 'Sams lunch ingredients.'] + # desired_packed: [' Sam ate lunch today.\n Sams lunch ingredients.'] + assert len(packed_examples) < len(orig_examples) + assert len(packed_examples) == 1 + assert len(packed_examples[0]) == sum(len(x) for x in orig_examples) + assert orig_paths == new_paths + + @pytest.mark.skipif(not FAIRSEQ_AVAILABLE, reason="This test requires fairseq") + def test_dynamic_batch_size(self): + if not FAIRSEQ_AVAILABLE: + return + ds, max_tokens, tokenizer = self._get_dataset(max_len=64) + required_batch_size_multiple = 64 + batch_sampler = ds.make_dynamic_sampler(max_tokens, required_batch_size_multiple=required_batch_size_multiple) + batch_sizes = [len(x) for x in batch_sampler] + assert len(set(batch_sizes)) > 1 # it's not dynamic batch size if every batch is the same length + assert sum(batch_sizes) == len(ds) # no dropped or added examples + data_loader = DataLoader(ds, batch_sampler=batch_sampler, collate_fn=ds.collate_fn, num_workers=2) + failures = [] + num_src_per_batch = [] + for batch in data_loader: + src_shape = batch["input_ids"].shape + bs = src_shape[0] + assert bs % required_batch_size_multiple == 0 or bs < required_batch_size_multiple + num_src_tokens = np.product(batch["input_ids"].shape) + num_src_per_batch.append(num_src_tokens) + if num_src_tokens > (max_tokens * 1.1): + failures.append(num_src_tokens) + assert num_src_per_batch[0] == max(num_src_per_batch) + if failures: + raise AssertionError(f"too many tokens in {len(failures)} batches") + + def test_sortish_sampler_reduces_padding(self): + ds, _, tokenizer = self._get_dataset(max_len=512) + bs = 2 + sortish_sampler = ds.make_sortish_sampler(bs, shuffle=False) + + naive_dl = DataLoader(ds, batch_size=bs, collate_fn=ds.collate_fn, num_workers=2) + sortish_dl = DataLoader(ds, batch_size=bs, collate_fn=ds.collate_fn, num_workers=2, sampler=sortish_sampler) + + pad = tokenizer.pad_token_id + + def count_pad_tokens(data_loader, k="input_ids"): + return [batch[k].eq(pad).sum().item() for batch in data_loader] + + assert sum(count_pad_tokens(sortish_dl, k="labels")) < sum(count_pad_tokens(naive_dl, k="labels")) + assert sum(count_pad_tokens(sortish_dl)) < sum(count_pad_tokens(naive_dl)) + assert len(sortish_dl) == len(naive_dl) + + def _get_dataset(self, n_obs=1000, max_len=128): + if os.getenv("USE_REAL_DATA", False): + data_dir = "examples/seq2seq/wmt_en_ro" + max_tokens = max_len * 2 * 64 + if not Path(data_dir).joinpath("train.len").exists(): + save_len_file(MARIAN_TINY, data_dir) + else: + data_dir = "examples/seq2seq/test_data/wmt_en_ro" + max_tokens = max_len * 4 + save_len_file(MARIAN_TINY, data_dir) + + tokenizer = AutoTokenizer.from_pretrained(MARIAN_TINY) + ds = Seq2SeqDataset( + tokenizer, + data_dir=data_dir, + type_path="train", + max_source_length=max_len, + max_target_length=max_len, + n_obs=n_obs, + ) + return ds, max_tokens, tokenizer + + def test_distributed_sortish_sampler_splits_indices_between_procs(self): + ds, max_tokens, tokenizer = self._get_dataset() + ids1 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=0, add_extra_examples=False)) + ids2 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=1, add_extra_examples=False)) + assert ids1.intersection(ids2) == set() + + @parameterized.expand( + [ + MBART_TINY, + MARIAN_TINY, + T5_TINY, + BART_TINY, + PEGASUS_XSUM, + ], + ) + def test_dataset_kwargs(self, tok_name): + tokenizer = AutoTokenizer.from_pretrained(tok_name) + if tok_name == MBART_TINY: + train_dataset = Seq2SeqDataset( + tokenizer, + data_dir=make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()), + type_path="train", + max_source_length=4, + max_target_length=8, + src_lang="EN", + tgt_lang="FR", + ) + kwargs = train_dataset.dataset_kwargs + assert "src_lang" in kwargs and "tgt_lang" in kwargs + else: + train_dataset = Seq2SeqDataset( + tokenizer, + data_dir=make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()), + type_path="train", + max_source_length=4, + max_target_length=8, + ) + kwargs = train_dataset.dataset_kwargs + assert "add_prefix_space" not in kwargs if tok_name != BART_TINY else "add_prefix_space" in kwargs + assert len(kwargs) == 1 if tok_name == BART_TINY else len(kwargs) == 0 diff --git a/examples/seq2seq/test_finetune_trainer.py b/examples/seq2seq/test_finetune_trainer.py index 18ebd1e695..f335c1525e 100644 --- a/examples/seq2seq/test_finetune_trainer.py +++ b/examples/seq2seq/test_finetune_trainer.py @@ -1,9 +1,8 @@ import os import sys -import tempfile from unittest.mock import patch -from transformers.testing_utils import slow +from transformers.testing_utils import TestCasePlus, slow from transformers.trainer_callback import TrainerState from transformers.trainer_utils import set_seed @@ -15,72 +14,71 @@ set_seed(42) MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1" -def test_finetune_trainer(): - output_dir = run_trainer(1, "12", MBART_TINY, 1) - logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history - eval_metrics = [log for log in logs if "eval_loss" in log.keys()] - first_step_stats = eval_metrics[0] - assert "eval_bleu" in first_step_stats +class TestFinetuneTrainer(TestCasePlus): + def test_finetune_trainer(self): + output_dir = self.run_trainer(1, "12", MBART_TINY, 1) + logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history + eval_metrics = [log for log in logs if "eval_loss" in log.keys()] + first_step_stats = eval_metrics[0] + assert "eval_bleu" in first_step_stats + @slow + def test_finetune_trainer_slow(self): + # There is a missing call to __init__process_group somewhere + output_dir = self.run_trainer(eval_steps=2, max_len="128", model_name=MARIAN_MODEL, num_train_epochs=3) -@slow -def test_finetune_trainer_slow(): - # There is a missing call to __init__process_group somewhere - output_dir = run_trainer(eval_steps=2, max_len="128", model_name=MARIAN_MODEL, num_train_epochs=3) + # Check metrics + logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history + eval_metrics = [log for log in logs if "eval_loss" in log.keys()] + first_step_stats = eval_metrics[0] + last_step_stats = eval_metrics[-1] - # Check metrics - logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history - eval_metrics = [log for log in logs if "eval_loss" in log.keys()] - first_step_stats = eval_metrics[0] - last_step_stats = eval_metrics[-1] + assert first_step_stats["eval_bleu"] < last_step_stats["eval_bleu"] # model learned nothing + assert isinstance(last_step_stats["eval_bleu"], float) - assert first_step_stats["eval_bleu"] < last_step_stats["eval_bleu"] # model learned nothing - assert isinstance(last_step_stats["eval_bleu"], float) + # test if do_predict saves generations and metrics + contents = os.listdir(output_dir) + contents = {os.path.basename(p) for p in contents} + assert "test_generations.txt" in contents + assert "test_results.json" in contents - # test if do_predict saves generations and metrics - contents = os.listdir(output_dir) - contents = {os.path.basename(p) for p in contents} - assert "test_generations.txt" in contents - assert "test_results.json" in contents + def run_trainer(self, eval_steps: int, max_len: str, model_name: str, num_train_epochs: int): + data_dir = "examples/seq2seq/test_data/wmt_en_ro" + output_dir = self.get_auto_remove_tmp_dir() + argv = f""" + --model_name_or_path {model_name} + --data_dir {data_dir} + --output_dir {output_dir} + --overwrite_output_dir + --n_train 8 + --n_val 8 + --max_source_length {max_len} + --max_target_length {max_len} + --val_max_target_length {max_len} + --do_train + --do_eval + --do_predict + --num_train_epochs {str(num_train_epochs)} + --per_device_train_batch_size 4 + --per_device_eval_batch_size 4 + --learning_rate 3e-4 + --warmup_steps 8 + --evaluate_during_training + --predict_with_generate + --logging_steps 0 + --save_steps {str(eval_steps)} + --eval_steps {str(eval_steps)} + --sortish_sampler + --label_smoothing 0.1 + --adafactor + --task translation + --tgt_lang ro_RO + --src_lang en_XX + """.split() + # --eval_beams 2 + testargs = ["finetune_trainer.py"] + argv + with patch.object(sys, "argv", testargs): + main() -def run_trainer(eval_steps: int, max_len: str, model_name: str, num_train_epochs: int): - data_dir = "examples/seq2seq/test_data/wmt_en_ro" - output_dir = tempfile.mkdtemp(prefix="test_output") - argv = f""" - --model_name_or_path {model_name} - --data_dir {data_dir} - --output_dir {output_dir} - --overwrite_output_dir - --n_train 8 - --n_val 8 - --max_source_length {max_len} - --max_target_length {max_len} - --val_max_target_length {max_len} - --do_train - --do_eval - --do_predict - --num_train_epochs {str(num_train_epochs)} - --per_device_train_batch_size 4 - --per_device_eval_batch_size 4 - --learning_rate 3e-4 - --warmup_steps 8 - --evaluate_during_training - --predict_with_generate - --logging_steps 0 - --save_steps {str(eval_steps)} - --eval_steps {str(eval_steps)} - --sortish_sampler - --label_smoothing 0.1 - --adafactor - --task translation - --tgt_lang ro_RO - --src_lang en_XX - """.split() - # --eval_beams 2 - - testargs = ["finetune_trainer.py"] + argv - with patch.object(sys, "argv", testargs): - main() - - return output_dir + return output_dir diff --git a/examples/seq2seq/test_seq2seq_examples.py b/examples/seq2seq/test_seq2seq_examples.py index b1c82c1e5d..51a6e6633c 100644 --- a/examples/seq2seq/test_seq2seq_examples.py +++ b/examples/seq2seq/test_seq2seq_examples.py @@ -3,7 +3,6 @@ import logging import os import sys import tempfile -import unittest from pathlib import Path from unittest.mock import patch @@ -15,11 +14,12 @@ import lightning_base from convert_pl_checkpoint_to_hf import convert_pl_to_hf from distillation import distill_main from finetune import SummarizationModule, main +from parameterized import parameterized from run_eval import generate_summaries_or_translations, run_generate from run_eval_search import run_search from transformers import AutoConfig, AutoModelForSeq2SeqLM from transformers.hf_api import HfApi -from transformers.testing_utils import CaptureStderr, CaptureStdout, require_multigpu, require_torch_and_cuda, slow +from transformers.testing_utils import CaptureStderr, CaptureStdout, TestCasePlus, require_torch_and_cuda, slow from utils import ROUGE_KEYS, label_smoothed_nll_loss, lmap, load_json @@ -52,7 +52,7 @@ CHEAP_ARGS = { "student_decoder_layers": 1, "val_check_interval": 1.0, "output_dir": "", - "fp16": False, # TODO: set this to CUDA_AVAILABLE if ci installs apex or start using native amp + "fp16": False, # TODO(SS): set this to CUDA_AVAILABLE if ci installs apex or start using native amp "no_teacher": False, "fp16_opt_level": "O1", "gpus": 1 if CUDA_AVAILABLE else 0, @@ -88,6 +88,7 @@ CHEAP_ARGS = { "student_encoder_layers": 1, "freeze_encoder": False, "auto_scale_batch_size": False, + "overwrite_output_dir": False, } @@ -110,15 +111,14 @@ logger.addHandler(stream_handler) logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks -def make_test_data_dir(**kwargs): - tmp_dir = Path(tempfile.mkdtemp(**kwargs)) +def make_test_data_dir(tmp_dir): for split in ["train", "val", "test"]: - _dump_articles((tmp_dir / f"{split}.source"), ARTICLES) - _dump_articles((tmp_dir / f"{split}.target"), SUMMARIES) + _dump_articles(os.path.join(tmp_dir, f"{split}.source"), ARTICLES) + _dump_articles(os.path.join(tmp_dir, f"{split}.target"), SUMMARIES) return tmp_dir -class TestSummarizationDistiller(unittest.TestCase): +class TestSummarizationDistiller(TestCasePlus): @classmethod def setUpClass(cls): logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks @@ -143,17 +143,6 @@ class TestSummarizationDistiller(unittest.TestCase): failures.append(m) assert not failures, f"The following models could not be loaded through AutoConfig: {failures}" - @require_multigpu - @unittest.skip("Broken at the moment") - def test_multigpu(self): - updates = dict( - no_teacher=True, - freeze_encoder=True, - gpus=2, - sortish_sampler=True, - ) - self._test_distiller_cli(updates, check_contents=False) - def test_distill_no_teacher(self): updates = dict(student_encoder_layers=2, student_decoder_layers=1, no_teacher=True) self._test_distiller_cli(updates) @@ -173,12 +162,12 @@ class TestSummarizationDistiller(unittest.TestCase): self.assertEqual(1, len(ckpts)) transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin")) self.assertEqual(len(transformer_ckpts), 2) - examples = lmap(str.strip, model.hparams.data_dir.joinpath("test.source").open().readlines()) - out_path = tempfile.mktemp() + examples = lmap(str.strip, Path(model.hparams.data_dir).joinpath("test.source").open().readlines()) + out_path = tempfile.mktemp() # XXX: not being cleaned up generate_summaries_or_translations(examples, out_path, str(model.output_dir / "best_tfmr")) self.assertTrue(Path(out_path).exists()) - out_path_new = tempfile.mkdtemp() + out_path_new = self.get_auto_remove_tmp_dir() convert_pl_to_hf(ckpts[0], transformer_ckpts[0].parent, out_path_new) assert os.path.exists(os.path.join(out_path_new, "pytorch_model.bin")) @@ -253,8 +242,8 @@ class TestSummarizationDistiller(unittest.TestCase): ) default_updates.update(updates) args_d: dict = CHEAP_ARGS.copy() - tmp_dir = make_test_data_dir() - output_dir = tempfile.mkdtemp(prefix="output_") + tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()) + output_dir = self.get_auto_remove_tmp_dir() args_d.update(data_dir=tmp_dir, output_dir=output_dir, **default_updates) model = distill_main(argparse.Namespace(**args_d)) @@ -279,256 +268,253 @@ class TestSummarizationDistiller(unittest.TestCase): return model -def run_eval_tester(model): - input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source" - output_file_name = input_file_name.parent / "utest_output.txt" - assert not output_file_name.exists() - articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."] - _dump_articles(input_file_name, articles) - score_path = str(Path(tempfile.mkdtemp()) / "scores.json") - task = "translation_en_to_de" if model == T5_TINY else "summarization" - testargs = f""" - run_eval_search.py - {model} - {input_file_name} - {output_file_name} - --score_path {score_path} - --task {task} - --num_beams 2 - --length_penalty 2.0 - """.split() +class TestTheRest(TestCasePlus): + def run_eval_tester(self, model): + input_file_name = Path(self.get_auto_remove_tmp_dir()) / "utest_input.source" + output_file_name = input_file_name.parent / "utest_output.txt" + assert not output_file_name.exists() + articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."] + _dump_articles(input_file_name, articles) - with patch.object(sys, "argv", testargs): - run_generate() - assert Path(output_file_name).exists() - os.remove(Path(output_file_name)) + score_path = str(Path(self.get_auto_remove_tmp_dir()) / "scores.json") + task = "translation_en_to_de" if model == T5_TINY else "summarization" + testargs = f""" + run_eval_search.py + {model} + {input_file_name} + {output_file_name} + --score_path {score_path} + --task {task} + --num_beams 2 + --length_penalty 2.0 + """.split() + with patch.object(sys, "argv", testargs): + run_generate() + assert Path(output_file_name).exists() + # os.remove(Path(output_file_name)) -# test one model to quickly (no-@slow) catch simple problems and do an -# extensive testing of functionality with multiple models as @slow separately -def test_run_eval(): - run_eval_tester(T5_TINY) + # test one model to quickly (no-@slow) catch simple problems and do an + # extensive testing of functionality with multiple models as @slow separately + def test_run_eval(self): + self.run_eval_tester(T5_TINY) + # any extra models should go into the list here - can be slow + @parameterized.expand([BART_TINY, MBART_TINY]) + @slow + def test_run_eval_slow(self, model): + self.run_eval_tester(model) -# any extra models should go into the list here - can be slow -@slow -@pytest.mark.parametrize("model", [BART_TINY, MBART_TINY]) -def test_run_eval_slow(model): - run_eval_tester(model) + # testing with 2 models to validate: 1. translation (t5) 2. summarization (mbart) + @parameterized.expand([T5_TINY, MBART_TINY]) + @slow + def test_run_eval_search(self, model): + input_file_name = Path(self.get_auto_remove_tmp_dir()) / "utest_input.source" + output_file_name = input_file_name.parent / "utest_output.txt" + assert not output_file_name.exists() + text = { + "en": ["Machine learning is great, isn't it?", "I like to eat bananas", "Tomorrow is another great day!"], + "de": [ + "Maschinelles Lernen ist großartig, oder?", + "Ich esse gerne Bananen", + "Morgen ist wieder ein toller Tag!", + ], + } -# testing with 2 models to validate: 1. translation (t5) 2. summarization (mbart) -@slow -@pytest.mark.parametrize("model", [T5_TINY, MBART_TINY]) -def test_run_eval_search(model): - input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source" - output_file_name = input_file_name.parent / "utest_output.txt" - assert not output_file_name.exists() + tmp_dir = Path(self.get_auto_remove_tmp_dir()) + score_path = str(tmp_dir / "scores.json") + reference_path = str(tmp_dir / "val.target") + _dump_articles(input_file_name, text["en"]) + _dump_articles(reference_path, text["de"]) + task = "translation_en_to_de" if model == T5_TINY else "summarization" + testargs = f""" + run_eval_search.py + {model} + {str(input_file_name)} + {str(output_file_name)} + --score_path {score_path} + --reference_path {reference_path} + --task {task} + """.split() + testargs.extend(["--search", "num_beams=1:2 length_penalty=0.9:1.0"]) - text = { - "en": ["Machine learning is great, isn't it?", "I like to eat bananas", "Tomorrow is another great day!"], - "de": [ - "Maschinelles Lernen ist großartig, oder?", - "Ich esse gerne Bananen", - "Morgen ist wieder ein toller Tag!", - ], - } + with patch.object(sys, "argv", testargs): + with CaptureStdout() as cs: + run_search() + expected_strings = [" num_beams | length_penalty", model, "Best score args"] + un_expected_strings = ["Info"] + if "translation" in task: + expected_strings.append("bleu") + else: + expected_strings.extend(ROUGE_KEYS) + for w in expected_strings: + assert w in cs.out + for w in un_expected_strings: + assert w not in cs.out + assert Path(output_file_name).exists() + os.remove(Path(output_file_name)) - tmp_dir = Path(tempfile.mkdtemp()) - score_path = str(tmp_dir / "scores.json") - reference_path = str(tmp_dir / "val.target") - _dump_articles(input_file_name, text["en"]) - _dump_articles(reference_path, text["de"]) - task = "translation_en_to_de" if model == T5_TINY else "summarization" - testargs = f""" - run_eval_search.py - {model} - {str(input_file_name)} - {str(output_file_name)} - --score_path {score_path} - --reference_path {reference_path} - --task {task} - """.split() - testargs.extend(["--search", "num_beams=1:2 length_penalty=0.9:1.0"]) + @parameterized.expand( + [T5_TINY, BART_TINY, MBART_TINY, MARIAN_TINY, FSMT_TINY], + ) + def test_finetune(self, model): + args_d: dict = CHEAP_ARGS.copy() + task = "translation" if model in [MBART_TINY, MARIAN_TINY, FSMT_TINY] else "summarization" + args_d["label_smoothing"] = 0.1 if task == "translation" else 0 - with patch.object(sys, "argv", testargs): - with CaptureStdout() as cs: - run_search() - expected_strings = [" num_beams | length_penalty", model, "Best score args"] - un_expected_strings = ["Info"] - if "translation" in task: - expected_strings.append("bleu") + tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()) + output_dir = self.get_auto_remove_tmp_dir() + args_d.update( + data_dir=tmp_dir, + model_name_or_path=model, + tokenizer_name=None, + train_batch_size=2, + eval_batch_size=2, + output_dir=output_dir, + do_predict=True, + task=task, + src_lang="en_XX", + tgt_lang="ro_RO", + freeze_encoder=True, + freeze_embeds=True, + ) + assert "n_train" in args_d + args = argparse.Namespace(**args_d) + module = main(args) + + input_embeds = module.model.get_input_embeddings() + assert not input_embeds.weight.requires_grad + if model == T5_TINY: + lm_head = module.model.lm_head + assert not lm_head.weight.requires_grad + assert (lm_head.weight == input_embeds.weight).all().item() + elif model == FSMT_TINY: + fsmt = module.model.model + embed_pos = fsmt.decoder.embed_positions + assert not embed_pos.weight.requires_grad + assert not fsmt.decoder.embed_tokens.weight.requires_grad + # check that embeds are not the same + assert fsmt.decoder.embed_tokens != fsmt.encoder.embed_tokens else: - expected_strings.extend(ROUGE_KEYS) - for w in expected_strings: - assert w in cs.out - for w in un_expected_strings: - assert w not in cs.out - assert Path(output_file_name).exists() - os.remove(Path(output_file_name)) + bart = module.model.model + embed_pos = bart.decoder.embed_positions + assert not embed_pos.weight.requires_grad + assert not bart.shared.weight.requires_grad + # check that embeds are the same + assert bart.decoder.embed_tokens == bart.encoder.embed_tokens + assert bart.decoder.embed_tokens == bart.shared + example_batch = load_json(module.output_dir / "text_batch.json") + assert isinstance(example_batch, dict) + assert len(example_batch) >= 4 -@pytest.mark.parametrize( - "model", - [T5_TINY, BART_TINY, MBART_TINY, MARIAN_TINY, FSMT_TINY], -) -def test_finetune(model): - args_d: dict = CHEAP_ARGS.copy() - task = "translation" if model in [MBART_TINY, MARIAN_TINY, FSMT_TINY] else "summarization" - args_d["label_smoothing"] = 0.1 if task == "translation" else 0 + def test_finetune_extra_model_args(self): + args_d: dict = CHEAP_ARGS.copy() - tmp_dir = make_test_data_dir() - output_dir = tempfile.mkdtemp(prefix="output_") - args_d.update( - data_dir=tmp_dir, - model_name_or_path=model, - tokenizer_name=None, - train_batch_size=2, - eval_batch_size=2, - output_dir=output_dir, - do_predict=True, - task=task, - src_lang="en_XX", - tgt_lang="ro_RO", - freeze_encoder=True, - freeze_embeds=True, - ) - assert "n_train" in args_d - args = argparse.Namespace(**args_d) - module = main(args) + task = "summarization" + tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()) - input_embeds = module.model.get_input_embeddings() - assert not input_embeds.weight.requires_grad - if model == T5_TINY: - lm_head = module.model.lm_head - assert not lm_head.weight.requires_grad - assert (lm_head.weight == input_embeds.weight).all().item() - elif model == FSMT_TINY: - fsmt = module.model.model - embed_pos = fsmt.decoder.embed_positions - assert not embed_pos.weight.requires_grad - assert not fsmt.decoder.embed_tokens.weight.requires_grad - # check that embeds are not the same - assert fsmt.decoder.embed_tokens != fsmt.encoder.embed_tokens - else: - bart = module.model.model - embed_pos = bart.decoder.embed_positions - assert not embed_pos.weight.requires_grad - assert not bart.shared.weight.requires_grad - # check that embeds are the same - assert bart.decoder.embed_tokens == bart.encoder.embed_tokens - assert bart.decoder.embed_tokens == bart.shared + args_d.update( + data_dir=tmp_dir, + tokenizer_name=None, + train_batch_size=2, + eval_batch_size=2, + do_predict=False, + task=task, + src_lang="en_XX", + tgt_lang="ro_RO", + freeze_encoder=True, + freeze_embeds=True, + ) - example_batch = load_json(module.output_dir / "text_batch.json") - assert isinstance(example_batch, dict) - assert len(example_batch) >= 4 - - -def test_finetune_extra_model_args(): - args_d: dict = CHEAP_ARGS.copy() - - task = "summarization" - tmp_dir = make_test_data_dir() - - args_d.update( - data_dir=tmp_dir, - tokenizer_name=None, - train_batch_size=2, - eval_batch_size=2, - do_predict=False, - task=task, - src_lang="en_XX", - tgt_lang="ro_RO", - freeze_encoder=True, - freeze_embeds=True, - ) - - # test models whose config includes the extra_model_args - model = BART_TINY - output_dir = tempfile.mkdtemp(prefix="output_1_") - args_d1 = args_d.copy() - args_d1.update( - model_name_or_path=model, - output_dir=output_dir, - ) - extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout") - for p in extra_model_params: - args_d1[p] = 0.5 - args = argparse.Namespace(**args_d1) - model = main(args) - for p in extra_model_params: - assert getattr(model.config, p) == 0.5, f"failed to override the model config for param {p}" - - # test models whose config doesn't include the extra_model_args - model = T5_TINY - output_dir = tempfile.mkdtemp(prefix="output_2_") - args_d2 = args_d.copy() - args_d2.update( - model_name_or_path=model, - output_dir=output_dir, - ) - unsupported_param = "encoder_layerdrop" - args_d2[unsupported_param] = 0.5 - args = argparse.Namespace(**args_d2) - with pytest.raises(Exception) as excinfo: + # test models whose config includes the extra_model_args + model = BART_TINY + output_dir = self.get_auto_remove_tmp_dir() + args_d1 = args_d.copy() + args_d1.update( + model_name_or_path=model, + output_dir=output_dir, + ) + extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout") + for p in extra_model_params: + args_d1[p] = 0.5 + args = argparse.Namespace(**args_d1) model = main(args) - assert str(excinfo.value) == f"model config doesn't have a `{unsupported_param}` attribute" + for p in extra_model_params: + assert getattr(model.config, p) == 0.5, f"failed to override the model config for param {p}" + # test models whose config doesn't include the extra_model_args + model = T5_TINY + output_dir = self.get_auto_remove_tmp_dir() + args_d2 = args_d.copy() + args_d2.update( + model_name_or_path=model, + output_dir=output_dir, + ) + unsupported_param = "encoder_layerdrop" + args_d2[unsupported_param] = 0.5 + args = argparse.Namespace(**args_d2) + with pytest.raises(Exception) as excinfo: + model = main(args) + assert str(excinfo.value) == f"model config doesn't have a `{unsupported_param}` attribute" -def test_finetune_lr_schedulers(): - args_d: dict = CHEAP_ARGS.copy() + def test_finetune_lr_schedulers(self): + args_d: dict = CHEAP_ARGS.copy() - task = "summarization" - tmp_dir = make_test_data_dir() + task = "summarization" + tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()) - model = BART_TINY - output_dir = tempfile.mkdtemp(prefix="output_1_") + model = BART_TINY + output_dir = self.get_auto_remove_tmp_dir() - args_d.update( - data_dir=tmp_dir, - model_name_or_path=model, - output_dir=output_dir, - tokenizer_name=None, - train_batch_size=2, - eval_batch_size=2, - do_predict=False, - task=task, - src_lang="en_XX", - tgt_lang="ro_RO", - freeze_encoder=True, - freeze_embeds=True, - ) + args_d.update( + data_dir=tmp_dir, + model_name_or_path=model, + output_dir=output_dir, + tokenizer_name=None, + train_batch_size=2, + eval_batch_size=2, + do_predict=False, + task=task, + src_lang="en_XX", + tgt_lang="ro_RO", + freeze_encoder=True, + freeze_embeds=True, + ) - # emulate finetune.py - parser = argparse.ArgumentParser() - parser = pl.Trainer.add_argparse_args(parser) - parser = SummarizationModule.add_model_specific_args(parser, os.getcwd()) - args = {"--help": True} + # emulate finetune.py + parser = argparse.ArgumentParser() + parser = pl.Trainer.add_argparse_args(parser) + parser = SummarizationModule.add_model_specific_args(parser, os.getcwd()) + args = {"--help": True} - # --help test - with pytest.raises(SystemExit) as excinfo: - with CaptureStdout() as cs: - args = parser.parse_args(args) - assert False, "--help is expected to sys.exit" - assert excinfo.type == SystemExit - expected = lightning_base.arg_to_scheduler_metavar - assert expected in cs.out, "--help is expected to list the supported schedulers" + # --help test + with pytest.raises(SystemExit) as excinfo: + with CaptureStdout() as cs: + args = parser.parse_args(args) + assert False, "--help is expected to sys.exit" + assert excinfo.type == SystemExit + expected = lightning_base.arg_to_scheduler_metavar + assert expected in cs.out, "--help is expected to list the supported schedulers" - # --lr_scheduler=non_existing_scheduler test - unsupported_param = "non_existing_scheduler" - args = {f"--lr_scheduler={unsupported_param}"} - with pytest.raises(SystemExit) as excinfo: - with CaptureStderr() as cs: - args = parser.parse_args(args) - assert False, "invalid argument is expected to sys.exit" - assert excinfo.type == SystemExit - expected = f"invalid choice: '{unsupported_param}'" - assert expected in cs.err, f"should have bailed on invalid choice of scheduler {unsupported_param}" + # --lr_scheduler=non_existing_scheduler test + unsupported_param = "non_existing_scheduler" + args = {f"--lr_scheduler={unsupported_param}"} + with pytest.raises(SystemExit) as excinfo: + with CaptureStderr() as cs: + args = parser.parse_args(args) + assert False, "invalid argument is expected to sys.exit" + assert excinfo.type == SystemExit + expected = f"invalid choice: '{unsupported_param}'" + assert expected in cs.err, f"should have bailed on invalid choice of scheduler {unsupported_param}" - # --lr_scheduler=existing_scheduler test - supported_param = "cosine" - args_d1 = args_d.copy() - args_d1["lr_scheduler"] = supported_param - args = argparse.Namespace(**args_d1) - model = main(args) - assert getattr(model.hparams, "lr_scheduler") == supported_param, f"lr_scheduler={supported_param} shouldn't fail" + # --lr_scheduler=existing_scheduler test + supported_param = "cosine" + args_d1 = args_d.copy() + args_d1["lr_scheduler"] = supported_param + args = argparse.Namespace(**args_d1) + model = main(args) + assert ( + getattr(model.hparams, "lr_scheduler") == supported_param + ), f"lr_scheduler={supported_param} shouldn't fail"