[lightning_base] fix s2s logging, only make train_loader once (#6404)
This commit is contained in:
@@ -150,15 +150,20 @@ class BaseTransformer(pl.LightningModule):
|
|||||||
def test_epoch_end(self, outputs):
|
def test_epoch_end(self, outputs):
|
||||||
return self.validation_end(outputs)
|
return self.validation_end(outputs)
|
||||||
|
|
||||||
def setup(self, step):
|
@property
|
||||||
train_batch_size = self.hparams.train_batch_size
|
def total_steps(self) -> int:
|
||||||
dataloader = self.get_dataloader("train", train_batch_size)
|
"""The number of total training steps that will be run. Used for lr scheduler purposes."""
|
||||||
self.train_loader = dataloader
|
num_devices = max(1, self.hparams.gpus) # TODO: consider num_tpu_cores
|
||||||
self.total_steps = (
|
effective_batch_size = self.hparams.train_batch_size * self.hparams.accumulate_grad_batches * num_devices
|
||||||
(len(dataloader.dataset) // (train_batch_size * max(1, self.hparams.gpus)))
|
dataset_size = len(self.train_loader.dataset)
|
||||||
// self.hparams.accumulate_grad_batches
|
return (dataset_size / effective_batch_size) * self.hparams.max_epochs
|
||||||
* float(self.hparams.max_epochs)
|
|
||||||
)
|
def setup(self, mode):
|
||||||
|
if mode == "fit":
|
||||||
|
self.train_loader = self.get_dataloader("train", self.hparams.train_batch_size, shuffle=True)
|
||||||
|
|
||||||
|
def get_dataloader(self, type_path, batch_size, shuffle=False):
|
||||||
|
raise NotImplementedError("You must implement this for your task")
|
||||||
|
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
return self.train_loader
|
return self.train_loader
|
||||||
@@ -304,6 +309,13 @@ def add_generic_args(parser, root_dir) -> None:
|
|||||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||||
)
|
)
|
||||||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||||
|
parser.add_argument(
|
||||||
|
"--data_dir",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def generic_train(
|
def generic_train(
|
||||||
|
|||||||
@@ -10,14 +10,7 @@ from torch import nn
|
|||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from lightning_base import generic_train
|
from lightning_base import generic_train
|
||||||
from transformers import (
|
from transformers import BartConfig, BartForConditionalGeneration, MBartTokenizer, T5Config, T5ForConditionalGeneration
|
||||||
AdamW,
|
|
||||||
BartConfig,
|
|
||||||
BartForConditionalGeneration,
|
|
||||||
MBartTokenizer,
|
|
||||||
T5Config,
|
|
||||||
T5ForConditionalGeneration,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -158,24 +151,6 @@ class BartSummarizationDistiller(SummarizationModule):
|
|||||||
)
|
)
|
||||||
return loss_ce, s_logits_slct, t_logits_slct
|
return loss_ce, s_logits_slct, t_logits_slct
|
||||||
|
|
||||||
def configure_optimizers(self):
|
|
||||||
"Prepare optimizer and schedule (linear warmup and decay)"
|
|
||||||
model = self.model
|
|
||||||
no_decay = ["bias", "LayerNorm.weight"]
|
|
||||||
optimizer_grouped_parameters = [
|
|
||||||
{
|
|
||||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
|
||||||
"weight_decay": self.hparams.weight_decay,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
|
||||||
"weight_decay": 0.0,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
|
|
||||||
self.opt = optimizer
|
|
||||||
return [optimizer]
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_model_specific_args(parser, root_dir):
|
def add_model_specific_args(parser, root_dir):
|
||||||
SummarizationModule.add_model_specific_args(parser, root_dir)
|
SummarizationModule.add_model_specific_args(parser, root_dir)
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ import glob
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import warnings
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
@@ -14,7 +13,7 @@ import torch
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from lightning_base import BaseTransformer, add_generic_args, generic_train
|
from lightning_base import BaseTransformer, add_generic_args, generic_train
|
||||||
from transformers import MarianTokenizer, MBartTokenizer, T5ForConditionalGeneration, get_linear_schedule_with_warmup
|
from transformers import MarianTokenizer, MBartTokenizer, T5ForConditionalGeneration
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -252,17 +251,6 @@ class SummarizationModule(BaseTransformer):
|
|||||||
|
|
||||||
def train_dataloader(self) -> DataLoader:
|
def train_dataloader(self) -> DataLoader:
|
||||||
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
|
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
|
||||||
t_total = (
|
|
||||||
(len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus)))
|
|
||||||
// self.hparams.accumulate_grad_batches
|
|
||||||
* float(self.hparams.max_epochs)
|
|
||||||
)
|
|
||||||
scheduler = get_linear_schedule_with_warmup(
|
|
||||||
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
|
|
||||||
)
|
|
||||||
if max(scheduler.get_last_lr()) > 0:
|
|
||||||
warnings.warn("All learning rates are 0")
|
|
||||||
self.lr_scheduler = scheduler
|
|
||||||
return dataloader
|
return dataloader
|
||||||
|
|
||||||
def val_dataloader(self) -> DataLoader:
|
def val_dataloader(self) -> DataLoader:
|
||||||
@@ -303,12 +291,6 @@ class SummarizationModule(BaseTransformer):
|
|||||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||||
"than this will be truncated, sequences shorter will be padded.",
|
"than this will be truncated, sequences shorter will be padded.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--data_dir",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="The input data dir. Should contain train.source, train.target, val.source, val.target, test.source, test.target",
|
|
||||||
)
|
|
||||||
parser.add_argument("--freeze_encoder", action="store_true")
|
parser.add_argument("--freeze_encoder", action="store_true")
|
||||||
parser.add_argument("--freeze_embeds", action="store_true")
|
parser.add_argument("--freeze_embeds", action="store_true")
|
||||||
parser.add_argument("--sortish_sampler", action="store_true", default=False)
|
parser.add_argument("--sortish_sampler", action="store_true", default=False)
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ class GLUETransformer(BaseTransformer):
|
|||||||
logger.info("Saving features into cached file %s", cached_features_file)
|
logger.info("Saving features into cached file %s", cached_features_file)
|
||||||
torch.save(features, cached_features_file)
|
torch.save(features, cached_features_file)
|
||||||
|
|
||||||
def get_dataloader(self, mode: int, batch_size: int, shuffle: bool = False) -> DataLoader:
|
def get_dataloader(self, mode: str, batch_size: int, shuffle: bool = False) -> DataLoader:
|
||||||
"Load datasets. Called after prepare data."
|
"Load datasets. Called after prepare data."
|
||||||
|
|
||||||
# We test on dev set to compare to benchmarks without having to submit to GLUE server
|
# We test on dev set to compare to benchmarks without having to submit to GLUE server
|
||||||
@@ -161,13 +161,6 @@ class GLUETransformer(BaseTransformer):
|
|||||||
type=int,
|
type=int,
|
||||||
help="The number of GPUs allocated for this, it is by default 0 meaning none",
|
help="The number of GPUs allocated for this, it is by default 0 meaning none",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--data_dir",
|
|
||||||
default=None,
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
||||||
|
|||||||
@@ -104,8 +104,7 @@ class NERTransformer(BaseTransformer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def validation_step(self, batch, batch_nb):
|
def validation_step(self, batch, batch_nb):
|
||||||
"Compute validation"
|
"""Compute validation""" ""
|
||||||
|
|
||||||
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
|
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
|
||||||
if self.config.model_type != "distilbert":
|
if self.config.model_type != "distilbert":
|
||||||
inputs["token_type_ids"] = (
|
inputs["token_type_ids"] = (
|
||||||
@@ -191,14 +190,6 @@ class NERTransformer(BaseTransformer):
|
|||||||
help="The number of GPUs allocated for this, it is by default 0 meaning none",
|
help="The number of GPUs allocated for this, it is by default 0 meaning none",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--data_dir",
|
|
||||||
default=None,
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import unittest
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import run_ner
|
import run_ner
|
||||||
|
from transformers.testing_utils import slow
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
@@ -12,6 +13,7 @@ logger = logging.getLogger()
|
|||||||
|
|
||||||
|
|
||||||
class ExamplesTests(unittest.TestCase):
|
class ExamplesTests(unittest.TestCase):
|
||||||
|
@slow
|
||||||
def test_run_ner(self):
|
def test_run_ner(self):
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
logger.addHandler(stream_handler)
|
logger.addHandler(stream_handler)
|
||||||
@@ -31,3 +33,23 @@ class ExamplesTests(unittest.TestCase):
|
|||||||
with patch.object(sys, "argv", ["run.py"] + testargs):
|
with patch.object(sys, "argv", ["run.py"] + testargs):
|
||||||
result = run_ner.main()
|
result = run_ner.main()
|
||||||
self.assertLess(result["eval_loss"], 1.5)
|
self.assertLess(result["eval_loss"], 1.5)
|
||||||
|
|
||||||
|
def test_run_ner_pl(self):
|
||||||
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
|
testargs = """
|
||||||
|
--model_name distilbert-base-german-cased
|
||||||
|
--output_dir ./tests/fixtures/tests_samples/temp_dir
|
||||||
|
--overwrite_output_dir
|
||||||
|
--data_dir ./tests/fixtures/tests_samples/GermEval
|
||||||
|
--labels ./tests/fixtures/tests_samples/GermEval/labels.txt
|
||||||
|
--max_seq_length 128
|
||||||
|
--num_train_epochs 6
|
||||||
|
--logging_steps 1
|
||||||
|
--do_train
|
||||||
|
--do_eval
|
||||||
|
""".split()
|
||||||
|
with patch.object(sys, "argv", ["run.py"] + testargs):
|
||||||
|
result = run_ner.main()
|
||||||
|
self.assertLess(result["eval_loss"], 1.5)
|
||||||
|
|||||||
Reference in New Issue
Block a user