Add new token classification example (#8340)
* Add new token classification example * Remove txt file * Add test * With actual testing done * Less warmup is better * Update examples/token-classification/run_ner_new.py Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com> * Address review comments * Fix test * Make Lysandre happy * Last touches and rename * Rename in tests * Address review comments * More run_ner -> run_ner_old Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com>
This commit is contained in:
@@ -28,7 +28,13 @@ from transformers.testing_utils import TestCasePlus, torch_device
|
||||
|
||||
SRC_DIRS = [
|
||||
os.path.join(os.path.dirname(__file__), dirname)
|
||||
for dirname in ["text-generation", "text-classification", "language-modeling", "question-answering"]
|
||||
for dirname in [
|
||||
"text-generation",
|
||||
"text-classification",
|
||||
"token-classification",
|
||||
"language-modeling",
|
||||
"question-answering",
|
||||
]
|
||||
]
|
||||
sys.path.extend(SRC_DIRS)
|
||||
|
||||
@@ -38,6 +44,7 @@ if SRC_DIRS is not None:
|
||||
import run_generation
|
||||
import run_glue
|
||||
import run_mlm
|
||||
import run_ner
|
||||
import run_pl_glue
|
||||
import run_squad
|
||||
|
||||
@@ -185,6 +192,36 @@ class ExamplesTests(TestCasePlus):
|
||||
result = run_mlm.main()
|
||||
self.assertLess(result["perplexity"], 42)
|
||||
|
||||
def test_run_ner(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
run_ner.py
|
||||
--model_name_or_path bert-base-uncased
|
||||
--train_file tests/fixtures/tests_samples/conll/sample.json
|
||||
--validation_file tests/fixtures/tests_samples/conll/sample.json
|
||||
--output_dir {tmp_dir}
|
||||
--overwrite_output_dir
|
||||
--do_train
|
||||
--do_eval
|
||||
--warmup_steps=2
|
||||
--learning_rate=2e-4
|
||||
--per_gpu_train_batch_size=2
|
||||
--per_gpu_eval_batch_size=2
|
||||
--num_train_epochs=2
|
||||
""".split()
|
||||
|
||||
if torch_device != "cuda":
|
||||
testargs.append("--no_cuda")
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
result = run_ner.main()
|
||||
self.assertGreaterEqual(result["eval_accuracy_score"], 0.75)
|
||||
self.assertGreaterEqual(result["eval_precision"], 0.75)
|
||||
self.assertLess(result["eval_loss"], 0.5)
|
||||
|
||||
def test_run_squad(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
Reference in New Issue
Block a user