Apply ruff flake8-comprehensions (#21694)

This commit is contained in:
Aaron Gokaslan
2023-02-22 03:14:54 -05:00
committed by GitHub
parent df06fb1f0b
commit 5e8c8eb5ba
230 changed files with 971 additions and 955 deletions

View File

@@ -145,18 +145,18 @@ class TestSummarizationDistiller(TestCasePlus):
assert not failures, f"The following models could not be loaded through AutoConfig: {failures}"
def test_distill_no_teacher(self):
updates = dict(student_encoder_layers=2, student_decoder_layers=1, no_teacher=True)
updates = {"student_encoder_layers": 2, "student_decoder_layers": 1, "no_teacher": True}
self._test_distiller_cli(updates)
def test_distill_checkpointing_with_teacher(self):
updates = dict(
student_encoder_layers=2,
student_decoder_layers=1,
max_epochs=4,
val_check_interval=0.25,
alpha_hid=2.0,
model_name_or_path="IGNORE_THIS_IT_DOESNT_GET_USED",
)
updates = {
"student_encoder_layers": 2,
"student_decoder_layers": 1,
"max_epochs": 4,
"val_check_interval": 0.25,
"alpha_hid": 2.0,
"model_name_or_path": "IGNORE_THIS_IT_DOESNT_GET_USED",
}
model = self._test_distiller_cli(updates, check_contents=False)
ckpts = list(Path(model.output_dir).glob("*.ckpt"))
@@ -193,19 +193,19 @@ class TestSummarizationDistiller(TestCasePlus):
self.assertEqual(nll_loss, model_computed_loss)
def test_distill_mbart(self):
updates = dict(
student_encoder_layers=2,
student_decoder_layers=1,
num_train_epochs=4,
val_check_interval=0.25,
alpha_hid=2.0,
task="translation",
model_name_or_path="IGNORE_THIS_IT_DOESNT_GET_USED",
tokenizer_name=MBART_TINY,
teacher=MBART_TINY,
src_lang="en_XX",
tgt_lang="ro_RO",
)
updates = {
"student_encoder_layers": 2,
"student_decoder_layers": 1,
"num_train_epochs": 4,
"val_check_interval": 0.25,
"alpha_hid": 2.0,
"task": "translation",
"model_name_or_path": "IGNORE_THIS_IT_DOESNT_GET_USED",
"tokenizer_name": MBART_TINY,
"teacher": MBART_TINY,
"src_lang": "en_XX",
"tgt_lang": "ro_RO",
}
model = self._test_distiller_cli(updates, check_contents=False)
assert model.model.config.model_type == "mbart"
@@ -217,39 +217,39 @@ class TestSummarizationDistiller(TestCasePlus):
self.assertEqual(len(transformer_ckpts), 2)
def test_distill_t5(self):
updates = dict(
student_encoder_layers=1,
student_decoder_layers=1,
alpha_hid=2.0,
teacher=T5_TINY,
model_name_or_path=T5_TINY,
tokenizer_name=T5_TINY,
)
updates = {
"student_encoder_layers": 1,
"student_decoder_layers": 1,
"alpha_hid": 2.0,
"teacher": T5_TINY,
"model_name_or_path": T5_TINY,
"tokenizer_name": T5_TINY,
}
self._test_distiller_cli(updates)
def test_distill_different_base_models(self):
updates = dict(
teacher=T5_TINY,
student=T5_TINIER,
model_name_or_path=T5_TINIER,
tokenizer_name=T5_TINIER,
)
updates = {
"teacher": T5_TINY,
"student": T5_TINIER,
"model_name_or_path": T5_TINIER,
"tokenizer_name": T5_TINIER,
}
self._test_distiller_cli(updates)
def _test_distiller_cli(self, updates, check_contents=True):
default_updates = dict(
label_smoothing=0.0,
early_stopping_patience=-1,
train_batch_size=1,
eval_batch_size=2,
max_epochs=2,
alpha_mlm=0.2,
alpha_ce=0.8,
do_predict=True,
model_name_or_path="sshleifer/tinier_bart",
teacher=CHEAP_ARGS["model_name_or_path"],
val_check_interval=0.5,
)
default_updates = {
"label_smoothing": 0.0,
"early_stopping_patience": -1,
"train_batch_size": 1,
"eval_batch_size": 2,
"max_epochs": 2,
"alpha_mlm": 0.2,
"alpha_ce": 0.8,
"do_predict": True,
"model_name_or_path": "sshleifer/tinier_bart",
"teacher": CHEAP_ARGS["model_name_or_path"],
"val_check_interval": 0.5,
}
default_updates.update(updates)
args_d: dict = CHEAP_ARGS.copy()
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())

