[s2sTrainer] test + code cleanup (#7467)

This commit is contained in:
Sam Shleifer
2020-10-01 00:33:01 -04:00
committed by GitHub
parent 097049b81b
commit 48f23f92a8
5 changed files with 102 additions and 116 deletions

View File

@@ -26,6 +26,7 @@ from utils import (
calculate_bleu,
calculate_rouge,
flatten_list,
freeze_embeds,
freeze_params,
get_git_info,
label_smoothed_nll_loss,
@@ -90,7 +91,7 @@ class SummarizationModule(BaseTransformer):
assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"
if self.hparams.freeze_embeds:
self.freeze_embeds()
freeze_embeds(self.model)
if self.hparams.freeze_encoder:
freeze_params(self.model.get_encoder())
assert_all_frozen(self.model.get_encoder())
@@ -105,29 +106,12 @@ class SummarizationModule(BaseTransformer):
Seq2SeqDataset if hasattr(self.tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset
)
self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams
assert self.eval_beams >= 1, f"got self.eval_beams={self.eval_beams}. Need an integer > 1"
if self.hparams.eval_max_gen_length is not None:
self.eval_max_length = self.hparams.eval_max_gen_length
else:
self.eval_max_length = self.model.config.max_length
self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric
def freeze_embeds(self):
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
if self.model_type == "t5":
freeze_params(self.model.shared)
for d in [self.model.encoder, self.model.decoder]:
freeze_params(d.embed_tokens)
elif self.model_type == "fsmt":
for d in [self.model.model.encoder, self.model.model.decoder]:
freeze_params(d.embed_positions)
freeze_params(d.embed_tokens)
else:
freeze_params(self.model.model.shared)
for d in [self.model.model.encoder, self.model.model.decoder]:
freeze_params(d.embed_positions)
freeze_params(d.embed_tokens)
def forward(self, input_ids, **kwargs):
return self.model(input_ids, **kwargs)