[s2s] distributed eval cleanup (#7186)

This commit is contained in:
Sam Shleifer
2020-09-16 15:38:37 -04:00
committed by GitHub
parent 3babef815c
commit 0203ad43bc
4 changed files with 47 additions and 31 deletions

View File

@@ -115,11 +115,11 @@ class AbstractSeq2SeqDataset(Dataset):
def get_char_lens(data_file):
return [len(x) for x in Path(data_file).open().readlines()]
def make_sortish_sampler(self, batch_size, distributed=False, **kwargs):
def make_sortish_sampler(self, batch_size, distributed=False, shuffle=True, **kwargs):
if distributed:
return DistributedSortishSampler(self, batch_size, **kwargs)
return DistributedSortishSampler(self, batch_size, shuffle=shuffle, **kwargs)
else:
return SortishSampler(self.src_lens, batch_size)
return SortishSampler(self.src_lens, batch_size, shuffle=shuffle)
def __getitem__(self, item):
raise NotImplementedError("You must implement this")
@@ -193,18 +193,20 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset):
class SortishSampler(Sampler):
"Go through the text data by order of src length with a bit of randomness. From fastai repo."
def __init__(self, data, batch_size):
self.data, self.bs = data, batch_size
def __init__(self, data, batch_size, shuffle=True):
self.data, self.bs, self.shuffle = data, batch_size, shuffle
def __len__(self) -> int:
return len(self.data)
def __iter__(self):
return iter(sortish_sampler_indices(self.data, self.bs))
return iter(sortish_sampler_indices(self.data, self.bs, shuffle=self.shuffle))
def sortish_sampler_indices(data: List, bs: int) -> np.array:
def sortish_sampler_indices(data: List, bs: int, shuffle=True) -> np.array:
"Go through the text data by order of src length with a bit of randomness. From fastai repo."
if not shuffle:
return np.argsort(np.array(data) * -1)
def key_fn(i):
return data[i]
@@ -225,7 +227,7 @@ def sortish_sampler_indices(data: List, bs: int) -> np.array:
class DistributedSortishSampler(Sampler):
"""Copied from torch DistributedSampler"""
def __init__(self, dataset, batch_size, num_replicas=None, rank=None, add_extra_examples=True):
def __init__(self, dataset, batch_size, num_replicas=None, rank=None, add_extra_examples=True, shuffle=True):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
@@ -246,13 +248,14 @@ class DistributedSortishSampler(Sampler):
self.num_samples = len(self.available_indices)
self.batch_size = batch_size
self.add_extra_examples = add_extra_examples
self.shuffle = shuffle
def __iter__(self) -> Iterable:
g = torch.Generator()
g.manual_seed(self.epoch)
sortish_data = [self.dataset.src_lens[i] for i in self.available_indices]
sortish_indices = sortish_sampler_indices(sortish_data, self.batch_size)
sortish_indices = sortish_sampler_indices(sortish_data, self.batch_size, shuffle=self.shuffle)
indices = [self.available_indices[i] for i in sortish_indices]
assert len(indices) == self.num_samples
return iter(indices)
@@ -309,9 +312,9 @@ def save_git_info(folder_path: str) -> None:
save_json(repo_infos, os.path.join(folder_path, "git_log.json"))
def save_json(content, path):
def save_json(content, path, indent=4, **json_dump_kwargs):
with open(path, "w") as f:
json.dump(content, f, indent=4)
json.dump(content, f, indent=indent, **json_dump_kwargs)
def load_json(path):