[lightning_base] fix s2s logging, only make train_loader once (#6404)

This commit is contained in:
Sam Shleifer
2020-08-16 22:49:41 -04:00
committed by GitHub
parent 72add6c98f
commit 84c265ffcc
6 changed files with 47 additions and 72 deletions

View File

@@ -3,7 +3,6 @@ import glob
import logging
import os
import time
import warnings
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple
@@ -14,7 +13,7 @@ import torch
from torch.utils.data import DataLoader
from lightning_base import BaseTransformer, add_generic_args, generic_train
from transformers import MarianTokenizer, MBartTokenizer, T5ForConditionalGeneration, get_linear_schedule_with_warmup
from transformers import MarianTokenizer, MBartTokenizer, T5ForConditionalGeneration
try:
@@ -252,17 +251,6 @@ class SummarizationModule(BaseTransformer):
def train_dataloader(self) -> DataLoader:
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
t_total = (
(len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus)))
// self.hparams.accumulate_grad_batches
* float(self.hparams.max_epochs)
)
scheduler = get_linear_schedule_with_warmup(
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
)
if max(scheduler.get_last_lr()) > 0:
warnings.warn("All learning rates are 0")
self.lr_scheduler = scheduler
return dataloader
def val_dataloader(self) -> DataLoader:
@@ -303,12 +291,6 @@ class SummarizationModule(BaseTransformer):
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
)
parser.add_argument(
"--data_dir",
type=str,
required=True,
help="The input data dir. Should contain train.source, train.target, val.source, val.target, test.source, test.target",
)
parser.add_argument("--freeze_encoder", action="store_true")
parser.add_argument("--freeze_embeds", action="store_true")
parser.add_argument("--sortish_sampler", action="store_true", default=False)