New squad example (#8992)
* Add new SQUAD example * Same with a task-specific Trainer * Address review comment. * Small fixes * Initial work for XLNet * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Final clean up and working XLNet script * Test and debug * Final working version * Add new SQUAD example * Same with a task-specific Trainer * Address review comment. * Small fixes * Initial work for XLNet * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Final clean up and working XLNet script * Test and debug * Final working version * Add tick * Update README * Address review comments Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -46,7 +46,7 @@ if SRC_DIRS is not None:
|
||||
import run_mlm
|
||||
import run_ner
|
||||
import run_pl_glue
|
||||
import run_squad
|
||||
import run_qa as run_squad
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
@@ -213,8 +213,8 @@ class ExamplesTests(TestCasePlus):
|
||||
--do_eval
|
||||
--warmup_steps=2
|
||||
--learning_rate=2e-4
|
||||
--per_gpu_train_batch_size=2
|
||||
--per_gpu_eval_batch_size=2
|
||||
--per_device_train_batch_size=2
|
||||
--per_device_eval_batch_size=2
|
||||
--num_train_epochs=2
|
||||
""".split()
|
||||
|
||||
@@ -235,26 +235,25 @@ class ExamplesTests(TestCasePlus):
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
run_squad.py
|
||||
--model_type=distilbert
|
||||
--model_name_or_path=sshleifer/tiny-distilbert-base-cased-distilled-squad
|
||||
--data_dir=./tests/fixtures/tests_samples/SQUAD
|
||||
--model_name_or_path bert-base-uncased
|
||||
--version_2_with_negative
|
||||
--train_file tests/fixtures/tests_samples/SQUAD/sample.json
|
||||
--validation_file tests/fixtures/tests_samples/SQUAD/sample.json
|
||||
--output_dir {tmp_dir}
|
||||
--overwrite_output_dir
|
||||
--max_steps=10
|
||||
--warmup_steps=2
|
||||
--do_train
|
||||
--do_eval
|
||||
--version_2_with_negative
|
||||
--learning_rate=2e-4
|
||||
--per_gpu_train_batch_size=2
|
||||
--per_gpu_eval_batch_size=1
|
||||
--seed=42
|
||||
--per_device_train_batch_size=2
|
||||
--per_device_eval_batch_size=1
|
||||
""".split()
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
result = run_squad.main()
|
||||
self.assertGreaterEqual(result["f1"], 25)
|
||||
self.assertGreaterEqual(result["exact"], 21)
|
||||
self.assertGreaterEqual(result["f1"], 30)
|
||||
self.assertGreaterEqual(result["exact"], 30)
|
||||
|
||||
@require_torch_non_multi_gpu_but_fix_me
|
||||
def test_generation(self):
|
||||
|
||||
Reference in New Issue
Block a user