prepare_seq2seq_batch makes labels/ decoder_input_ids made later. (#6654)
* broken test * batch parity * tests pass * boom boom * boom boom * split out bart tokenizer tests * fix tests * boom boom * Fixed dataset bug * Fix marian * Undo extra * Get marian working * Fix t5 tok tests * Test passing * Cleanup * better assert msg * require torch * Fix mbart tests * undo extra decoder_attn_mask change * Fix import * pegasus tokenizer can ignore src_lang kwargs * unused kwarg test cov * boom boom * add todo for pegasus issue * cover one word translation edge case * Cleanup * doc
This commit is contained in:
@@ -3,7 +3,6 @@ import json
|
||||
import linecache
|
||||
import os
|
||||
import pickle
|
||||
import warnings
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Iterable, List
|
||||
@@ -41,6 +40,7 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
|
||||
|
||||
|
||||
def encode_line(tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"):
|
||||
"""Only used by LegacyDataset"""
|
||||
extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}
|
||||
return tokenizer(
|
||||
[line],
|
||||
@@ -75,7 +75,7 @@ def trim_batch(
|
||||
return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
|
||||
|
||||
|
||||
class Seq2SeqDataset(Dataset):
|
||||
class AbstractSeq2SeqDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
@@ -102,11 +102,28 @@ class Seq2SeqDataset(Dataset):
|
||||
self.pad_token_id = self.tokenizer.pad_token_id
|
||||
self.src_lang = src_lang
|
||||
self.tgt_lang = tgt_lang
|
||||
self.add_prefix_space = isinstance(self.tokenizer, BartTokenizer)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.src_lens)
|
||||
|
||||
@staticmethod
|
||||
def get_char_lens(data_file):
|
||||
return [len(x) for x in Path(data_file).open().readlines()]
|
||||
|
||||
def make_sortish_sampler(self, batch_size):
|
||||
return SortishSampler(self.src_lens, batch_size)
|
||||
|
||||
def __getitem__(self, item):
|
||||
raise NotImplementedError("You must implement this")
|
||||
|
||||
def collate_fn(self, batch):
|
||||
raise NotImplementedError("You must implement this")
|
||||
|
||||
|
||||
class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
|
||||
def __getitem__(self, index) -> Dict[str, torch.Tensor]:
|
||||
"""Call tokenizer on src and tgt_lines"""
|
||||
index = index + 1 # linecache starts at 1
|
||||
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
|
||||
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
|
||||
@@ -121,42 +138,27 @@ class Seq2SeqDataset(Dataset):
|
||||
return {
|
||||
"input_ids": source_ids,
|
||||
"attention_mask": src_mask,
|
||||
"decoder_input_ids": target_ids,
|
||||
"labels": target_ids,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_char_lens(data_file):
|
||||
return [len(x) for x in Path(data_file).open().readlines()]
|
||||
|
||||
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
|
||||
input_ids = torch.stack([x["input_ids"] for x in batch])
|
||||
masks = torch.stack([x["attention_mask"] for x in batch])
|
||||
target_ids = torch.stack([x["decoder_input_ids"] for x in batch])
|
||||
target_ids = torch.stack([x["labels"] for x in batch])
|
||||
pad_token_id = self.pad_token_id
|
||||
y = trim_batch(target_ids, pad_token_id)
|
||||
source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks)
|
||||
batch = {
|
||||
"input_ids": source_ids,
|
||||
"attention_mask": source_mask,
|
||||
"decoder_input_ids": y,
|
||||
"labels": y,
|
||||
}
|
||||
return batch
|
||||
|
||||
def make_sortish_sampler(self, batch_size):
|
||||
return SortishSampler(self.src_lens, batch_size)
|
||||
|
||||
|
||||
class TranslationDataset(Seq2SeqDataset):
|
||||
class Seq2SeqDataset(AbstractSeq2SeqDataset):
|
||||
"""A dataset that calls prepare_seq2seq_batch."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if self.max_source_length != self.max_target_length:
|
||||
warnings.warn(
|
||||
f"Mbart is using sequence lengths {self.max_source_length}, {self.max_target_length}. "
|
||||
f"Imbalanced sequence lengths may be undesired for translation tasks"
|
||||
)
|
||||
|
||||
def __getitem__(self, index) -> Dict[str, str]:
|
||||
index = index + 1 # linecache starts at 1
|
||||
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
|
||||
@@ -169,6 +171,7 @@ class TranslationDataset(Seq2SeqDataset):
|
||||
}
|
||||
|
||||
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
|
||||
"""Call prepare_seq2seq_batch."""
|
||||
batch_encoding = self.tokenizer.prepare_seq2seq_batch(
|
||||
[x["src_texts"] for x in batch],
|
||||
src_lang=self.src_lang,
|
||||
@@ -176,6 +179,8 @@ class TranslationDataset(Seq2SeqDataset):
|
||||
tgt_lang=self.tgt_lang,
|
||||
max_length=self.max_source_length,
|
||||
max_target_length=self.max_target_length,
|
||||
return_tensors="pt",
|
||||
add_prefix_space=self.add_prefix_space,
|
||||
)
|
||||
return batch_encoding.data
|
||||
|
||||
@@ -276,7 +281,11 @@ def calculate_rouge(output_lns: List[str], reference_lns: List[str], use_stemmer
|
||||
return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()}
|
||||
|
||||
|
||||
# Utilities for freezing parameters and checking whether they are frozen
|
||||
|
||||
|
||||
def freeze_params(model: nn.Module):
|
||||
"""Set requires_grad=False for each of model.parameters()"""
|
||||
for par in model.parameters():
|
||||
par.requires_grad = False
|
||||
|
||||
|
||||
Reference in New Issue
Block a user