[s2s] adjust finetune + test to work with fsmt (#7263)
This commit is contained in:
@@ -61,6 +61,8 @@ class SummarizationModule(BaseTransformer):
|
||||
pickle_save(self.hparams, self.hparams_save_path)
|
||||
self.step_count = 0
|
||||
self.metrics = defaultdict(list)
|
||||
self.model_type = self.config.model_type
|
||||
self.vocab_size = self.config.tgt_vocab_size if self.model_type == "fsmt" else self.config.vocab_size
|
||||
|
||||
self.dataset_kwargs: dict = dict(
|
||||
data_dir=self.hparams.data_dir,
|
||||
@@ -106,14 +108,18 @@ class SummarizationModule(BaseTransformer):
|
||||
|
||||
def freeze_embeds(self):
|
||||
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
|
||||
try:
|
||||
freeze_params(self.model.model.shared)
|
||||
if self.model_type == "t5":
|
||||
freeze_params(self.model.shared)
|
||||
for d in [self.model.encoder, self.model.decoder]:
|
||||
freeze_params(d.embed_tokens)
|
||||
elif self.model_type == "fsmt":
|
||||
for d in [self.model.model.encoder, self.model.model.decoder]:
|
||||
freeze_params(d.embed_positions)
|
||||
freeze_params(d.embed_tokens)
|
||||
except AttributeError:
|
||||
freeze_params(self.model.shared)
|
||||
for d in [self.model.encoder, self.model.decoder]:
|
||||
else:
|
||||
freeze_params(self.model.model.shared)
|
||||
for d in [self.model.model.encoder, self.model.model.decoder]:
|
||||
freeze_params(d.embed_positions)
|
||||
freeze_params(d.embed_tokens)
|
||||
|
||||
def forward(self, input_ids, **kwargs):
|
||||
@@ -140,7 +146,7 @@ class SummarizationModule(BaseTransformer):
|
||||
# Same behavior as modeling_bart.py, besides ignoring pad_token_id
|
||||
ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
|
||||
|
||||
assert lm_logits.shape[-1] == self.model.config.vocab_size
|
||||
assert lm_logits.shape[-1] == self.vocab_size
|
||||
loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))
|
||||
else:
|
||||
lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
|
||||
|
||||
Reference in New Issue
Block a user