[wip/s2s] DistributedSortishSampler (#7056)
This commit is contained in:
@@ -3,7 +3,6 @@ import glob
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import warnings
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
@@ -67,6 +66,8 @@ class SummarizationModule(BaseTransformer):
|
|||||||
default_val_metric = "rouge2"
|
default_val_metric = "rouge2"
|
||||||
|
|
||||||
def __init__(self, hparams, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
|
if hparams.sortish_sampler and hparams.gpus > 1:
|
||||||
|
hparams.replace_sampler_ddp = False
|
||||||
super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs)
|
super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs)
|
||||||
use_task_specific_params(self.model, "summarization")
|
use_task_specific_params(self.model, "summarization")
|
||||||
save_git_info(self.hparams.output_dir)
|
save_git_info(self.hparams.output_dir)
|
||||||
@@ -93,9 +94,6 @@ class SummarizationModule(BaseTransformer):
|
|||||||
"val": self.hparams.val_max_target_length,
|
"val": self.hparams.val_max_target_length,
|
||||||
"test": self.hparams.test_max_target_length,
|
"test": self.hparams.test_max_target_length,
|
||||||
}
|
}
|
||||||
if self.hparams.sortish_sampler and self.hparams.gpus > 1:
|
|
||||||
self.hparams.sortish_sampler = False
|
|
||||||
warnings.warn("ignoring sortish_sampler as it is unsupported on multiple GPUs")
|
|
||||||
assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
|
assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
|
||||||
assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"
|
assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"
|
||||||
|
|
||||||
@@ -257,8 +255,7 @@ class SummarizationModule(BaseTransformer):
|
|||||||
dataset = self.get_dataset(type_path)
|
dataset = self.get_dataset(type_path)
|
||||||
sampler = None
|
sampler = None
|
||||||
if self.hparams.sortish_sampler and type_path == "train":
|
if self.hparams.sortish_sampler and type_path == "train":
|
||||||
assert self.hparams.gpus <= 1 # this should never break because of the assertion in __init__
|
sampler = dataset.make_sortish_sampler(batch_size, distributed=self.hparams.gpus > 1)
|
||||||
sampler = dataset.make_sortish_sampler(batch_size)
|
|
||||||
shuffle = False
|
shuffle = False
|
||||||
|
|
||||||
dataloader = DataLoader(
|
dataloader = DataLoader(
|
||||||
|
|||||||
@@ -149,9 +149,9 @@ class TestSummarizationDistiller(unittest.TestCase):
|
|||||||
no_teacher=True,
|
no_teacher=True,
|
||||||
freeze_encoder=True,
|
freeze_encoder=True,
|
||||||
gpus=2,
|
gpus=2,
|
||||||
sortish_sampler=False,
|
sortish_sampler=True,
|
||||||
)
|
)
|
||||||
self._test_distiller_cli(updates)
|
self._test_distiller_cli(updates, check_contents=False)
|
||||||
|
|
||||||
def test_distill_no_teacher(self):
|
def test_distill_no_teacher(self):
|
||||||
updates = dict(student_encoder_layers=2, student_decoder_layers=1, no_teacher=True)
|
updates = dict(student_encoder_layers=2, student_decoder_layers=1, no_teacher=True)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import itertools
|
import itertools
|
||||||
import json
|
import json
|
||||||
import linecache
|
import linecache
|
||||||
|
import math
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
@@ -10,6 +11,7 @@ from typing import Callable, Dict, Iterable, List, Union
|
|||||||
import git
|
import git
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
from rouge_score import rouge_scorer, scoring
|
from rouge_score import rouge_scorer, scoring
|
||||||
from sacrebleu import corpus_bleu
|
from sacrebleu import corpus_bleu
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -111,8 +113,11 @@ class AbstractSeq2SeqDataset(Dataset):
|
|||||||
def get_char_lens(data_file):
|
def get_char_lens(data_file):
|
||||||
return [len(x) for x in Path(data_file).open().readlines()]
|
return [len(x) for x in Path(data_file).open().readlines()]
|
||||||
|
|
||||||
def make_sortish_sampler(self, batch_size):
|
def make_sortish_sampler(self, batch_size, distributed=False):
|
||||||
return SortishSampler(self.src_lens, batch_size)
|
if distributed:
|
||||||
|
return DistributedSortishSampler(self, batch_size)
|
||||||
|
else:
|
||||||
|
return SortishSampler(self.src_lens, batch_size)
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
raise NotImplementedError("You must implement this")
|
raise NotImplementedError("You must implement this")
|
||||||
@@ -191,24 +196,77 @@ class SortishSampler(Sampler):
|
|||||||
def __init__(self, data, batch_size):
|
def __init__(self, data, batch_size):
|
||||||
self.data, self.bs = data, batch_size
|
self.data, self.bs = data, batch_size
|
||||||
|
|
||||||
def key(self, i):
|
|
||||||
return self.data[i]
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return len(self.data)
|
return len(self.data)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
idxs = np.random.permutation(len(self.data))
|
return iter(sortish_sampler_indices(self.data, self.bs))
|
||||||
sz = self.bs * 50
|
|
||||||
ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)]
|
|
||||||
sort_idx = np.concatenate([sorted(s, key=self.key, reverse=True) for s in ck_idx])
|
def sortish_sampler_indices(data: List, bs: int) -> np.array:
|
||||||
sz = self.bs
|
"Go through the text data by order of src length with a bit of randomness. From fastai repo."
|
||||||
ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)]
|
|
||||||
max_ck = np.argmax([self.key(ck[0]) for ck in ck_idx]) # find the chunk with the largest key,
|
def key_fn(i):
|
||||||
ck_idx[0], ck_idx[max_ck] = ck_idx[max_ck], ck_idx[0] # then make sure it goes first.
|
return data[i]
|
||||||
sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) if len(ck_idx) > 1 else np.array([], dtype=np.int)
|
|
||||||
sort_idx = np.concatenate((ck_idx[0], sort_idx))
|
idxs = np.random.permutation(len(data))
|
||||||
return iter(sort_idx)
|
sz = bs * 50
|
||||||
|
ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)]
|
||||||
|
sort_idx = np.concatenate([sorted(s, key=key_fn, reverse=True) for s in ck_idx])
|
||||||
|
sz = bs
|
||||||
|
ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)]
|
||||||
|
max_ck = np.argmax([key_fn(ck[0]) for ck in ck_idx]) # find the chunk with the largest key,
|
||||||
|
ck_idx[0], ck_idx[max_ck] = ck_idx[max_ck], ck_idx[0] # then make sure it goes first.
|
||||||
|
sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) if len(ck_idx) > 1 else np.array([], dtype=np.int)
|
||||||
|
sort_idx = np.concatenate((ck_idx[0], sort_idx))
|
||||||
|
return sort_idx
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedSortishSampler(Sampler):
|
||||||
|
"""Copied from torch DistributedSampler"""
|
||||||
|
|
||||||
|
def __init__(self, dataset, batch_size, num_replicas=None, rank=None):
|
||||||
|
if num_replicas is None:
|
||||||
|
if not dist.is_available():
|
||||||
|
raise RuntimeError("Requires distributed package to be available")
|
||||||
|
num_replicas = dist.get_world_size()
|
||||||
|
if rank is None:
|
||||||
|
if not dist.is_available():
|
||||||
|
raise RuntimeError("Requires distributed package to be available")
|
||||||
|
rank = dist.get_rank()
|
||||||
|
self.dataset = dataset
|
||||||
|
self.num_replicas = num_replicas
|
||||||
|
self.rank = rank
|
||||||
|
self.epoch = 0
|
||||||
|
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
|
||||||
|
self.total_size = self.num_samples * self.num_replicas
|
||||||
|
self.batch_size = batch_size
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterable:
|
||||||
|
g = torch.Generator()
|
||||||
|
g.manual_seed(self.epoch)
|
||||||
|
available_indices = self.get_indices_for_rank() # indices[self.rank: self.total_size: self.num_replicas]
|
||||||
|
|
||||||
|
sortish_data = [self.dataset.src_lens[i] for i in available_indices]
|
||||||
|
sortish_indices = sortish_sampler_indices(sortish_data, self.batch_size)
|
||||||
|
indices = [available_indices[i] for i in sortish_indices]
|
||||||
|
assert len(indices) == self.num_samples
|
||||||
|
return iter(indices)
|
||||||
|
|
||||||
|
def get_indices_for_rank(self) -> np.array:
|
||||||
|
indices = list(range(len(self.dataset)))
|
||||||
|
# add extra samples to make it evenly divisible
|
||||||
|
indices += indices[: (self.total_size - len(indices))]
|
||||||
|
assert len(indices) == self.total_size
|
||||||
|
# subsample
|
||||||
|
available_indices = indices[self.rank : self.total_size : self.num_replicas]
|
||||||
|
return available_indices
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.num_samples
|
||||||
|
|
||||||
|
def set_epoch(self, epoch):
|
||||||
|
self.epoch = epoch
|
||||||
|
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|||||||
Reference in New Issue
Block a user