Add new token classification example (#8340)

* Add new token classification example

* Remove txt file

* Add test

* With actual testing done

* Less warmup is better

* Update examples/token-classification/run_ner_new.py

Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com>

* Address review comments

* Fix test

* Make Lysandre happy

* Last touches and rename

* Rename in tests

* Address review comments

* More run_ner -> run_ner_old

Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com>
This commit is contained in:
Sylvain Gugger
2020-11-09 11:39:55 -05:00
committed by GitHub
parent c7cb1aa26c
commit 908a28894c
21 changed files with 652 additions and 185 deletions

View File

@@ -28,7 +28,13 @@ from transformers.testing_utils import TestCasePlus, torch_device
SRC_DIRS = [
os.path.join(os.path.dirname(__file__), dirname)
for dirname in ["text-generation", "text-classification", "language-modeling", "question-answering"]
for dirname in [
"text-generation",
"text-classification",
"token-classification",
"language-modeling",
"question-answering",
]
]
sys.path.extend(SRC_DIRS)
@@ -38,6 +44,7 @@ if SRC_DIRS is not None:
import run_generation
import run_glue
import run_mlm
import run_ner
import run_pl_glue
import run_squad
@@ -185,6 +192,36 @@ class ExamplesTests(TestCasePlus):
result = run_mlm.main()
self.assertLess(result["perplexity"], 42)
def test_run_ner(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_ner.py
--model_name_or_path bert-base-uncased
--train_file tests/fixtures/tests_samples/conll/sample.json
--validation_file tests/fixtures/tests_samples/conll/sample.json
--output_dir {tmp_dir}
--overwrite_output_dir
--do_train
--do_eval
--warmup_steps=2
--learning_rate=2e-4
--per_gpu_train_batch_size=2
--per_gpu_eval_batch_size=2
--num_train_epochs=2
""".split()
if torch_device != "cuda":
testargs.append("--no_cuda")
with patch.object(sys, "argv", testargs):
result = run_ner.main()
self.assertGreaterEqual(result["eval_accuracy_score"], 0.75)
self.assertGreaterEqual(result["eval_precision"], 0.75)
self.assertLess(result["eval_loss"], 0.5)
def test_run_squad(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)