[s2s] dynamic batch size with --max_tokens_per_batch (#7030)

This commit is contained in:
Sam Shleifer
2020-09-17 15:19:34 -04:00
committed by GitHub
parent efeab6a3f1
commit a5638b2b3a
11 changed files with 385 additions and 116 deletions

View File

@@ -21,6 +21,14 @@ from transformers import BartTokenizer
from transformers.file_utils import cached_property
try:
from fairseq.data.data_utils import batch_by_size
FAIRSEQ_AVAILABLE = True
except (ImportError, ModuleNotFoundError):
FAIRSEQ_AVAILABLE = False
def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
"""From fairseq"""
if target.dim() == lprobs.dim() - 1:
@@ -94,7 +102,13 @@ class AbstractSeq2SeqDataset(Dataset):
super().__init__()
self.src_file = Path(data_dir).joinpath(type_path + ".source")
self.tgt_file = Path(data_dir).joinpath(type_path + ".target")
self.src_lens = self.get_char_lens(self.src_file)
self.len_file = Path(data_dir).joinpath(type_path + ".len")
if os.path.exists(self.len_file):
self.src_lens = pickle_load(self.len_file)
self.used_char_len = False
else:
self.src_lens = self.get_char_lens(self.src_file)
self.used_char_len = True
self.max_source_length = max_source_length
self.max_target_length = max_target_length
assert min(self.src_lens) > 0, f"found empty line in {self.src_file}"
@@ -115,12 +129,42 @@ class AbstractSeq2SeqDataset(Dataset):
def get_char_lens(data_file):
return [len(x) for x in Path(data_file).open().readlines()]
@cached_property
def tgt_lens(self):
"""Length in characters of target documents"""
return self.get_char_lens(self.tgt_file)
def make_sortish_sampler(self, batch_size, distributed=False, shuffle=True, **kwargs):
if distributed:
return DistributedSortishSampler(self, batch_size, shuffle=shuffle, **kwargs)
else:
return SortishSampler(self.src_lens, batch_size, shuffle=shuffle)
def make_dynamic_sampler(self, max_tokens_per_batch=1024, **kwargs):
assert FAIRSEQ_AVAILABLE, "Dynamic batch size requires `pip install fairseq`"
assert not self.used_char_len, "You must call python make_len_file.py before calling make_dynamic_sampler"
sorted_indices = list(self.make_sortish_sampler(1024, shuffle=False))
def num_tokens_in_example(i):
return min(self.src_lens[i], self.max_target_length)
# call fairseq cython function
batch_sampler: List[List[int]] = batch_by_size(
sorted_indices,
num_tokens_fn=num_tokens_in_example,
max_tokens=max_tokens_per_batch,
required_batch_size_multiple=64,
)
shuffled_batches = [batch_sampler[i] for i in np.random.permutation(range(len(batch_sampler)))]
# move the largest batch to the front to OOM quickly (uses an approximation for padding)
approximate_toks_per_batch = [max(self.src_lens[i] for i in batch) * len(batch) for batch in shuffled_batches]
largest_batch_idx = np.argmax(approximate_toks_per_batch)
shuffled_batches[0], shuffled_batches[largest_batch_idx] = (
shuffled_batches[largest_batch_idx],
shuffled_batches[0],
)
return shuffled_batches
def __getitem__(self, item):
raise NotImplementedError("You must implement this")