prepare_seq2seq_batch makes labels/ decoder_input_ids made later. (#6654)

* broken test

* batch parity

* tests pass

* boom boom

* boom boom

* split out bart tokenizer tests

* fix tests

* boom boom

* Fixed dataset bug

* Fix marian

* Undo extra

* Get marian working

* Fix t5 tok tests

* Test passing

* Cleanup

* better assert msg

* require torch

* Fix mbart tests

* undo extra decoder_attn_mask change

* Fix import

* pegasus tokenizer can ignore src_lang kwargs

* unused kwarg test cov

* boom boom

* add todo for pegasus issue

* cover one word translation edge case

* Cleanup

* doc
This commit is contained in:
Sam Shleifer
2020-08-28 11:15:17 -04:00
committed by GitHub
parent cb276b41de
commit 9336086ab5
20 changed files with 429 additions and 290 deletions

View File

@@ -3,7 +3,6 @@ import json
import linecache
import os
import pickle
import warnings
from logging import getLogger
from pathlib import Path
from typing import Callable, Dict, Iterable, List
@@ -41,6 +40,7 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
def encode_line(tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"):
"""Only used by LegacyDataset"""
extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}
return tokenizer(
[line],
@@ -75,7 +75,7 @@ def trim_batch(
return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
class Seq2SeqDataset(Dataset):
class AbstractSeq2SeqDataset(Dataset):
def __init__(
self,
tokenizer,
@@ -102,11 +102,28 @@ class Seq2SeqDataset(Dataset):
self.pad_token_id = self.tokenizer.pad_token_id
self.src_lang = src_lang
self.tgt_lang = tgt_lang
self.add_prefix_space = isinstance(self.tokenizer, BartTokenizer)
def __len__(self):
return len(self.src_lens)
@staticmethod
def get_char_lens(data_file):
return [len(x) for x in Path(data_file).open().readlines()]
def make_sortish_sampler(self, batch_size):
return SortishSampler(self.src_lens, batch_size)
def __getitem__(self, item):
raise NotImplementedError("You must implement this")
def collate_fn(self, batch):
raise NotImplementedError("You must implement this")
class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
def __getitem__(self, index) -> Dict[str, torch.Tensor]:
"""Call tokenizer on src and tgt_lines"""
index = index + 1 # linecache starts at 1
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
@@ -121,42 +138,27 @@ class Seq2SeqDataset(Dataset):
return {
"input_ids": source_ids,
"attention_mask": src_mask,
"decoder_input_ids": target_ids,
"labels": target_ids,
}
@staticmethod
def get_char_lens(data_file):
return [len(x) for x in Path(data_file).open().readlines()]
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
input_ids = torch.stack([x["input_ids"] for x in batch])
masks = torch.stack([x["attention_mask"] for x in batch])
target_ids = torch.stack([x["decoder_input_ids"] for x in batch])
target_ids = torch.stack([x["labels"] for x in batch])
pad_token_id = self.pad_token_id
y = trim_batch(target_ids, pad_token_id)
source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks)
batch = {
"input_ids": source_ids,
"attention_mask": source_mask,
"decoder_input_ids": y,
"labels": y,
}
return batch
def make_sortish_sampler(self, batch_size):
return SortishSampler(self.src_lens, batch_size)
class TranslationDataset(Seq2SeqDataset):
class Seq2SeqDataset(AbstractSeq2SeqDataset):
"""A dataset that calls prepare_seq2seq_batch."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.max_source_length != self.max_target_length:
warnings.warn(
f"Mbart is using sequence lengths {self.max_source_length}, {self.max_target_length}. "
f"Imbalanced sequence lengths may be undesired for translation tasks"
)
def __getitem__(self, index) -> Dict[str, str]:
index = index + 1 # linecache starts at 1
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
@@ -169,6 +171,7 @@ class TranslationDataset(Seq2SeqDataset):
}
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
"""Call prepare_seq2seq_batch."""
batch_encoding = self.tokenizer.prepare_seq2seq_batch(
[x["src_texts"] for x in batch],
src_lang=self.src_lang,
@@ -176,6 +179,8 @@ class TranslationDataset(Seq2SeqDataset):
tgt_lang=self.tgt_lang,
max_length=self.max_source_length,
max_target_length=self.max_target_length,
return_tensors="pt",
add_prefix_space=self.add_prefix_space,
)
return batch_encoding.data
@@ -276,7 +281,11 @@ def calculate_rouge(output_lns: List[str], reference_lns: List[str], use_stemmer
return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()}
# Utilities for freezing parameters and checking whether they are frozen
def freeze_params(model: nn.Module):
"""Set requires_grad=False for each of model.parameters()"""
for par in model.parameters():
par.requires_grad = False