Black 20 release

This commit is contained in:
Lysandre
2020-08-26 17:20:22 +02:00
parent e78c110338
commit a75c64d80c
191 changed files with 4807 additions and 3503 deletions

View File

@@ -117,7 +117,12 @@ class TestSummarizationDistiller(unittest.TestCase):
@require_multigpu
def test_multigpu(self):
updates = dict(no_teacher=True, freeze_encoder=True, gpus=2, sortish_sampler=False,)
updates = dict(
no_teacher=True,
freeze_encoder=True,
gpus=2,
sortish_sampler=False,
)
self._test_distiller_cli(updates)
def test_distill_no_teacher(self):
@@ -261,7 +266,8 @@ def test_run_eval_bart(model):
@pytest.mark.parametrize(
["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)],
["model"],
[pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)],
)
def test_finetune(model):
args_d: dict = CHEAP_ARGS.copy()
@@ -329,7 +335,8 @@ def test_finetune_extra_model_args():
output_dir = tempfile.mkdtemp(prefix="output_1_")
args_d1 = args_d.copy()
args_d1.update(
model_name_or_path=model, output_dir=output_dir,
model_name_or_path=model,
output_dir=output_dir,
)
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout")
for p in extra_model_params:
@@ -344,7 +351,8 @@ def test_finetune_extra_model_args():
output_dir = tempfile.mkdtemp(prefix="output_2_")
args_d2 = args_d.copy()
args_d2.update(
model_name_or_path=model, output_dir=output_dir,
model_name_or_path=model,
output_dir=output_dir,
)
unsupported_param = "encoder_layerdrop"
args_d2[unsupported_param] = 0.5
@@ -478,7 +486,11 @@ def test_summarization_dataset_truncation(tok):
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
trunc_target = 4
train_dataset = Seq2SeqDataset(
tokenizer, data_dir=tmp_dir, type_path="train", max_source_length=20, max_target_length=trunc_target,
tokenizer,
data_dir=tmp_dir,
type_path="train",
max_source_length=20,
max_target_length=trunc_target,
)
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
for batch in dataloader: