prepare_seq2seq_batch makes labels/ decoder_input_ids made later. (#6654)
* broken test * batch parity * tests pass * boom boom * boom boom * split out bart tokenizer tests * fix tests * boom boom * Fixed dataset bug * Fix marian * Undo extra * Get marian working * Fix t5 tok tests * Test passing * Cleanup * better assert msg * require torch * Fix mbart tests * undo extra decoder_attn_mask change * Fix import * pegasus tokenizer can ignore src_lang kwargs * unused kwarg test cov * boom boom * add todo for pegasus issue * cover one word translation edge case * Cleanup * doc
This commit is contained in:
@@ -13,15 +13,16 @@ import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from lightning_base import BaseTransformer, add_generic_args, generic_train
|
||||
from transformers import MarianTokenizer, MBartTokenizer, T5ForConditionalGeneration
|
||||
from transformers import MBartTokenizer, T5ForConditionalGeneration
|
||||
from transformers.modeling_bart import shift_tokens_right
|
||||
|
||||
|
||||
try:
|
||||
from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
|
||||
from .utils import (
|
||||
ROUGE_KEYS,
|
||||
LegacySeq2SeqDataset,
|
||||
Seq2SeqDataset,
|
||||
TranslationDataset,
|
||||
assert_all_frozen,
|
||||
calculate_bleu,
|
||||
calculate_rouge,
|
||||
@@ -39,8 +40,8 @@ except ImportError:
|
||||
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
|
||||
from utils import (
|
||||
ROUGE_KEYS,
|
||||
LegacySeq2SeqDataset,
|
||||
Seq2SeqDataset,
|
||||
TranslationDataset,
|
||||
assert_all_frozen,
|
||||
calculate_bleu,
|
||||
calculate_rouge,
|
||||
@@ -102,14 +103,13 @@ class SummarizationModule(BaseTransformer):
|
||||
|
||||
self.hparams.git_sha = get_git_info()["repo_sha"]
|
||||
self.num_workers = hparams.num_workers
|
||||
self.decoder_start_token_id = None
|
||||
self.decoder_start_token_id = None # default to config
|
||||
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
|
||||
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
|
||||
self.model.config.decoder_start_token_id = self.decoder_start_token_id
|
||||
if isinstance(self.tokenizer, MBartTokenizer) or isinstance(self.tokenizer, MarianTokenizer):
|
||||
self.dataset_class = TranslationDataset
|
||||
else:
|
||||
self.dataset_class = Seq2SeqDataset
|
||||
self.dataset_class = (
|
||||
Seq2SeqDataset if hasattr(self.tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset
|
||||
)
|
||||
|
||||
def freeze_embeds(self):
|
||||
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
|
||||
@@ -134,27 +134,25 @@ class SummarizationModule(BaseTransformer):
|
||||
|
||||
def _step(self, batch: dict) -> Tuple:
|
||||
pad_token_id = self.tokenizer.pad_token_id
|
||||
source_ids, source_mask, target_ids = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
|
||||
|
||||
src_ids, src_mask = batch["input_ids"], batch["attention_mask"]
|
||||
tgt_ids = batch["labels"]
|
||||
if isinstance(self.model, T5ForConditionalGeneration):
|
||||
decoder_input_ids = self.model._shift_right(target_ids)
|
||||
lm_labels = target_ids
|
||||
decoder_input_ids = self.model._shift_right(tgt_ids)
|
||||
else:
|
||||
decoder_input_ids = target_ids[:, :-1].contiguous() # Why this line?
|
||||
lm_labels = target_ids[:, 1:].clone() # why clone?
|
||||
|
||||
outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
|
||||
decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id)
|
||||
|
||||
outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
|
||||
lm_logits = outputs[0]
|
||||
if self.hparams.label_smoothing == 0:
|
||||
# Same behavior as modeling_bart.py
|
||||
# Same behavior as modeling_bart.py, besides ignoring pad_token_id
|
||||
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
|
||||
lm_logits = outputs[0]
|
||||
|
||||
assert lm_logits.shape[-1] == self.model.config.vocab_size
|
||||
loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), lm_labels.view(-1))
|
||||
loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))
|
||||
else:
|
||||
lprobs = torch.nn.functional.log_softmax(outputs[0], dim=-1)
|
||||
lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
|
||||
loss, nll_loss = label_smoothed_nll_loss(
|
||||
lprobs, lm_labels, self.hparams.label_smoothing, ignore_index=pad_token_id
|
||||
lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id
|
||||
)
|
||||
return (loss,)
|
||||
|
||||
@@ -167,7 +165,7 @@ class SummarizationModule(BaseTransformer):
|
||||
|
||||
logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
|
||||
# tokens per batch
|
||||
logs["tpb"] = batch["input_ids"].ne(self.pad).sum() + batch["decoder_input_ids"].ne(self.pad).sum()
|
||||
logs["tpb"] = batch["input_ids"].ne(self.pad).sum() + batch["labels"].ne(self.pad).sum()
|
||||
return {"loss": loss_tensors[0], "log": logs}
|
||||
|
||||
def validation_step(self, batch, batch_idx) -> Dict:
|
||||
@@ -204,7 +202,7 @@ class SummarizationModule(BaseTransformer):
|
||||
)
|
||||
gen_time = (time.time() - t0) / batch["input_ids"].shape[0]
|
||||
preds: List[str] = self.ids_to_clean_text(generated_ids)
|
||||
target: List[str] = self.ids_to_clean_text(batch["decoder_input_ids"])
|
||||
target: List[str] = self.ids_to_clean_text(batch["labels"])
|
||||
loss_tensors = self._step(batch)
|
||||
base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
|
||||
rouge: Dict = self.calc_generative_metrics(preds, target)
|
||||
|
||||
Reference in New Issue
Block a user