[lightning_base] fix s2s logging, only make train_loader once (#6404)

This commit is contained in:
Sam Shleifer
2020-08-16 22:49:41 -04:00
committed by GitHub
parent 72add6c98f
commit 84c265ffcc
6 changed files with 47 additions and 72 deletions

View File

@@ -10,14 +10,7 @@ from torch import nn
from torch.nn import functional as F
from lightning_base import generic_train
from transformers import (
AdamW,
BartConfig,
BartForConditionalGeneration,
MBartTokenizer,
T5Config,
T5ForConditionalGeneration,
)
from transformers import BartConfig, BartForConditionalGeneration, MBartTokenizer, T5Config, T5ForConditionalGeneration
try:
@@ -158,24 +151,6 @@ class BartSummarizationDistiller(SummarizationModule):
)
return loss_ce, s_logits_slct, t_logits_slct
def configure_optimizers(self):
"Prepare optimizer and schedule (linear warmup and decay)"
model = self.model
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": self.hparams.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
self.opt = optimizer
return [optimizer]
@staticmethod
def add_model_specific_args(parser, root_dir):
SummarizationModule.add_model_specific_args(parser, root_dir)

View File

@@ -3,7 +3,6 @@ import glob
import logging
import os
import time
import warnings
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple
@@ -14,7 +13,7 @@ import torch
from torch.utils.data import DataLoader
from lightning_base import BaseTransformer, add_generic_args, generic_train
from transformers import MarianTokenizer, MBartTokenizer, T5ForConditionalGeneration, get_linear_schedule_with_warmup
from transformers import MarianTokenizer, MBartTokenizer, T5ForConditionalGeneration
try:
@@ -252,17 +251,6 @@ class SummarizationModule(BaseTransformer):
def train_dataloader(self) -> DataLoader:
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
t_total = (
(len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus)))
// self.hparams.accumulate_grad_batches
* float(self.hparams.max_epochs)
)
scheduler = get_linear_schedule_with_warmup(
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
)
if max(scheduler.get_last_lr()) > 0:
warnings.warn("All learning rates are 0")
self.lr_scheduler = scheduler
return dataloader
def val_dataloader(self) -> DataLoader:
@@ -303,12 +291,6 @@ class SummarizationModule(BaseTransformer):
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
)
parser.add_argument(
"--data_dir",
type=str,
required=True,
help="The input data dir. Should contain train.source, train.target, val.source, val.target, test.source, test.target",
)
parser.add_argument("--freeze_encoder", action="store_true")
parser.add_argument("--freeze_embeds", action="store_true")
parser.add_argument("--sortish_sampler", action="store_true", default=False)