[lightning_base] fix s2s logging, only make train_loader once (#6404)

This commit is contained in:
Sam Shleifer
2020-08-16 22:49:41 -04:00
committed by GitHub
parent 72add6c98f
commit 84c265ffcc
6 changed files with 47 additions and 72 deletions

View File

@@ -104,8 +104,7 @@ class NERTransformer(BaseTransformer):
)
def validation_step(self, batch, batch_nb):
"Compute validation"
"""Compute validation""" ""
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
if self.config.model_type != "distilbert":
inputs["token_type_ids"] = (
@@ -191,14 +190,6 @@ class NERTransformer(BaseTransformer):
help="The number of GPUs allocated for this, it is by default 0 meaning none",
)
parser.add_argument(
"--data_dir",
default=None,
type=str,
required=True,
help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.",
)
parser.add_argument(
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
)

View File

@@ -4,6 +4,7 @@ import unittest
from unittest.mock import patch
import run_ner
from transformers.testing_utils import slow
logging.basicConfig(level=logging.INFO)
@@ -12,6 +13,7 @@ logger = logging.getLogger()
class ExamplesTests(unittest.TestCase):
@slow
def test_run_ner(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
@@ -31,3 +33,23 @@ class ExamplesTests(unittest.TestCase):
with patch.object(sys, "argv", ["run.py"] + testargs):
result = run_ner.main()
self.assertLess(result["eval_loss"], 1.5)
def test_run_ner_pl(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
testargs = """
--model_name distilbert-base-german-cased
--output_dir ./tests/fixtures/tests_samples/temp_dir
--overwrite_output_dir
--data_dir ./tests/fixtures/tests_samples/GermEval
--labels ./tests/fixtures/tests_samples/GermEval/labels.txt
--max_seq_length 128
--num_train_epochs 6
--logging_steps 1
--do_train
--do_eval
""".split()
with patch.object(sys, "argv", ["run.py"] + testargs):
result = run_ner.main()
self.assertLess(result["eval_loss"], 1.5)