[s2s test] cleanup (#8131)
This commit is contained in:
@@ -1,117 +1,17 @@
|
|||||||
# as due to their complexity multi-gpu tests could impact other tests, and to aid debug we have those in a separate module.
|
# as due to their complexity multi-gpu tests could impact other tests, and to aid debug we have those in a separate module.
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from transformers import is_torch_available
|
|
||||||
from transformers.testing_utils import TestCasePlus, execute_subprocess_async, require_torch_multigpu
|
from transformers.testing_utils import TestCasePlus, execute_subprocess_async, require_torch_multigpu
|
||||||
|
|
||||||
|
from .test_seq2seq_examples import CHEAP_ARGS, make_test_data_dir
|
||||||
from .utils import load_json
|
from .utils import load_json
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
|
||||||
|
|
||||||
logger = logging.getLogger()
|
|
||||||
CUDA_AVAILABLE = torch.cuda.is_available()
|
|
||||||
CHEAP_ARGS = {
|
|
||||||
"max_tokens_per_batch": None,
|
|
||||||
"supervise_forward": True,
|
|
||||||
"normalize_hidden": True,
|
|
||||||
"label_smoothing": 0.2,
|
|
||||||
"eval_max_gen_length": None,
|
|
||||||
"eval_beams": 1,
|
|
||||||
"val_metric": "loss",
|
|
||||||
"save_top_k": 1,
|
|
||||||
"adafactor": True,
|
|
||||||
"early_stopping_patience": 2,
|
|
||||||
"logger_name": "default",
|
|
||||||
"length_penalty": 0.5,
|
|
||||||
"cache_dir": "",
|
|
||||||
"task": "summarization",
|
|
||||||
"num_workers": 2,
|
|
||||||
"alpha_hid": 0,
|
|
||||||
"freeze_embeds": True,
|
|
||||||
"enc_only": False,
|
|
||||||
"tgt_suffix": "",
|
|
||||||
"resume_from_checkpoint": None,
|
|
||||||
"sortish_sampler": True,
|
|
||||||
"student_decoder_layers": 1,
|
|
||||||
"val_check_interval": 1.0,
|
|
||||||
"output_dir": "",
|
|
||||||
"fp16": False, # TODO(SS): set this to CUDA_AVAILABLE if ci installs apex or start using native amp
|
|
||||||
"no_teacher": False,
|
|
||||||
"fp16_opt_level": "O1",
|
|
||||||
"gpus": 1 if CUDA_AVAILABLE else 0,
|
|
||||||
"n_tpu_cores": 0,
|
|
||||||
"max_grad_norm": 1.0,
|
|
||||||
"do_train": True,
|
|
||||||
"do_predict": True,
|
|
||||||
"accumulate_grad_batches": 1,
|
|
||||||
"server_ip": "",
|
|
||||||
"server_port": "",
|
|
||||||
"seed": 42,
|
|
||||||
"model_name_or_path": "sshleifer/bart-tiny-random",
|
|
||||||
"config_name": "",
|
|
||||||
"tokenizer_name": "facebook/bart-large",
|
|
||||||
"do_lower_case": False,
|
|
||||||
"learning_rate": 0.3,
|
|
||||||
"lr_scheduler": "linear",
|
|
||||||
"weight_decay": 0.0,
|
|
||||||
"adam_epsilon": 1e-08,
|
|
||||||
"warmup_steps": 0,
|
|
||||||
"max_epochs": 1,
|
|
||||||
"train_batch_size": 2,
|
|
||||||
"eval_batch_size": 2,
|
|
||||||
"max_source_length": 12,
|
|
||||||
"max_target_length": 12,
|
|
||||||
"val_max_target_length": 12,
|
|
||||||
"test_max_target_length": 12,
|
|
||||||
"fast_dev_run": False,
|
|
||||||
"no_cache": False,
|
|
||||||
"n_train": -1,
|
|
||||||
"n_val": -1,
|
|
||||||
"n_test": -1,
|
|
||||||
"student_encoder_layers": 1,
|
|
||||||
"freeze_encoder": False,
|
|
||||||
"auto_scale_batch_size": False,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _dump_articles(path: Path, articles: list):
|
|
||||||
content = "\n".join(articles)
|
|
||||||
Path(path).open("w").writelines(content)
|
|
||||||
|
|
||||||
|
|
||||||
ARTICLES = [" Sam ate lunch today.", "Sams lunch ingredients."]
|
|
||||||
SUMMARIES = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"]
|
|
||||||
T5_TINY = "patrickvonplaten/t5-tiny-random"
|
|
||||||
BART_TINY = "sshleifer/bart-tiny-random"
|
|
||||||
MBART_TINY = "sshleifer/tiny-mbart"
|
|
||||||
MARIAN_TINY = "sshleifer/tiny-marian-en-de"
|
|
||||||
|
|
||||||
|
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
|
||||||
logger.addHandler(stream_handler)
|
|
||||||
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
|
|
||||||
|
|
||||||
|
|
||||||
def make_test_data_dir(tmp_dir):
|
|
||||||
for split in ["train", "val", "test"]:
|
|
||||||
_dump_articles(os.path.join(tmp_dir, f"{split}.source"), ARTICLES)
|
|
||||||
_dump_articles(os.path.join(tmp_dir, f"{split}.target"), SUMMARIES)
|
|
||||||
return tmp_dir
|
|
||||||
|
|
||||||
|
|
||||||
class TestSummarizationDistillerMultiGPU(TestCasePlus):
|
class TestSummarizationDistillerMultiGPU(TestCasePlus):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
|
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
@require_torch_multigpu
|
@require_torch_multigpu
|
||||||
|
|||||||
Reference in New Issue
Block a user