[wip/s2s] DistributedSortishSampler (#7056)

This commit is contained in:
Sam Shleifer
2020-09-10 15:23:44 -04:00
committed by GitHub
parent 514486739c
commit 77950c485a
3 changed files with 79 additions and 24 deletions

View File

@@ -1,6 +1,7 @@
import itertools
import json
import linecache
import math
import os
import pickle
from logging import getLogger
@@ -10,6 +11,7 @@ from typing import Callable, Dict, Iterable, List, Union
import git
import numpy as np
import torch
import torch.distributed as dist
from rouge_score import rouge_scorer, scoring
from sacrebleu import corpus_bleu
from torch import nn
@@ -111,8 +113,11 @@ class AbstractSeq2SeqDataset(Dataset):
def get_char_lens(data_file):
return [len(x) for x in Path(data_file).open().readlines()]
def make_sortish_sampler(self, batch_size):
return SortishSampler(self.src_lens, batch_size)
def make_sortish_sampler(self, batch_size, distributed=False):
if distributed:
return DistributedSortishSampler(self, batch_size)
else:
return SortishSampler(self.src_lens, batch_size)
def __getitem__(self, item):
raise NotImplementedError("You must implement this")
@@ -191,24 +196,77 @@ class SortishSampler(Sampler):
def __init__(self, data, batch_size):
self.data, self.bs = data, batch_size
def key(self, i):
return self.data[i]
def __len__(self) -> int:
return len(self.data)
def __iter__(self):
idxs = np.random.permutation(len(self.data))
sz = self.bs * 50
ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)]
sort_idx = np.concatenate([sorted(s, key=self.key, reverse=True) for s in ck_idx])
sz = self.bs
ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)]
max_ck = np.argmax([self.key(ck[0]) for ck in ck_idx]) # find the chunk with the largest key,
ck_idx[0], ck_idx[max_ck] = ck_idx[max_ck], ck_idx[0] # then make sure it goes first.
sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) if len(ck_idx) > 1 else np.array([], dtype=np.int)
sort_idx = np.concatenate((ck_idx[0], sort_idx))
return iter(sort_idx)
return iter(sortish_sampler_indices(self.data, self.bs))
def sortish_sampler_indices(data: List, bs: int) -> np.array:
"Go through the text data by order of src length with a bit of randomness. From fastai repo."
def key_fn(i):
return data[i]
idxs = np.random.permutation(len(data))
sz = bs * 50
ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)]
sort_idx = np.concatenate([sorted(s, key=key_fn, reverse=True) for s in ck_idx])
sz = bs
ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)]
max_ck = np.argmax([key_fn(ck[0]) for ck in ck_idx]) # find the chunk with the largest key,
ck_idx[0], ck_idx[max_ck] = ck_idx[max_ck], ck_idx[0] # then make sure it goes first.
sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) if len(ck_idx) > 1 else np.array([], dtype=np.int)
sort_idx = np.concatenate((ck_idx[0], sort_idx))
return sort_idx
class DistributedSortishSampler(Sampler):
"""Copied from torch DistributedSampler"""
def __init__(self, dataset, batch_size, num_replicas=None, rank=None):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
self.batch_size = batch_size
def __iter__(self) -> Iterable:
g = torch.Generator()
g.manual_seed(self.epoch)
available_indices = self.get_indices_for_rank() # indices[self.rank: self.total_size: self.num_replicas]
sortish_data = [self.dataset.src_lens[i] for i in available_indices]
sortish_indices = sortish_sampler_indices(sortish_data, self.batch_size)
indices = [available_indices[i] for i in sortish_indices]
assert len(indices) == self.num_samples
return iter(indices)
def get_indices_for_rank(self) -> np.array:
indices = list(range(len(self.dataset)))
# add extra samples to make it evenly divisible
indices += indices[: (self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
available_indices = indices[self.rank : self.total_size : self.num_replicas]
return available_indices
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch
logger = getLogger(__name__)