[lightning_base] fix s2s logging, only make train_loader once (#6404)
This commit is contained in:
@@ -3,7 +3,6 @@ import glob
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
@@ -14,7 +13,7 @@ import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from lightning_base import BaseTransformer, add_generic_args, generic_train
|
||||
from transformers import MarianTokenizer, MBartTokenizer, T5ForConditionalGeneration, get_linear_schedule_with_warmup
|
||||
from transformers import MarianTokenizer, MBartTokenizer, T5ForConditionalGeneration
|
||||
|
||||
|
||||
try:
|
||||
@@ -252,17 +251,6 @@ class SummarizationModule(BaseTransformer):
|
||||
|
||||
def train_dataloader(self) -> DataLoader:
|
||||
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
|
||||
t_total = (
|
||||
(len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus)))
|
||||
// self.hparams.accumulate_grad_batches
|
||||
* float(self.hparams.max_epochs)
|
||||
)
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
|
||||
)
|
||||
if max(scheduler.get_last_lr()) > 0:
|
||||
warnings.warn("All learning rates are 0")
|
||||
self.lr_scheduler = scheduler
|
||||
return dataloader
|
||||
|
||||
def val_dataloader(self) -> DataLoader:
|
||||
@@ -303,12 +291,6 @@ class SummarizationModule(BaseTransformer):
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input data dir. Should contain train.source, train.target, val.source, val.target, test.source, test.target",
|
||||
)
|
||||
parser.add_argument("--freeze_encoder", action="store_true")
|
||||
parser.add_argument("--freeze_embeds", action="store_true")
|
||||
parser.add_argument("--sortish_sampler", action="store_true", default=False)
|
||||
|
||||
Reference in New Issue
Block a user