[lightning_base] fix s2s logging, only make train_loader once (#6404)
This commit is contained in:
@@ -104,8 +104,7 @@ class NERTransformer(BaseTransformer):
|
||||
)
|
||||
|
||||
def validation_step(self, batch, batch_nb):
|
||||
"Compute validation"
|
||||
|
||||
"""Compute validation""" ""
|
||||
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
|
||||
if self.config.model_type != "distilbert":
|
||||
inputs["token_type_ids"] = (
|
||||
@@ -191,14 +190,6 @@ class NERTransformer(BaseTransformer):
|
||||
help="The number of GPUs allocated for this, it is by default 0 meaning none",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
||||
)
|
||||
|
||||
@@ -4,6 +4,7 @@ import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import run_ner
|
||||
from transformers.testing_utils import slow
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@@ -12,6 +13,7 @@ logger = logging.getLogger()
|
||||
|
||||
|
||||
class ExamplesTests(unittest.TestCase):
|
||||
@slow
|
||||
def test_run_ner(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
@@ -31,3 +33,23 @@ class ExamplesTests(unittest.TestCase):
|
||||
with patch.object(sys, "argv", ["run.py"] + testargs):
|
||||
result = run_ner.main()
|
||||
self.assertLess(result["eval_loss"], 1.5)
|
||||
|
||||
def test_run_ner_pl(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
testargs = """
|
||||
--model_name distilbert-base-german-cased
|
||||
--output_dir ./tests/fixtures/tests_samples/temp_dir
|
||||
--overwrite_output_dir
|
||||
--data_dir ./tests/fixtures/tests_samples/GermEval
|
||||
--labels ./tests/fixtures/tests_samples/GermEval/labels.txt
|
||||
--max_seq_length 128
|
||||
--num_train_epochs 6
|
||||
--logging_steps 1
|
||||
--do_train
|
||||
--do_eval
|
||||
""".split()
|
||||
with patch.object(sys, "argv", ["run.py"] + testargs):
|
||||
result = run_ner.main()
|
||||
self.assertLess(result["eval_loss"], 1.5)
|
||||
|
||||
Reference in New Issue
Block a user