View File

@@ -98,29 +98,29 @@ class TestSummarizationDistillerMultiGPU(TestCasePlus):
@require_torch_multi_gpu
def test_multi_gpu(self):
updates = dict(
no_teacher=True,
freeze_encoder=True,
gpus=2,
overwrite_output_dir=True,
sortish_sampler=True,
)
updates = {
"no_teacher": True,
"freeze_encoder": True,
"gpus": 2,
"overwrite_output_dir": True,
"sortish_sampler": True,
}
self._test_distiller_cli_fork(updates, check_contents=False)
def _test_distiller_cli_fork(self, updates, check_contents=True):
default_updates = dict(
label_smoothing=0.0,
early_stopping_patience=-1,
train_batch_size=1,
eval_batch_size=2,
max_epochs=2,
alpha_mlm=0.2,
alpha_ce=0.8,
do_predict=True,
model_name_or_path="sshleifer/tinier_bart",
teacher=CHEAP_ARGS["model_name_or_path"],
val_check_interval=0.5,
)
default_updates = {
"label_smoothing": 0.0,
"early_stopping_patience": -1,
"train_batch_size": 1,
"eval_batch_size": 2,
"max_epochs": 2,
"alpha_mlm": 0.2,
"alpha_ce": 0.8,
"do_predict": True,
"model_name_or_path": "sshleifer/tinier_bart",
"teacher": CHEAP_ARGS["model_name_or_path"],
"val_check_interval": 0.5,
}
default_updates.update(updates)
args_d: dict = CHEAP_ARGS.copy()
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())

View File

@@ -74,11 +74,11 @@ class SummarizationModule(BaseTransformer):
self.model_type = self.config.model_type
self.vocab_size = self.config.tgt_vocab_size if self.model_type == "fsmt" else self.config.vocab_size
self.dataset_kwargs: dict = dict(
data_dir=self.hparams.data_dir,
max_source_length=self.hparams.max_source_length,
prefix=self.model.config.prefix or "",
)
self.dataset_kwargs: dict = {
"data_dir": self.hparams.data_dir,
"max_source_length": self.hparams.max_source_length,
"prefix": self.model.config.prefix or "",
}
n_observations_per_split = {
"train": self.hparams.n_train,
"val": self.hparams.n_val,
@@ -433,7 +433,7 @@ def main(args, model=None) -> SummarizationModule:
return model
model.hparams.test_checkpoint = ""
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True)))
checkpoints = sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True))
if checkpoints:
model.hparams.test_checkpoint = checkpoints[-1]
trainer.resume_from_checkpoint = checkpoints[-1]

View File

@@ -171,11 +171,11 @@ def create_student_by_copying_alternating_layers(
logger.info(
f"Copied encoder layers {e_layers_to_copy} and decoder layers {d_layers_to_copy}. Saving them to {save_path}"
)
student.config.init_metadata = dict(
teacher_type=teacher.config.model_type,
copied_encoder_layers=e_layers_to_copy,
copied_decoder_layers=d_layers_to_copy,
)
student.config.init_metadata = {
"teacher_type": teacher.config.model_type,
"copied_encoder_layers": e_layers_to_copy,
"copied_decoder_layers": d_layers_to_copy,
}
student.save_pretrained(save_path)
# Save information about copying for easier reproducibility

View File

@@ -63,7 +63,7 @@ def generate_summaries_or_translations(
fout.close()
runtime = int(time.time() - start_time) # seconds
n_obs = len(examples)
return dict(n_obs=n_obs, runtime=runtime, seconds_per_sample=round(runtime / n_obs, 4))
return {"n_obs": n_obs, "runtime": runtime, "seconds_per_sample": round(runtime / n_obs, 4)}
def datetime_now():

View File

@@ -437,7 +437,7 @@ def pickle_save(obj, path):
def flatten_list(summary_ids: List[List]):
return [x for x in itertools.chain.from_iterable(summary_ids)]
return list(itertools.chain.from_iterable(summary_ids))
def save_git_info(folder_path: str) -> None: