[s2sTrainer] test + code cleanup (#7467)
This commit is contained in:
@@ -26,6 +26,7 @@ from utils import (
|
||||
calculate_bleu,
|
||||
calculate_rouge,
|
||||
flatten_list,
|
||||
freeze_embeds,
|
||||
freeze_params,
|
||||
get_git_info,
|
||||
label_smoothed_nll_loss,
|
||||
@@ -90,7 +91,7 @@ class SummarizationModule(BaseTransformer):
|
||||
assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
|
||||
assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"
|
||||
if self.hparams.freeze_embeds:
|
||||
self.freeze_embeds()
|
||||
freeze_embeds(self.model)
|
||||
if self.hparams.freeze_encoder:
|
||||
freeze_params(self.model.get_encoder())
|
||||
assert_all_frozen(self.model.get_encoder())
|
||||
@@ -105,29 +106,12 @@ class SummarizationModule(BaseTransformer):
|
||||
Seq2SeqDataset if hasattr(self.tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset
|
||||
)
|
||||
self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams
|
||||
assert self.eval_beams >= 1, f"got self.eval_beams={self.eval_beams}. Need an integer > 1"
|
||||
if self.hparams.eval_max_gen_length is not None:
|
||||
self.eval_max_length = self.hparams.eval_max_gen_length
|
||||
else:
|
||||
self.eval_max_length = self.model.config.max_length
|
||||
self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric
|
||||
|
||||
def freeze_embeds(self):
|
||||
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
|
||||
if self.model_type == "t5":
|
||||
freeze_params(self.model.shared)
|
||||
for d in [self.model.encoder, self.model.decoder]:
|
||||
freeze_params(d.embed_tokens)
|
||||
elif self.model_type == "fsmt":
|
||||
for d in [self.model.model.encoder, self.model.model.decoder]:
|
||||
freeze_params(d.embed_positions)
|
||||
freeze_params(d.embed_tokens)
|
||||
else:
|
||||
freeze_params(self.model.model.shared)
|
||||
for d in [self.model.model.encoder, self.model.model.decoder]:
|
||||
freeze_params(d.embed_positions)
|
||||
freeze_params(d.embed_tokens)
|
||||
|
||||
def forward(self, input_ids, **kwargs):
|
||||
return self.model(input_ids, **kwargs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user