[wip/s2s] DistributedSortishSampler (#7056)

This commit is contained in:
Sam Shleifer
2020-09-10 15:23:44 -04:00
committed by GitHub
parent 514486739c
commit 77950c485a
3 changed files with 79 additions and 24 deletions

View File

@@ -3,7 +3,6 @@ import glob
import logging
import os
import time
import warnings
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple
@@ -67,6 +66,8 @@ class SummarizationModule(BaseTransformer):
default_val_metric = "rouge2"
def __init__(self, hparams, **kwargs):
if hparams.sortish_sampler and hparams.gpus > 1:
hparams.replace_sampler_ddp = False
super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs)
use_task_specific_params(self.model, "summarization")
save_git_info(self.hparams.output_dir)
@@ -93,9 +94,6 @@ class SummarizationModule(BaseTransformer):
"val": self.hparams.val_max_target_length,
"test": self.hparams.test_max_target_length,
}
if self.hparams.sortish_sampler and self.hparams.gpus > 1:
self.hparams.sortish_sampler = False
warnings.warn("ignoring sortish_sampler as it is unsupported on multiple GPUs")
assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"
@@ -257,8 +255,7 @@ class SummarizationModule(BaseTransformer):
dataset = self.get_dataset(type_path)
sampler = None
if self.hparams.sortish_sampler and type_path == "train":
assert self.hparams.gpus <= 1 # this should never break because of the assertion in __init__
sampler = dataset.make_sortish_sampler(batch_size)
sampler = dataset.make_sortish_sampler(batch_size, distributed=self.hparams.gpus > 1)
shuffle = False
dataloader = DataLoader(