[s2s] distributed eval in one command (#7124)

This commit is contained in:
Sam Shleifer
2020-09-14 15:57:56 -04:00
committed by GitHub
parent 206b78d485
commit 33d479d2b2
4 changed files with 125 additions and 85 deletions

View File

@@ -18,6 +18,7 @@ from torch import nn
from torch.utils.data import Dataset, Sampler
from transformers import BartTokenizer
from transformers.file_utils import cached_property
def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
@@ -114,9 +115,9 @@ class AbstractSeq2SeqDataset(Dataset):
def get_char_lens(data_file):
return [len(x) for x in Path(data_file).open().readlines()]
def make_sortish_sampler(self, batch_size, distributed=False):
def make_sortish_sampler(self, batch_size, distributed=False, **kwargs):
if distributed:
return DistributedSortishSampler(self, batch_size)
return DistributedSortishSampler(self, batch_size, **kwargs)
else:
return SortishSampler(self.src_lens, batch_size)
@@ -171,14 +172,11 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset):
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
assert source_line, f"empty source line for index {index}"
assert tgt_line, f"empty tgt line for index {index}"
return {
"tgt_texts": tgt_line,
"src_texts": source_line,
}
return {"tgt_texts": tgt_line, "src_texts": source_line, "id": index - 1}
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
"""Call prepare_seq2seq_batch."""
batch_encoding = self.tokenizer.prepare_seq2seq_batch(
batch_encoding: Dict[str, torch.Tensor] = self.tokenizer.prepare_seq2seq_batch(
[x["src_texts"] for x in batch],
src_lang=self.src_lang,
tgt_texts=[x["tgt_texts"] for x in batch],
@@ -187,8 +185,9 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset):
max_target_length=self.max_target_length,
return_tensors="pt",
add_prefix_space=self.add_prefix_space,
)
return batch_encoding.data
).data
batch_encoding["ids"] = torch.tensor([x["id"] for x in batch])
return batch_encoding
class SortishSampler(Sampler):
@@ -226,7 +225,7 @@ def sortish_sampler_indices(data: List, bs: int) -> np.array:
class DistributedSortishSampler(Sampler):
"""Copied from torch DistributedSampler"""
def __init__(self, dataset, batch_size, num_replicas=None, rank=None):
def __init__(self, dataset, batch_size, num_replicas=None, rank=None, add_extra_examples=True):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
@@ -239,22 +238,27 @@ class DistributedSortishSampler(Sampler):
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
if add_extra_examples:
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
else:
self.total_size = len(dataset)
self.num_samples = len(self.available_indices)
self.batch_size = batch_size
self.add_extra_examples = add_extra_examples
def __iter__(self) -> Iterable:
g = torch.Generator()
g.manual_seed(self.epoch)
available_indices = self.get_indices_for_rank() # indices[self.rank: self.total_size: self.num_replicas]
sortish_data = [self.dataset.src_lens[i] for i in available_indices]
sortish_data = [self.dataset.src_lens[i] for i in self.available_indices]
sortish_indices = sortish_sampler_indices(sortish_data, self.batch_size)
indices = [available_indices[i] for i in sortish_indices]
indices = [self.available_indices[i] for i in sortish_indices]
assert len(indices) == self.num_samples
return iter(indices)
def get_indices_for_rank(self) -> np.array:
@cached_property
def available_indices(self) -> np.array:
indices = list(range(len(self.dataset)))
# add extra samples to make it evenly divisible
indices += indices[: (self.total_size - len(indices))]