[s2s] dynamic batch size with --max_tokens_per_batch (#7030)
This commit is contained in:
@@ -68,6 +68,12 @@ class SummarizationModule(BaseTransformer):
|
||||
def __init__(self, hparams, **kwargs):
|
||||
if hparams.sortish_sampler and hparams.gpus > 1:
|
||||
hparams.replace_sampler_ddp = False
|
||||
elif hparams.max_tokens_per_batch is not None:
|
||||
if hparams.gpus > 1:
|
||||
raise NotImplementedError("Dynamic Batch size does not work for multi-gpu training")
|
||||
if hparams.sortish_sampler:
|
||||
raise ValueError("--sortish_sampler and --max_tokens_per_batch may not be used simultaneously")
|
||||
|
||||
super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs)
|
||||
use_task_specific_params(self.model, "summarization")
|
||||
save_git_info(self.hparams.output_dir)
|
||||
@@ -97,6 +103,10 @@ class SummarizationModule(BaseTransformer):
|
||||
assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
|
||||
assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"
|
||||
|
||||
if self.hparams.sortish_sampler and self.hparams.gpus > 1:
|
||||
raise AssertionError("Sortish Sampler does not work for multigpu")
|
||||
if self.hparams.sortish_sampler and self.hparams.max_tokens_per_batch is not None:
|
||||
raise AssertionError("max tokens per batch and sortish sampler are incompatible.")
|
||||
if self.hparams.freeze_embeds:
|
||||
self.freeze_embeds()
|
||||
if self.hparams.freeze_encoder:
|
||||
@@ -175,6 +185,10 @@ class SummarizationModule(BaseTransformer):
|
||||
logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
|
||||
# tokens per batch
|
||||
logs["tpb"] = batch["input_ids"].ne(self.pad).sum() + batch["labels"].ne(self.pad).sum()
|
||||
logs["bs"] = batch["input_ids"].shape[0]
|
||||
logs["src_pad_tok"] = batch["input_ids"].eq(self.pad).sum()
|
||||
logs["src_pad_frac"] = batch["input_ids"].eq(self.pad).float().mean()
|
||||
# TODO(SS): make a wandb summary metric for this
|
||||
return {"loss": loss_tensors[0], "log": logs}
|
||||
|
||||
def validation_step(self, batch, batch_idx) -> Dict:
|
||||
@@ -253,20 +267,39 @@ class SummarizationModule(BaseTransformer):
|
||||
|
||||
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
|
||||
dataset = self.get_dataset(type_path)
|
||||
sampler = None
|
||||
if self.hparams.sortish_sampler and type_path == "train":
|
||||
sampler = dataset.make_sortish_sampler(batch_size, distributed=self.hparams.gpus > 1)
|
||||
shuffle = False
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
collate_fn=dataset.collate_fn,
|
||||
shuffle=shuffle,
|
||||
num_workers=self.num_workers,
|
||||
sampler=sampler,
|
||||
)
|
||||
return dataloader
|
||||
if self.hparams.sortish_sampler and type_path != "test":
|
||||
sampler = dataset.make_sortish_sampler(batch_size, distributed=self.hparams.gpus > 1)
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
collate_fn=dataset.collate_fn,
|
||||
shuffle=False,
|
||||
num_workers=self.num_workers,
|
||||
sampler=sampler,
|
||||
)
|
||||
|
||||
elif self.hparams.max_tokens_per_batch is not None and type_path != "test":
|
||||
batch_sampler = dataset.make_dynamic_sampler(
|
||||
self.hparams.max_tokens_per_batch, distributed=self.hparams.gpus > 1
|
||||
)
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_sampler=batch_sampler,
|
||||
collate_fn=dataset.collate_fn,
|
||||
# shuffle=False,
|
||||
num_workers=self.num_workers,
|
||||
# batch_size=None,
|
||||
)
|
||||
else:
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
collate_fn=dataset.collate_fn,
|
||||
shuffle=shuffle,
|
||||
num_workers=self.num_workers,
|
||||
sampler=None,
|
||||
)
|
||||
|
||||
def train_dataloader(self) -> DataLoader:
|
||||
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
|
||||
@@ -313,6 +346,7 @@ class SummarizationModule(BaseTransformer):
|
||||
parser.add_argument("--freeze_encoder", action="store_true")
|
||||
parser.add_argument("--freeze_embeds", action="store_true")
|
||||
parser.add_argument("--sortish_sampler", action="store_true", default=False)
|
||||
parser.add_argument("--max_tokens_per_batch", type=int, default=None)
|
||||
parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
|
||||
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
||||
parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.")
|
||||
|
||||
Reference in New Issue
Block a user