Apply ruff flake8-comprehensions (#21694)
This commit is contained in:
@@ -145,18 +145,18 @@ class TestSummarizationDistiller(TestCasePlus):
|
||||
assert not failures, f"The following models could not be loaded through AutoConfig: {failures}"
|
||||
|
||||
def test_distill_no_teacher(self):
|
||||
updates = dict(student_encoder_layers=2, student_decoder_layers=1, no_teacher=True)
|
||||
updates = {"student_encoder_layers": 2, "student_decoder_layers": 1, "no_teacher": True}
|
||||
self._test_distiller_cli(updates)
|
||||
|
||||
def test_distill_checkpointing_with_teacher(self):
|
||||
updates = dict(
|
||||
student_encoder_layers=2,
|
||||
student_decoder_layers=1,
|
||||
max_epochs=4,
|
||||
val_check_interval=0.25,
|
||||
alpha_hid=2.0,
|
||||
model_name_or_path="IGNORE_THIS_IT_DOESNT_GET_USED",
|
||||
)
|
||||
updates = {
|
||||
"student_encoder_layers": 2,
|
||||
"student_decoder_layers": 1,
|
||||
"max_epochs": 4,
|
||||
"val_check_interval": 0.25,
|
||||
"alpha_hid": 2.0,
|
||||
"model_name_or_path": "IGNORE_THIS_IT_DOESNT_GET_USED",
|
||||
}
|
||||
model = self._test_distiller_cli(updates, check_contents=False)
|
||||
|
||||
ckpts = list(Path(model.output_dir).glob("*.ckpt"))
|
||||
@@ -193,19 +193,19 @@ class TestSummarizationDistiller(TestCasePlus):
|
||||
self.assertEqual(nll_loss, model_computed_loss)
|
||||
|
||||
def test_distill_mbart(self):
|
||||
updates = dict(
|
||||
student_encoder_layers=2,
|
||||
student_decoder_layers=1,
|
||||
num_train_epochs=4,
|
||||
val_check_interval=0.25,
|
||||
alpha_hid=2.0,
|
||||
task="translation",
|
||||
model_name_or_path="IGNORE_THIS_IT_DOESNT_GET_USED",
|
||||
tokenizer_name=MBART_TINY,
|
||||
teacher=MBART_TINY,
|
||||
src_lang="en_XX",
|
||||
tgt_lang="ro_RO",
|
||||
)
|
||||
updates = {
|
||||
"student_encoder_layers": 2,
|
||||
"student_decoder_layers": 1,
|
||||
"num_train_epochs": 4,
|
||||
"val_check_interval": 0.25,
|
||||
"alpha_hid": 2.0,
|
||||
"task": "translation",
|
||||
"model_name_or_path": "IGNORE_THIS_IT_DOESNT_GET_USED",
|
||||
"tokenizer_name": MBART_TINY,
|
||||
"teacher": MBART_TINY,
|
||||
"src_lang": "en_XX",
|
||||
"tgt_lang": "ro_RO",
|
||||
}
|
||||
model = self._test_distiller_cli(updates, check_contents=False)
|
||||
assert model.model.config.model_type == "mbart"
|
||||
|
||||
@@ -217,39 +217,39 @@ class TestSummarizationDistiller(TestCasePlus):
|
||||
self.assertEqual(len(transformer_ckpts), 2)
|
||||
|
||||
def test_distill_t5(self):
|
||||
updates = dict(
|
||||
student_encoder_layers=1,
|
||||
student_decoder_layers=1,
|
||||
alpha_hid=2.0,
|
||||
teacher=T5_TINY,
|
||||
model_name_or_path=T5_TINY,
|
||||
tokenizer_name=T5_TINY,
|
||||
)
|
||||
updates = {
|
||||
"student_encoder_layers": 1,
|
||||
"student_decoder_layers": 1,
|
||||
"alpha_hid": 2.0,
|
||||
"teacher": T5_TINY,
|
||||
"model_name_or_path": T5_TINY,
|
||||
"tokenizer_name": T5_TINY,
|
||||
}
|
||||
self._test_distiller_cli(updates)
|
||||
|
||||
def test_distill_different_base_models(self):
|
||||
updates = dict(
|
||||
teacher=T5_TINY,
|
||||
student=T5_TINIER,
|
||||
model_name_or_path=T5_TINIER,
|
||||
tokenizer_name=T5_TINIER,
|
||||
)
|
||||
updates = {
|
||||
"teacher": T5_TINY,
|
||||
"student": T5_TINIER,
|
||||
"model_name_or_path": T5_TINIER,
|
||||
"tokenizer_name": T5_TINIER,
|
||||
}
|
||||
self._test_distiller_cli(updates)
|
||||
|
||||
def _test_distiller_cli(self, updates, check_contents=True):
|
||||
default_updates = dict(
|
||||
label_smoothing=0.0,
|
||||
early_stopping_patience=-1,
|
||||
train_batch_size=1,
|
||||
eval_batch_size=2,
|
||||
max_epochs=2,
|
||||
alpha_mlm=0.2,
|
||||
alpha_ce=0.8,
|
||||
do_predict=True,
|
||||
model_name_or_path="sshleifer/tinier_bart",
|
||||
teacher=CHEAP_ARGS["model_name_or_path"],
|
||||
val_check_interval=0.5,
|
||||
)
|
||||
default_updates = {
|
||||
"label_smoothing": 0.0,
|
||||
"early_stopping_patience": -1,
|
||||
"train_batch_size": 1,
|
||||
"eval_batch_size": 2,
|
||||
"max_epochs": 2,
|
||||
"alpha_mlm": 0.2,
|
||||
"alpha_ce": 0.8,
|
||||
"do_predict": True,
|
||||
"model_name_or_path": "sshleifer/tinier_bart",
|
||||
"teacher": CHEAP_ARGS["model_name_or_path"],
|
||||
"val_check_interval": 0.5,
|
||||
}
|
||||
default_updates.update(updates)
|
||||
args_d: dict = CHEAP_ARGS.copy()
|
||||
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
|
||||
|
||||
@@ -98,29 +98,29 @@ class TestSummarizationDistillerMultiGPU(TestCasePlus):
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_multi_gpu(self):
|
||||
updates = dict(
|
||||
no_teacher=True,
|
||||
freeze_encoder=True,
|
||||
gpus=2,
|
||||
overwrite_output_dir=True,
|
||||
sortish_sampler=True,
|
||||
)
|
||||
updates = {
|
||||
"no_teacher": True,
|
||||
"freeze_encoder": True,
|
||||
"gpus": 2,
|
||||
"overwrite_output_dir": True,
|
||||
"sortish_sampler": True,
|
||||
}
|
||||
self._test_distiller_cli_fork(updates, check_contents=False)
|
||||
|
||||
def _test_distiller_cli_fork(self, updates, check_contents=True):
|
||||
default_updates = dict(
|
||||
label_smoothing=0.0,
|
||||
early_stopping_patience=-1,
|
||||
train_batch_size=1,
|
||||
eval_batch_size=2,
|
||||
max_epochs=2,
|
||||
alpha_mlm=0.2,
|
||||
alpha_ce=0.8,
|
||||
do_predict=True,
|
||||
model_name_or_path="sshleifer/tinier_bart",
|
||||
teacher=CHEAP_ARGS["model_name_or_path"],
|
||||
val_check_interval=0.5,
|
||||
)
|
||||
default_updates = {
|
||||
"label_smoothing": 0.0,
|
||||
"early_stopping_patience": -1,
|
||||
"train_batch_size": 1,
|
||||
"eval_batch_size": 2,
|
||||
"max_epochs": 2,
|
||||
"alpha_mlm": 0.2,
|
||||
"alpha_ce": 0.8,
|
||||
"do_predict": True,
|
||||
"model_name_or_path": "sshleifer/tinier_bart",
|
||||
"teacher": CHEAP_ARGS["model_name_or_path"],
|
||||
"val_check_interval": 0.5,
|
||||
}
|
||||
default_updates.update(updates)
|
||||
args_d: dict = CHEAP_ARGS.copy()
|
||||
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
|
||||
|
||||
@@ -74,11 +74,11 @@ class SummarizationModule(BaseTransformer):
|
||||
self.model_type = self.config.model_type
|
||||
self.vocab_size = self.config.tgt_vocab_size if self.model_type == "fsmt" else self.config.vocab_size
|
||||
|
||||
self.dataset_kwargs: dict = dict(
|
||||
data_dir=self.hparams.data_dir,
|
||||
max_source_length=self.hparams.max_source_length,
|
||||
prefix=self.model.config.prefix or "",
|
||||
)
|
||||
self.dataset_kwargs: dict = {
|
||||
"data_dir": self.hparams.data_dir,
|
||||
"max_source_length": self.hparams.max_source_length,
|
||||
"prefix": self.model.config.prefix or "",
|
||||
}
|
||||
n_observations_per_split = {
|
||||
"train": self.hparams.n_train,
|
||||
"val": self.hparams.n_val,
|
||||
@@ -433,7 +433,7 @@ def main(args, model=None) -> SummarizationModule:
|
||||
return model
|
||||
|
||||
model.hparams.test_checkpoint = ""
|
||||
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True)))
|
||||
checkpoints = sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True))
|
||||
if checkpoints:
|
||||
model.hparams.test_checkpoint = checkpoints[-1]
|
||||
trainer.resume_from_checkpoint = checkpoints[-1]
|
||||
|
||||
@@ -171,11 +171,11 @@ def create_student_by_copying_alternating_layers(
|
||||
logger.info(
|
||||
f"Copied encoder layers {e_layers_to_copy} and decoder layers {d_layers_to_copy}. Saving them to {save_path}"
|
||||
)
|
||||
student.config.init_metadata = dict(
|
||||
teacher_type=teacher.config.model_type,
|
||||
copied_encoder_layers=e_layers_to_copy,
|
||||
copied_decoder_layers=d_layers_to_copy,
|
||||
)
|
||||
student.config.init_metadata = {
|
||||
"teacher_type": teacher.config.model_type,
|
||||
"copied_encoder_layers": e_layers_to_copy,
|
||||
"copied_decoder_layers": d_layers_to_copy,
|
||||
}
|
||||
student.save_pretrained(save_path)
|
||||
# Save information about copying for easier reproducibility
|
||||
|
||||
|
||||
@@ -63,7 +63,7 @@ def generate_summaries_or_translations(
|
||||
fout.close()
|
||||
runtime = int(time.time() - start_time) # seconds
|
||||
n_obs = len(examples)
|
||||
return dict(n_obs=n_obs, runtime=runtime, seconds_per_sample=round(runtime / n_obs, 4))
|
||||
return {"n_obs": n_obs, "runtime": runtime, "seconds_per_sample": round(runtime / n_obs, 4)}
|
||||
|
||||
|
||||
def datetime_now():
|
||||
|
||||
@@ -437,7 +437,7 @@ def pickle_save(obj, path):
|
||||
|
||||
|
||||
def flatten_list(summary_ids: List[List]):
|
||||
return [x for x in itertools.chain.from_iterable(summary_ids)]
|
||||
return list(itertools.chain.from_iterable(summary_ids))
|
||||
|
||||
|
||||
def save_git_info(folder_path: str) -> None:
|
||||
|
||||
Reference in New Issue
Block a user