[wip/s2s] DistributedSortishSampler (#7056)
This commit is contained in:
@@ -3,7 +3,6 @@ import glob
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
@@ -67,6 +66,8 @@ class SummarizationModule(BaseTransformer):
|
||||
default_val_metric = "rouge2"
|
||||
|
||||
def __init__(self, hparams, **kwargs):
|
||||
if hparams.sortish_sampler and hparams.gpus > 1:
|
||||
hparams.replace_sampler_ddp = False
|
||||
super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs)
|
||||
use_task_specific_params(self.model, "summarization")
|
||||
save_git_info(self.hparams.output_dir)
|
||||
@@ -93,9 +94,6 @@ class SummarizationModule(BaseTransformer):
|
||||
"val": self.hparams.val_max_target_length,
|
||||
"test": self.hparams.test_max_target_length,
|
||||
}
|
||||
if self.hparams.sortish_sampler and self.hparams.gpus > 1:
|
||||
self.hparams.sortish_sampler = False
|
||||
warnings.warn("ignoring sortish_sampler as it is unsupported on multiple GPUs")
|
||||
assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
|
||||
assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"
|
||||
|
||||
@@ -257,8 +255,7 @@ class SummarizationModule(BaseTransformer):
|
||||
dataset = self.get_dataset(type_path)
|
||||
sampler = None
|
||||
if self.hparams.sortish_sampler and type_path == "train":
|
||||
assert self.hparams.gpus <= 1 # this should never break because of the assertion in __init__
|
||||
sampler = dataset.make_sortish_sampler(batch_size)
|
||||
sampler = dataset.make_sortish_sampler(batch_size, distributed=self.hparams.gpus > 1)
|
||||
shuffle = False
|
||||
|
||||
dataloader = DataLoader(
|
||||
|
||||
Reference in New Issue
Block a user