[s2s] distributed eval in one command (#7124)
This commit is contained in:
@@ -18,6 +18,7 @@ from torch import nn
|
||||
from torch.utils.data import Dataset, Sampler
|
||||
|
||||
from transformers import BartTokenizer
|
||||
from transformers.file_utils import cached_property
|
||||
|
||||
|
||||
def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
|
||||
@@ -114,9 +115,9 @@ class AbstractSeq2SeqDataset(Dataset):
|
||||
def get_char_lens(data_file):
|
||||
return [len(x) for x in Path(data_file).open().readlines()]
|
||||
|
||||
def make_sortish_sampler(self, batch_size, distributed=False):
|
||||
def make_sortish_sampler(self, batch_size, distributed=False, **kwargs):
|
||||
if distributed:
|
||||
return DistributedSortishSampler(self, batch_size)
|
||||
return DistributedSortishSampler(self, batch_size, **kwargs)
|
||||
else:
|
||||
return SortishSampler(self.src_lens, batch_size)
|
||||
|
||||
@@ -171,14 +172,11 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset):
|
||||
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
|
||||
assert source_line, f"empty source line for index {index}"
|
||||
assert tgt_line, f"empty tgt line for index {index}"
|
||||
return {
|
||||
"tgt_texts": tgt_line,
|
||||
"src_texts": source_line,
|
||||
}
|
||||
return {"tgt_texts": tgt_line, "src_texts": source_line, "id": index - 1}
|
||||
|
||||
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
|
||||
"""Call prepare_seq2seq_batch."""
|
||||
batch_encoding = self.tokenizer.prepare_seq2seq_batch(
|
||||
batch_encoding: Dict[str, torch.Tensor] = self.tokenizer.prepare_seq2seq_batch(
|
||||
[x["src_texts"] for x in batch],
|
||||
src_lang=self.src_lang,
|
||||
tgt_texts=[x["tgt_texts"] for x in batch],
|
||||
@@ -187,8 +185,9 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset):
|
||||
max_target_length=self.max_target_length,
|
||||
return_tensors="pt",
|
||||
add_prefix_space=self.add_prefix_space,
|
||||
)
|
||||
return batch_encoding.data
|
||||
).data
|
||||
batch_encoding["ids"] = torch.tensor([x["id"] for x in batch])
|
||||
return batch_encoding
|
||||
|
||||
|
||||
class SortishSampler(Sampler):
|
||||
@@ -226,7 +225,7 @@ def sortish_sampler_indices(data: List, bs: int) -> np.array:
|
||||
class DistributedSortishSampler(Sampler):
|
||||
"""Copied from torch DistributedSampler"""
|
||||
|
||||
def __init__(self, dataset, batch_size, num_replicas=None, rank=None):
|
||||
def __init__(self, dataset, batch_size, num_replicas=None, rank=None, add_extra_examples=True):
|
||||
if num_replicas is None:
|
||||
if not dist.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
@@ -239,22 +238,27 @@ class DistributedSortishSampler(Sampler):
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
self.epoch = 0
|
||||
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
if add_extra_examples:
|
||||
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
else:
|
||||
self.total_size = len(dataset)
|
||||
self.num_samples = len(self.available_indices)
|
||||
self.batch_size = batch_size
|
||||
self.add_extra_examples = add_extra_examples
|
||||
|
||||
def __iter__(self) -> Iterable:
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.epoch)
|
||||
available_indices = self.get_indices_for_rank() # indices[self.rank: self.total_size: self.num_replicas]
|
||||
|
||||
sortish_data = [self.dataset.src_lens[i] for i in available_indices]
|
||||
sortish_data = [self.dataset.src_lens[i] for i in self.available_indices]
|
||||
sortish_indices = sortish_sampler_indices(sortish_data, self.batch_size)
|
||||
indices = [available_indices[i] for i in sortish_indices]
|
||||
indices = [self.available_indices[i] for i in sortish_indices]
|
||||
assert len(indices) == self.num_samples
|
||||
return iter(indices)
|
||||
|
||||
def get_indices_for_rank(self) -> np.array:
|
||||
@cached_property
|
||||
def available_indices(self) -> np.array:
|
||||
indices = list(range(len(self.dataset)))
|
||||
# add extra samples to make it evenly divisible
|
||||
indices += indices[: (self.total_size - len(indices))]
|
||||
|
||||
Reference in New Issue
Block a user