[s2s testing] turn all to unittests, use auto-delete temp dirs (#7859)
This commit is contained in:
@@ -3,7 +3,6 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
@@ -16,7 +15,7 @@ from distillation import BartSummarizationDistiller, distill_main
|
|||||||
from finetune import SummarizationModule, main
|
from finetune import SummarizationModule, main
|
||||||
from test_seq2seq_examples import CUDA_AVAILABLE, MBART_TINY
|
from test_seq2seq_examples import CUDA_AVAILABLE, MBART_TINY
|
||||||
from transformers import BartForConditionalGeneration, MarianMTModel
|
from transformers import BartForConditionalGeneration, MarianMTModel
|
||||||
from transformers.testing_utils import slow
|
from transformers.testing_utils import TestCasePlus, slow
|
||||||
from utils import load_json
|
from utils import load_json
|
||||||
|
|
||||||
|
|
||||||
@@ -24,18 +23,18 @@ MODEL_NAME = MBART_TINY
|
|||||||
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
|
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
|
||||||
|
|
||||||
|
|
||||||
@slow
|
class TestAll(TestCasePlus):
|
||||||
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU")
|
@slow
|
||||||
def test_model_download():
|
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU")
|
||||||
|
def test_model_download(self):
|
||||||
"""This warms up the cache so that we can time the next test without including download time, which varies between machines."""
|
"""This warms up the cache so that we can time the next test without including download time, which varies between machines."""
|
||||||
BartForConditionalGeneration.from_pretrained(MODEL_NAME)
|
BartForConditionalGeneration.from_pretrained(MODEL_NAME)
|
||||||
MarianMTModel.from_pretrained(MARIAN_MODEL)
|
MarianMTModel.from_pretrained(MARIAN_MODEL)
|
||||||
|
|
||||||
|
@timeout_decorator.timeout(120)
|
||||||
@timeout_decorator.timeout(120)
|
@slow
|
||||||
@slow
|
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU")
|
||||||
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU")
|
def test_train_mbart_cc25_enro_script(self):
|
||||||
def test_train_mbart_cc25_enro_script():
|
|
||||||
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
|
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
|
||||||
env_vars_to_replace = {
|
env_vars_to_replace = {
|
||||||
"--fp16_opt_level=O1": "",
|
"--fp16_opt_level=O1": "",
|
||||||
@@ -53,7 +52,7 @@ def test_train_mbart_cc25_enro_script():
|
|||||||
bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "")
|
bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "")
|
||||||
for k, v in env_vars_to_replace.items():
|
for k, v in env_vars_to_replace.items():
|
||||||
bash_script = bash_script.replace(k, str(v))
|
bash_script = bash_script.replace(k, str(v))
|
||||||
output_dir = tempfile.mkdtemp(prefix="output_mbart")
|
output_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
|
||||||
bash_script = bash_script.replace("--fp16 ", "")
|
bash_script = bash_script.replace("--fp16 ", "")
|
||||||
testargs = (
|
testargs = (
|
||||||
@@ -81,7 +80,9 @@ def test_train_mbart_cc25_enro_script():
|
|||||||
metrics = load_json(model.metrics_save_path)
|
metrics = load_json(model.metrics_save_path)
|
||||||
first_step_stats = metrics["val"][0]
|
first_step_stats = metrics["val"][0]
|
||||||
last_step_stats = metrics["val"][-1]
|
last_step_stats = metrics["val"][-1]
|
||||||
assert len(metrics["val"]) == (args.max_epochs / args.val_check_interval) + 1 # +1 accounts for val_sanity_check
|
assert (
|
||||||
|
len(metrics["val"]) == (args.max_epochs / args.val_check_interval) + 1
|
||||||
|
) # +1 accounts for val_sanity_check
|
||||||
|
|
||||||
assert last_step_stats["val_avg_gen_time"] >= 0.01
|
assert last_step_stats["val_avg_gen_time"] >= 0.01
|
||||||
|
|
||||||
@@ -106,11 +107,10 @@ def test_train_mbart_cc25_enro_script():
|
|||||||
# assert len(metrics["val"]) == desired_n_evals
|
# assert len(metrics["val"]) == desired_n_evals
|
||||||
assert len(metrics["test"]) == 1
|
assert len(metrics["test"]) == 1
|
||||||
|
|
||||||
|
@timeout_decorator.timeout(600)
|
||||||
@timeout_decorator.timeout(600)
|
@slow
|
||||||
@slow
|
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU")
|
||||||
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU")
|
def test_opus_mt_distill_script(self):
|
||||||
def test_opus_mt_distill_script():
|
|
||||||
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
|
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
|
||||||
env_vars_to_replace = {
|
env_vars_to_replace = {
|
||||||
"--fp16_opt_level=O1": "",
|
"--fp16_opt_level=O1": "",
|
||||||
@@ -131,7 +131,7 @@ def test_opus_mt_distill_script():
|
|||||||
|
|
||||||
for k, v in env_vars_to_replace.items():
|
for k, v in env_vars_to_replace.items():
|
||||||
bash_script = bash_script.replace(k, str(v))
|
bash_script = bash_script.replace(k, str(v))
|
||||||
output_dir = tempfile.mkdtemp(prefix="marian_output")
|
output_dir = self.get_auto_remove_tmp_dir()
|
||||||
bash_script = bash_script.replace("--fp16", "")
|
bash_script = bash_script.replace("--fp16", "")
|
||||||
epochs = 6
|
epochs = 6
|
||||||
testargs = (
|
testargs = (
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
@@ -1,5 +1,4 @@
|
|||||||
import os
|
import os
|
||||||
import tempfile
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -7,11 +6,12 @@ import pytest
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from pack_dataset import pack_data_dir
|
from pack_dataset import pack_data_dir
|
||||||
|
from parameterized import parameterized
|
||||||
from save_len_file import save_len_file
|
from save_len_file import save_len_file
|
||||||
from test_seq2seq_examples import ARTICLES, BART_TINY, MARIAN_TINY, MBART_TINY, SUMMARIES, T5_TINY, make_test_data_dir
|
from test_seq2seq_examples import ARTICLES, BART_TINY, MARIAN_TINY, MBART_TINY, SUMMARIES, T5_TINY, make_test_data_dir
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
from transformers.modeling_bart import shift_tokens_right
|
from transformers.modeling_bart import shift_tokens_right
|
||||||
from transformers.testing_utils import slow
|
from transformers.testing_utils import TestCasePlus, slow
|
||||||
from utils import FAIRSEQ_AVAILABLE, DistributedSortishSampler, LegacySeq2SeqDataset, Seq2SeqDataset
|
from utils import FAIRSEQ_AVAILABLE, DistributedSortishSampler, LegacySeq2SeqDataset, Seq2SeqDataset
|
||||||
|
|
||||||
|
|
||||||
@@ -19,9 +19,8 @@ BERT_BASE_CASED = "bert-base-cased"
|
|||||||
PEGASUS_XSUM = "google/pegasus-xsum"
|
PEGASUS_XSUM = "google/pegasus-xsum"
|
||||||
|
|
||||||
|
|
||||||
@slow
|
class TestAll(TestCasePlus):
|
||||||
@pytest.mark.parametrize(
|
@parameterized.expand(
|
||||||
"tok_name",
|
|
||||||
[
|
[
|
||||||
MBART_TINY,
|
MBART_TINY,
|
||||||
MARIAN_TINY,
|
MARIAN_TINY,
|
||||||
@@ -29,10 +28,11 @@ PEGASUS_XSUM = "google/pegasus-xsum"
|
|||||||
BART_TINY,
|
BART_TINY,
|
||||||
PEGASUS_XSUM,
|
PEGASUS_XSUM,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_seq2seq_dataset_truncation(tok_name):
|
@slow
|
||||||
|
def test_seq2seq_dataset_truncation(self, tok_name):
|
||||||
tokenizer = AutoTokenizer.from_pretrained(tok_name)
|
tokenizer = AutoTokenizer.from_pretrained(tok_name)
|
||||||
tmp_dir = make_test_data_dir()
|
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
|
||||||
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
|
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
|
||||||
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
|
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
|
||||||
max_src_len = 4
|
max_src_len = 4
|
||||||
@@ -68,11 +68,10 @@ def test_seq2seq_dataset_truncation(tok_name):
|
|||||||
|
|
||||||
break # No need to test every batch
|
break # No need to test every batch
|
||||||
|
|
||||||
|
@parameterized.expand([BART_TINY, BERT_BASE_CASED])
|
||||||
@pytest.mark.parametrize("tok", [BART_TINY, BERT_BASE_CASED])
|
def test_legacy_dataset_truncation(self, tok):
|
||||||
def test_legacy_dataset_truncation(tok):
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(tok)
|
tokenizer = AutoTokenizer.from_pretrained(tok)
|
||||||
tmp_dir = make_test_data_dir()
|
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
|
||||||
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
|
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
|
||||||
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
|
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
|
||||||
trunc_target = 4
|
trunc_target = 4
|
||||||
@@ -94,13 +93,12 @@ def test_legacy_dataset_truncation(tok):
|
|||||||
assert max_len_target > trunc_target # Truncated
|
assert max_len_target > trunc_target # Truncated
|
||||||
break # No need to test every batch
|
break # No need to test every batch
|
||||||
|
|
||||||
|
def test_pack_dataset(self):
|
||||||
def test_pack_dataset():
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
|
tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
|
||||||
|
|
||||||
tmp_dir = Path(make_test_data_dir())
|
tmp_dir = Path(make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()))
|
||||||
orig_examples = tmp_dir.joinpath("train.source").open().readlines()
|
orig_examples = tmp_dir.joinpath("train.source").open().readlines()
|
||||||
save_dir = Path(tempfile.mkdtemp(prefix="packed_"))
|
save_dir = Path(make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()))
|
||||||
pack_data_dir(tokenizer, tmp_dir, 128, save_dir)
|
pack_data_dir(tokenizer, tmp_dir, 128, save_dir)
|
||||||
orig_paths = {x.name for x in tmp_dir.iterdir()}
|
orig_paths = {x.name for x in tmp_dir.iterdir()}
|
||||||
new_paths = {x.name for x in save_dir.iterdir()}
|
new_paths = {x.name for x in save_dir.iterdir()}
|
||||||
@@ -112,12 +110,11 @@ def test_pack_dataset():
|
|||||||
assert len(packed_examples[0]) == sum(len(x) for x in orig_examples)
|
assert len(packed_examples[0]) == sum(len(x) for x in orig_examples)
|
||||||
assert orig_paths == new_paths
|
assert orig_paths == new_paths
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not FAIRSEQ_AVAILABLE, reason="This test requires fairseq")
|
||||||
@pytest.mark.skipif(not FAIRSEQ_AVAILABLE, reason="This test requires fairseq")
|
def test_dynamic_batch_size(self):
|
||||||
def test_dynamic_batch_size():
|
|
||||||
if not FAIRSEQ_AVAILABLE:
|
if not FAIRSEQ_AVAILABLE:
|
||||||
return
|
return
|
||||||
ds, max_tokens, tokenizer = _get_dataset(max_len=64)
|
ds, max_tokens, tokenizer = self._get_dataset(max_len=64)
|
||||||
required_batch_size_multiple = 64
|
required_batch_size_multiple = 64
|
||||||
batch_sampler = ds.make_dynamic_sampler(max_tokens, required_batch_size_multiple=required_batch_size_multiple)
|
batch_sampler = ds.make_dynamic_sampler(max_tokens, required_batch_size_multiple=required_batch_size_multiple)
|
||||||
batch_sizes = [len(x) for x in batch_sampler]
|
batch_sizes = [len(x) for x in batch_sampler]
|
||||||
@@ -138,9 +135,8 @@ def test_dynamic_batch_size():
|
|||||||
if failures:
|
if failures:
|
||||||
raise AssertionError(f"too many tokens in {len(failures)} batches")
|
raise AssertionError(f"too many tokens in {len(failures)} batches")
|
||||||
|
|
||||||
|
def test_sortish_sampler_reduces_padding(self):
|
||||||
def test_sortish_sampler_reduces_padding():
|
ds, _, tokenizer = self._get_dataset(max_len=512)
|
||||||
ds, _, tokenizer = _get_dataset(max_len=512)
|
|
||||||
bs = 2
|
bs = 2
|
||||||
sortish_sampler = ds.make_sortish_sampler(bs, shuffle=False)
|
sortish_sampler = ds.make_sortish_sampler(bs, shuffle=False)
|
||||||
|
|
||||||
@@ -156,8 +152,7 @@ def test_sortish_sampler_reduces_padding():
|
|||||||
assert sum(count_pad_tokens(sortish_dl)) < sum(count_pad_tokens(naive_dl))
|
assert sum(count_pad_tokens(sortish_dl)) < sum(count_pad_tokens(naive_dl))
|
||||||
assert len(sortish_dl) == len(naive_dl)
|
assert len(sortish_dl) == len(naive_dl)
|
||||||
|
|
||||||
|
def _get_dataset(self, n_obs=1000, max_len=128):
|
||||||
def _get_dataset(n_obs=1000, max_len=128):
|
|
||||||
if os.getenv("USE_REAL_DATA", False):
|
if os.getenv("USE_REAL_DATA", False):
|
||||||
data_dir = "examples/seq2seq/wmt_en_ro"
|
data_dir = "examples/seq2seq/wmt_en_ro"
|
||||||
max_tokens = max_len * 2 * 64
|
max_tokens = max_len * 2 * 64
|
||||||
@@ -179,16 +174,13 @@ def _get_dataset(n_obs=1000, max_len=128):
|
|||||||
)
|
)
|
||||||
return ds, max_tokens, tokenizer
|
return ds, max_tokens, tokenizer
|
||||||
|
|
||||||
|
def test_distributed_sortish_sampler_splits_indices_between_procs(self):
|
||||||
def test_distributed_sortish_sampler_splits_indices_between_procs():
|
ds, max_tokens, tokenizer = self._get_dataset()
|
||||||
ds, max_tokens, tokenizer = _get_dataset()
|
|
||||||
ids1 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=0, add_extra_examples=False))
|
ids1 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=0, add_extra_examples=False))
|
||||||
ids2 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=1, add_extra_examples=False))
|
ids2 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=1, add_extra_examples=False))
|
||||||
assert ids1.intersection(ids2) == set()
|
assert ids1.intersection(ids2) == set()
|
||||||
|
|
||||||
|
@parameterized.expand(
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"tok_name",
|
|
||||||
[
|
[
|
||||||
MBART_TINY,
|
MBART_TINY,
|
||||||
MARIAN_TINY,
|
MARIAN_TINY,
|
||||||
@@ -196,13 +188,13 @@ def test_distributed_sortish_sampler_splits_indices_between_procs():
|
|||||||
BART_TINY,
|
BART_TINY,
|
||||||
PEGASUS_XSUM,
|
PEGASUS_XSUM,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_dataset_kwargs(tok_name):
|
def test_dataset_kwargs(self, tok_name):
|
||||||
tokenizer = AutoTokenizer.from_pretrained(tok_name)
|
tokenizer = AutoTokenizer.from_pretrained(tok_name)
|
||||||
if tok_name == MBART_TINY:
|
if tok_name == MBART_TINY:
|
||||||
train_dataset = Seq2SeqDataset(
|
train_dataset = Seq2SeqDataset(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
data_dir=make_test_data_dir(),
|
data_dir=make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()),
|
||||||
type_path="train",
|
type_path="train",
|
||||||
max_source_length=4,
|
max_source_length=4,
|
||||||
max_target_length=8,
|
max_target_length=8,
|
||||||
@@ -213,7 +205,11 @@ def test_dataset_kwargs(tok_name):
|
|||||||
assert "src_lang" in kwargs and "tgt_lang" in kwargs
|
assert "src_lang" in kwargs and "tgt_lang" in kwargs
|
||||||
else:
|
else:
|
||||||
train_dataset = Seq2SeqDataset(
|
train_dataset = Seq2SeqDataset(
|
||||||
tokenizer, data_dir=make_test_data_dir(), type_path="train", max_source_length=4, max_target_length=8
|
tokenizer,
|
||||||
|
data_dir=make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()),
|
||||||
|
type_path="train",
|
||||||
|
max_source_length=4,
|
||||||
|
max_target_length=8,
|
||||||
)
|
)
|
||||||
kwargs = train_dataset.dataset_kwargs
|
kwargs = train_dataset.dataset_kwargs
|
||||||
assert "add_prefix_space" not in kwargs if tok_name != BART_TINY else "add_prefix_space" in kwargs
|
assert "add_prefix_space" not in kwargs if tok_name != BART_TINY else "add_prefix_space" in kwargs
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from transformers.testing_utils import slow
|
from transformers.testing_utils import TestCasePlus, slow
|
||||||
from transformers.trainer_callback import TrainerState
|
from transformers.trainer_callback import TrainerState
|
||||||
from transformers.trainer_utils import set_seed
|
from transformers.trainer_utils import set_seed
|
||||||
|
|
||||||
@@ -15,18 +14,18 @@ set_seed(42)
|
|||||||
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
|
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
|
||||||
|
|
||||||
|
|
||||||
def test_finetune_trainer():
|
class TestFinetuneTrainer(TestCasePlus):
|
||||||
output_dir = run_trainer(1, "12", MBART_TINY, 1)
|
def test_finetune_trainer(self):
|
||||||
|
output_dir = self.run_trainer(1, "12", MBART_TINY, 1)
|
||||||
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
|
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
|
||||||
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
|
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
|
||||||
first_step_stats = eval_metrics[0]
|
first_step_stats = eval_metrics[0]
|
||||||
assert "eval_bleu" in first_step_stats
|
assert "eval_bleu" in first_step_stats
|
||||||
|
|
||||||
|
@slow
|
||||||
@slow
|
def test_finetune_trainer_slow(self):
|
||||||
def test_finetune_trainer_slow():
|
|
||||||
# There is a missing call to __init__process_group somewhere
|
# There is a missing call to __init__process_group somewhere
|
||||||
output_dir = run_trainer(eval_steps=2, max_len="128", model_name=MARIAN_MODEL, num_train_epochs=3)
|
output_dir = self.run_trainer(eval_steps=2, max_len="128", model_name=MARIAN_MODEL, num_train_epochs=3)
|
||||||
|
|
||||||
# Check metrics
|
# Check metrics
|
||||||
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
|
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
|
||||||
@@ -43,10 +42,9 @@ def test_finetune_trainer_slow():
|
|||||||
assert "test_generations.txt" in contents
|
assert "test_generations.txt" in contents
|
||||||
assert "test_results.json" in contents
|
assert "test_results.json" in contents
|
||||||
|
|
||||||
|
def run_trainer(self, eval_steps: int, max_len: str, model_name: str, num_train_epochs: int):
|
||||||
def run_trainer(eval_steps: int, max_len: str, model_name: str, num_train_epochs: int):
|
|
||||||
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
|
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
|
||||||
output_dir = tempfile.mkdtemp(prefix="test_output")
|
output_dir = self.get_auto_remove_tmp_dir()
|
||||||
argv = f"""
|
argv = f"""
|
||||||
--model_name_or_path {model_name}
|
--model_name_or_path {model_name}
|
||||||
--data_dir {data_dir}
|
--data_dir {data_dir}
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
@@ -15,11 +14,12 @@ import lightning_base
|
|||||||
from convert_pl_checkpoint_to_hf import convert_pl_to_hf
|
from convert_pl_checkpoint_to_hf import convert_pl_to_hf
|
||||||
from distillation import distill_main
|
from distillation import distill_main
|
||||||
from finetune import SummarizationModule, main
|
from finetune import SummarizationModule, main
|
||||||
|
from parameterized import parameterized
|
||||||
from run_eval import generate_summaries_or_translations, run_generate
|
from run_eval import generate_summaries_or_translations, run_generate
|
||||||
from run_eval_search import run_search
|
from run_eval_search import run_search
|
||||||
from transformers import AutoConfig, AutoModelForSeq2SeqLM
|
from transformers import AutoConfig, AutoModelForSeq2SeqLM
|
||||||
from transformers.hf_api import HfApi
|
from transformers.hf_api import HfApi
|
||||||
from transformers.testing_utils import CaptureStderr, CaptureStdout, require_multigpu, require_torch_and_cuda, slow
|
from transformers.testing_utils import CaptureStderr, CaptureStdout, TestCasePlus, require_torch_and_cuda, slow
|
||||||
from utils import ROUGE_KEYS, label_smoothed_nll_loss, lmap, load_json
|
from utils import ROUGE_KEYS, label_smoothed_nll_loss, lmap, load_json
|
||||||
|
|
||||||
|
|
||||||
@@ -52,7 +52,7 @@ CHEAP_ARGS = {
|
|||||||
"student_decoder_layers": 1,
|
"student_decoder_layers": 1,
|
||||||
"val_check_interval": 1.0,
|
"val_check_interval": 1.0,
|
||||||
"output_dir": "",
|
"output_dir": "",
|
||||||
"fp16": False, # TODO: set this to CUDA_AVAILABLE if ci installs apex or start using native amp
|
"fp16": False, # TODO(SS): set this to CUDA_AVAILABLE if ci installs apex or start using native amp
|
||||||
"no_teacher": False,
|
"no_teacher": False,
|
||||||
"fp16_opt_level": "O1",
|
"fp16_opt_level": "O1",
|
||||||
"gpus": 1 if CUDA_AVAILABLE else 0,
|
"gpus": 1 if CUDA_AVAILABLE else 0,
|
||||||
@@ -88,6 +88,7 @@ CHEAP_ARGS = {
|
|||||||
"student_encoder_layers": 1,
|
"student_encoder_layers": 1,
|
||||||
"freeze_encoder": False,
|
"freeze_encoder": False,
|
||||||
"auto_scale_batch_size": False,
|
"auto_scale_batch_size": False,
|
||||||
|
"overwrite_output_dir": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -110,15 +111,14 @@ logger.addHandler(stream_handler)
|
|||||||
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
|
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
|
||||||
|
|
||||||
|
|
||||||
def make_test_data_dir(**kwargs):
|
def make_test_data_dir(tmp_dir):
|
||||||
tmp_dir = Path(tempfile.mkdtemp(**kwargs))
|
|
||||||
for split in ["train", "val", "test"]:
|
for split in ["train", "val", "test"]:
|
||||||
_dump_articles((tmp_dir / f"{split}.source"), ARTICLES)
|
_dump_articles(os.path.join(tmp_dir, f"{split}.source"), ARTICLES)
|
||||||
_dump_articles((tmp_dir / f"{split}.target"), SUMMARIES)
|
_dump_articles(os.path.join(tmp_dir, f"{split}.target"), SUMMARIES)
|
||||||
return tmp_dir
|
return tmp_dir
|
||||||
|
|
||||||
|
|
||||||
class TestSummarizationDistiller(unittest.TestCase):
|
class TestSummarizationDistiller(TestCasePlus):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
|
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
|
||||||
@@ -143,17 +143,6 @@ class TestSummarizationDistiller(unittest.TestCase):
|
|||||||
failures.append(m)
|
failures.append(m)
|
||||||
assert not failures, f"The following models could not be loaded through AutoConfig: {failures}"
|
assert not failures, f"The following models could not be loaded through AutoConfig: {failures}"
|
||||||
|
|
||||||
@require_multigpu
|
|
||||||
@unittest.skip("Broken at the moment")
|
|
||||||
def test_multigpu(self):
|
|
||||||
updates = dict(
|
|
||||||
no_teacher=True,
|
|
||||||
freeze_encoder=True,
|
|
||||||
gpus=2,
|
|
||||||
sortish_sampler=True,
|
|
||||||
)
|
|
||||||
self._test_distiller_cli(updates, check_contents=False)
|
|
||||||
|
|
||||||
def test_distill_no_teacher(self):
|
def test_distill_no_teacher(self):
|
||||||
updates = dict(student_encoder_layers=2, student_decoder_layers=1, no_teacher=True)
|
updates = dict(student_encoder_layers=2, student_decoder_layers=1, no_teacher=True)
|
||||||
self._test_distiller_cli(updates)
|
self._test_distiller_cli(updates)
|
||||||
@@ -173,12 +162,12 @@ class TestSummarizationDistiller(unittest.TestCase):
|
|||||||
self.assertEqual(1, len(ckpts))
|
self.assertEqual(1, len(ckpts))
|
||||||
transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin"))
|
transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin"))
|
||||||
self.assertEqual(len(transformer_ckpts), 2)
|
self.assertEqual(len(transformer_ckpts), 2)
|
||||||
examples = lmap(str.strip, model.hparams.data_dir.joinpath("test.source").open().readlines())
|
examples = lmap(str.strip, Path(model.hparams.data_dir).joinpath("test.source").open().readlines())
|
||||||
out_path = tempfile.mktemp()
|
out_path = tempfile.mktemp() # XXX: not being cleaned up
|
||||||
generate_summaries_or_translations(examples, out_path, str(model.output_dir / "best_tfmr"))
|
generate_summaries_or_translations(examples, out_path, str(model.output_dir / "best_tfmr"))
|
||||||
self.assertTrue(Path(out_path).exists())
|
self.assertTrue(Path(out_path).exists())
|
||||||
|
|
||||||
out_path_new = tempfile.mkdtemp()
|
out_path_new = self.get_auto_remove_tmp_dir()
|
||||||
convert_pl_to_hf(ckpts[0], transformer_ckpts[0].parent, out_path_new)
|
convert_pl_to_hf(ckpts[0], transformer_ckpts[0].parent, out_path_new)
|
||||||
assert os.path.exists(os.path.join(out_path_new, "pytorch_model.bin"))
|
assert os.path.exists(os.path.join(out_path_new, "pytorch_model.bin"))
|
||||||
|
|
||||||
@@ -253,8 +242,8 @@ class TestSummarizationDistiller(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
default_updates.update(updates)
|
default_updates.update(updates)
|
||||||
args_d: dict = CHEAP_ARGS.copy()
|
args_d: dict = CHEAP_ARGS.copy()
|
||||||
tmp_dir = make_test_data_dir()
|
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
|
||||||
output_dir = tempfile.mkdtemp(prefix="output_")
|
output_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
|
||||||
args_d.update(data_dir=tmp_dir, output_dir=output_dir, **default_updates)
|
args_d.update(data_dir=tmp_dir, output_dir=output_dir, **default_updates)
|
||||||
model = distill_main(argparse.Namespace(**args_d))
|
model = distill_main(argparse.Namespace(**args_d))
|
||||||
@@ -279,13 +268,15 @@ class TestSummarizationDistiller(unittest.TestCase):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def run_eval_tester(model):
|
class TestTheRest(TestCasePlus):
|
||||||
input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source"
|
def run_eval_tester(self, model):
|
||||||
|
input_file_name = Path(self.get_auto_remove_tmp_dir()) / "utest_input.source"
|
||||||
output_file_name = input_file_name.parent / "utest_output.txt"
|
output_file_name = input_file_name.parent / "utest_output.txt"
|
||||||
assert not output_file_name.exists()
|
assert not output_file_name.exists()
|
||||||
articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
|
articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
|
||||||
_dump_articles(input_file_name, articles)
|
_dump_articles(input_file_name, articles)
|
||||||
score_path = str(Path(tempfile.mkdtemp()) / "scores.json")
|
|
||||||
|
score_path = str(Path(self.get_auto_remove_tmp_dir()) / "scores.json")
|
||||||
task = "translation_en_to_de" if model == T5_TINY else "summarization"
|
task = "translation_en_to_de" if model == T5_TINY else "summarization"
|
||||||
testargs = f"""
|
testargs = f"""
|
||||||
run_eval_search.py
|
run_eval_search.py
|
||||||
@@ -301,27 +292,24 @@ def run_eval_tester(model):
|
|||||||
with patch.object(sys, "argv", testargs):
|
with patch.object(sys, "argv", testargs):
|
||||||
run_generate()
|
run_generate()
|
||||||
assert Path(output_file_name).exists()
|
assert Path(output_file_name).exists()
|
||||||
os.remove(Path(output_file_name))
|
# os.remove(Path(output_file_name))
|
||||||
|
|
||||||
|
# test one model to quickly (no-@slow) catch simple problems and do an
|
||||||
|
# extensive testing of functionality with multiple models as @slow separately
|
||||||
|
def test_run_eval(self):
|
||||||
|
self.run_eval_tester(T5_TINY)
|
||||||
|
|
||||||
# test one model to quickly (no-@slow) catch simple problems and do an
|
# any extra models should go into the list here - can be slow
|
||||||
# extensive testing of functionality with multiple models as @slow separately
|
@parameterized.expand([BART_TINY, MBART_TINY])
|
||||||
def test_run_eval():
|
@slow
|
||||||
run_eval_tester(T5_TINY)
|
def test_run_eval_slow(self, model):
|
||||||
|
self.run_eval_tester(model)
|
||||||
|
|
||||||
|
# testing with 2 models to validate: 1. translation (t5) 2. summarization (mbart)
|
||||||
# any extra models should go into the list here - can be slow
|
@parameterized.expand([T5_TINY, MBART_TINY])
|
||||||
@slow
|
@slow
|
||||||
@pytest.mark.parametrize("model", [BART_TINY, MBART_TINY])
|
def test_run_eval_search(self, model):
|
||||||
def test_run_eval_slow(model):
|
input_file_name = Path(self.get_auto_remove_tmp_dir()) / "utest_input.source"
|
||||||
run_eval_tester(model)
|
|
||||||
|
|
||||||
|
|
||||||
# testing with 2 models to validate: 1. translation (t5) 2. summarization (mbart)
|
|
||||||
@slow
|
|
||||||
@pytest.mark.parametrize("model", [T5_TINY, MBART_TINY])
|
|
||||||
def test_run_eval_search(model):
|
|
||||||
input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source"
|
|
||||||
output_file_name = input_file_name.parent / "utest_output.txt"
|
output_file_name = input_file_name.parent / "utest_output.txt"
|
||||||
assert not output_file_name.exists()
|
assert not output_file_name.exists()
|
||||||
|
|
||||||
@@ -334,7 +322,7 @@ def test_run_eval_search(model):
|
|||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
tmp_dir = Path(tempfile.mkdtemp())
|
tmp_dir = Path(self.get_auto_remove_tmp_dir())
|
||||||
score_path = str(tmp_dir / "scores.json")
|
score_path = str(tmp_dir / "scores.json")
|
||||||
reference_path = str(tmp_dir / "val.target")
|
reference_path = str(tmp_dir / "val.target")
|
||||||
_dump_articles(input_file_name, text["en"])
|
_dump_articles(input_file_name, text["en"])
|
||||||
@@ -367,18 +355,16 @@ def test_run_eval_search(model):
|
|||||||
assert Path(output_file_name).exists()
|
assert Path(output_file_name).exists()
|
||||||
os.remove(Path(output_file_name))
|
os.remove(Path(output_file_name))
|
||||||
|
|
||||||
|
@parameterized.expand(
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"model",
|
|
||||||
[T5_TINY, BART_TINY, MBART_TINY, MARIAN_TINY, FSMT_TINY],
|
[T5_TINY, BART_TINY, MBART_TINY, MARIAN_TINY, FSMT_TINY],
|
||||||
)
|
)
|
||||||
def test_finetune(model):
|
def test_finetune(self, model):
|
||||||
args_d: dict = CHEAP_ARGS.copy()
|
args_d: dict = CHEAP_ARGS.copy()
|
||||||
task = "translation" if model in [MBART_TINY, MARIAN_TINY, FSMT_TINY] else "summarization"
|
task = "translation" if model in [MBART_TINY, MARIAN_TINY, FSMT_TINY] else "summarization"
|
||||||
args_d["label_smoothing"] = 0.1 if task == "translation" else 0
|
args_d["label_smoothing"] = 0.1 if task == "translation" else 0
|
||||||
|
|
||||||
tmp_dir = make_test_data_dir()
|
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
|
||||||
output_dir = tempfile.mkdtemp(prefix="output_")
|
output_dir = self.get_auto_remove_tmp_dir()
|
||||||
args_d.update(
|
args_d.update(
|
||||||
data_dir=tmp_dir,
|
data_dir=tmp_dir,
|
||||||
model_name_or_path=model,
|
model_name_or_path=model,
|
||||||
@@ -423,12 +409,11 @@ def test_finetune(model):
|
|||||||
assert isinstance(example_batch, dict)
|
assert isinstance(example_batch, dict)
|
||||||
assert len(example_batch) >= 4
|
assert len(example_batch) >= 4
|
||||||
|
|
||||||
|
def test_finetune_extra_model_args(self):
|
||||||
def test_finetune_extra_model_args():
|
|
||||||
args_d: dict = CHEAP_ARGS.copy()
|
args_d: dict = CHEAP_ARGS.copy()
|
||||||
|
|
||||||
task = "summarization"
|
task = "summarization"
|
||||||
tmp_dir = make_test_data_dir()
|
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
|
||||||
|
|
||||||
args_d.update(
|
args_d.update(
|
||||||
data_dir=tmp_dir,
|
data_dir=tmp_dir,
|
||||||
@@ -445,7 +430,7 @@ def test_finetune_extra_model_args():
|
|||||||
|
|
||||||
# test models whose config includes the extra_model_args
|
# test models whose config includes the extra_model_args
|
||||||
model = BART_TINY
|
model = BART_TINY
|
||||||
output_dir = tempfile.mkdtemp(prefix="output_1_")
|
output_dir = self.get_auto_remove_tmp_dir()
|
||||||
args_d1 = args_d.copy()
|
args_d1 = args_d.copy()
|
||||||
args_d1.update(
|
args_d1.update(
|
||||||
model_name_or_path=model,
|
model_name_or_path=model,
|
||||||
@@ -461,7 +446,7 @@ def test_finetune_extra_model_args():
|
|||||||
|
|
||||||
# test models whose config doesn't include the extra_model_args
|
# test models whose config doesn't include the extra_model_args
|
||||||
model = T5_TINY
|
model = T5_TINY
|
||||||
output_dir = tempfile.mkdtemp(prefix="output_2_")
|
output_dir = self.get_auto_remove_tmp_dir()
|
||||||
args_d2 = args_d.copy()
|
args_d2 = args_d.copy()
|
||||||
args_d2.update(
|
args_d2.update(
|
||||||
model_name_or_path=model,
|
model_name_or_path=model,
|
||||||
@@ -474,15 +459,14 @@ def test_finetune_extra_model_args():
|
|||||||
model = main(args)
|
model = main(args)
|
||||||
assert str(excinfo.value) == f"model config doesn't have a `{unsupported_param}` attribute"
|
assert str(excinfo.value) == f"model config doesn't have a `{unsupported_param}` attribute"
|
||||||
|
|
||||||
|
def test_finetune_lr_schedulers(self):
|
||||||
def test_finetune_lr_schedulers():
|
|
||||||
args_d: dict = CHEAP_ARGS.copy()
|
args_d: dict = CHEAP_ARGS.copy()
|
||||||
|
|
||||||
task = "summarization"
|
task = "summarization"
|
||||||
tmp_dir = make_test_data_dir()
|
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
|
||||||
|
|
||||||
model = BART_TINY
|
model = BART_TINY
|
||||||
output_dir = tempfile.mkdtemp(prefix="output_1_")
|
output_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
|
||||||
args_d.update(
|
args_d.update(
|
||||||
data_dir=tmp_dir,
|
data_dir=tmp_dir,
|
||||||
@@ -531,4 +515,6 @@ def test_finetune_lr_schedulers():
|
|||||||
args_d1["lr_scheduler"] = supported_param
|
args_d1["lr_scheduler"] = supported_param
|
||||||
args = argparse.Namespace(**args_d1)
|
args = argparse.Namespace(**args_d1)
|
||||||
model = main(args)
|
model = main(args)
|
||||||
assert getattr(model.hparams, "lr_scheduler") == supported_param, f"lr_scheduler={supported_param} shouldn't fail"
|
assert (
|
||||||
|
getattr(model.hparams, "lr_scheduler") == supported_param
|
||||||
|
), f"lr_scheduler={supported_param} shouldn't fail"
|
||||||
|
|||||||
Reference in New Issue
Block a user