Supporting Seq2Seq model for question answering task (#13432)
* Add seq2seq example for QnA on SQuAD Dataset. * Changes from review - Fixing styling mistakes. * Added how to example in README, simplified the access to dataset's preprocess function. * Added tests for the seq2seq QA example. * Change dataset column name to fix tests. * Fix test command mistake. * Add missing argument 'ignore_pad_token_for_loss' from DataTrainingArguments. * Add missing argument 'num_beams' from DataTrainingArguments. * Fix processing of output predicted token ids so that tokenizer decode gets appropriate input. Updated assertion conditions on the tests.
This commit is contained in:
@@ -57,6 +57,7 @@ if SRC_DIRS is not None:
|
||||
import run_mlm
|
||||
import run_ner
|
||||
import run_qa as run_squad
|
||||
import run_seq2seq_qa as run_squad_seq2seq
|
||||
import run_speech_recognition_ctc
|
||||
import run_summarization
|
||||
import run_swag
|
||||
@@ -244,6 +245,40 @@ class ExamplesTests(TestCasePlus):
|
||||
self.assertGreaterEqual(result["eval_f1"], 30)
|
||||
self.assertGreaterEqual(result["eval_exact"], 30)
|
||||
|
||||
def test_run_squad_seq2seq(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
run_seq2seq_qa.py
|
||||
--model_name_or_path t5-small
|
||||
--context_column context
|
||||
--question_column question
|
||||
--answer_column answers
|
||||
--version_2_with_negative
|
||||
--train_file tests/fixtures/tests_samples/SQUAD/sample.json
|
||||
--validation_file tests/fixtures/tests_samples/SQUAD/sample.json
|
||||
--output_dir {tmp_dir}
|
||||
--overwrite_output_dir
|
||||
--max_steps=10
|
||||
--warmup_steps=2
|
||||
--do_train
|
||||
--do_eval
|
||||
--learning_rate=2e-4
|
||||
--per_device_train_batch_size=2
|
||||
--per_device_eval_batch_size=1
|
||||
--predict_with_generate
|
||||
""".split()
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
run_squad_seq2seq.main()
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["eval_rouge1"], 10)
|
||||
self.assertGreaterEqual(result["eval_rouge2"], 10)
|
||||
self.assertGreaterEqual(result["eval_rougeL"], 10)
|
||||
self.assertGreaterEqual(result["eval_rougeLsum"], 10)
|
||||
|
||||
def test_run_swag(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
Reference in New Issue
Block a user