diff --git a/examples/seq2seq/test_seq2seq_examples_multi_gpu.py b/examples/seq2seq/test_seq2seq_examples_multi_gpu.py index 03ec39037c..463ad1e7d9 100644 --- a/examples/seq2seq/test_seq2seq_examples_multi_gpu.py +++ b/examples/seq2seq/test_seq2seq_examples_multi_gpu.py @@ -1,117 +1,17 @@ # as due to their complexity multi-gpu tests could impact other tests, and to aid debug we have those in a separate module. -import logging import os import sys -from pathlib import Path -from transformers import is_torch_available from transformers.testing_utils import TestCasePlus, execute_subprocess_async, require_torch_multigpu +from .test_seq2seq_examples import CHEAP_ARGS, make_test_data_dir from .utils import load_json -if is_torch_available(): - import torch - - -logging.basicConfig(level=logging.DEBUG) - -logger = logging.getLogger() -CUDA_AVAILABLE = torch.cuda.is_available() -CHEAP_ARGS = { - "max_tokens_per_batch": None, - "supervise_forward": True, - "normalize_hidden": True, - "label_smoothing": 0.2, - "eval_max_gen_length": None, - "eval_beams": 1, - "val_metric": "loss", - "save_top_k": 1, - "adafactor": True, - "early_stopping_patience": 2, - "logger_name": "default", - "length_penalty": 0.5, - "cache_dir": "", - "task": "summarization", - "num_workers": 2, - "alpha_hid": 0, - "freeze_embeds": True, - "enc_only": False, - "tgt_suffix": "", - "resume_from_checkpoint": None, - "sortish_sampler": True, - "student_decoder_layers": 1, - "val_check_interval": 1.0, - "output_dir": "", - "fp16": False, # TODO(SS): set this to CUDA_AVAILABLE if ci installs apex or start using native amp - "no_teacher": False, - "fp16_opt_level": "O1", - "gpus": 1 if CUDA_AVAILABLE else 0, - "n_tpu_cores": 0, - "max_grad_norm": 1.0, - "do_train": True, - "do_predict": True, - "accumulate_grad_batches": 1, - "server_ip": "", - "server_port": "", - "seed": 42, - "model_name_or_path": "sshleifer/bart-tiny-random", - "config_name": "", - "tokenizer_name": "facebook/bart-large", - "do_lower_case": False, - "learning_rate": 0.3, - "lr_scheduler": "linear", - "weight_decay": 0.0, - "adam_epsilon": 1e-08, - "warmup_steps": 0, - "max_epochs": 1, - "train_batch_size": 2, - "eval_batch_size": 2, - "max_source_length": 12, - "max_target_length": 12, - "val_max_target_length": 12, - "test_max_target_length": 12, - "fast_dev_run": False, - "no_cache": False, - "n_train": -1, - "n_val": -1, - "n_test": -1, - "student_encoder_layers": 1, - "freeze_encoder": False, - "auto_scale_batch_size": False, -} - - -def _dump_articles(path: Path, articles: list): - content = "\n".join(articles) - Path(path).open("w").writelines(content) - - -ARTICLES = [" Sam ate lunch today.", "Sams lunch ingredients."] -SUMMARIES = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"] -T5_TINY = "patrickvonplaten/t5-tiny-random" -BART_TINY = "sshleifer/bart-tiny-random" -MBART_TINY = "sshleifer/tiny-mbart" -MARIAN_TINY = "sshleifer/tiny-marian-en-de" - - -stream_handler = logging.StreamHandler(sys.stdout) -logger.addHandler(stream_handler) -logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks - - -def make_test_data_dir(tmp_dir): - for split in ["train", "val", "test"]: - _dump_articles(os.path.join(tmp_dir, f"{split}.source"), ARTICLES) - _dump_articles(os.path.join(tmp_dir, f"{split}.target"), SUMMARIES) - return tmp_dir - - class TestSummarizationDistillerMultiGPU(TestCasePlus): @classmethod def setUpClass(cls): - logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks return cls @require_torch_multigpu