[s2s] dynamic batch size with --max_tokens_per_batch (#7030)
This commit is contained in:
@@ -21,6 +21,14 @@ from transformers import BartTokenizer
|
||||
from transformers.file_utils import cached_property
|
||||
|
||||
|
||||
try:
|
||||
from fairseq.data.data_utils import batch_by_size
|
||||
|
||||
FAIRSEQ_AVAILABLE = True
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
FAIRSEQ_AVAILABLE = False
|
||||
|
||||
|
||||
def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
|
||||
"""From fairseq"""
|
||||
if target.dim() == lprobs.dim() - 1:
|
||||
@@ -94,7 +102,13 @@ class AbstractSeq2SeqDataset(Dataset):
|
||||
super().__init__()
|
||||
self.src_file = Path(data_dir).joinpath(type_path + ".source")
|
||||
self.tgt_file = Path(data_dir).joinpath(type_path + ".target")
|
||||
self.src_lens = self.get_char_lens(self.src_file)
|
||||
self.len_file = Path(data_dir).joinpath(type_path + ".len")
|
||||
if os.path.exists(self.len_file):
|
||||
self.src_lens = pickle_load(self.len_file)
|
||||
self.used_char_len = False
|
||||
else:
|
||||
self.src_lens = self.get_char_lens(self.src_file)
|
||||
self.used_char_len = True
|
||||
self.max_source_length = max_source_length
|
||||
self.max_target_length = max_target_length
|
||||
assert min(self.src_lens) > 0, f"found empty line in {self.src_file}"
|
||||
@@ -115,12 +129,42 @@ class AbstractSeq2SeqDataset(Dataset):
|
||||
def get_char_lens(data_file):
|
||||
return [len(x) for x in Path(data_file).open().readlines()]
|
||||
|
||||
@cached_property
|
||||
def tgt_lens(self):
|
||||
"""Length in characters of target documents"""
|
||||
return self.get_char_lens(self.tgt_file)
|
||||
|
||||
def make_sortish_sampler(self, batch_size, distributed=False, shuffle=True, **kwargs):
|
||||
if distributed:
|
||||
return DistributedSortishSampler(self, batch_size, shuffle=shuffle, **kwargs)
|
||||
else:
|
||||
return SortishSampler(self.src_lens, batch_size, shuffle=shuffle)
|
||||
|
||||
def make_dynamic_sampler(self, max_tokens_per_batch=1024, **kwargs):
|
||||
assert FAIRSEQ_AVAILABLE, "Dynamic batch size requires `pip install fairseq`"
|
||||
assert not self.used_char_len, "You must call python make_len_file.py before calling make_dynamic_sampler"
|
||||
sorted_indices = list(self.make_sortish_sampler(1024, shuffle=False))
|
||||
|
||||
def num_tokens_in_example(i):
|
||||
return min(self.src_lens[i], self.max_target_length)
|
||||
|
||||
# call fairseq cython function
|
||||
batch_sampler: List[List[int]] = batch_by_size(
|
||||
sorted_indices,
|
||||
num_tokens_fn=num_tokens_in_example,
|
||||
max_tokens=max_tokens_per_batch,
|
||||
required_batch_size_multiple=64,
|
||||
)
|
||||
shuffled_batches = [batch_sampler[i] for i in np.random.permutation(range(len(batch_sampler)))]
|
||||
# move the largest batch to the front to OOM quickly (uses an approximation for padding)
|
||||
approximate_toks_per_batch = [max(self.src_lens[i] for i in batch) * len(batch) for batch in shuffled_batches]
|
||||
largest_batch_idx = np.argmax(approximate_toks_per_batch)
|
||||
shuffled_batches[0], shuffled_batches[largest_batch_idx] = (
|
||||
shuffled_batches[largest_batch_idx],
|
||||
shuffled_batches[0],
|
||||
)
|
||||
return shuffled_batches
|
||||
|
||||
def __getitem__(self, item):
|
||||
raise NotImplementedError("You must implement this")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user