New run_seq2seq script (#9605)
* New run_seq2seq script * Add tests * Mark as slow * Update examples/seq2seq/run_seq2seq.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/data/data_collator.py Co-authored-by: Suraj Patil <surajp815@gmail.com> * Update src/transformers/data/data_collator.py Co-authored-by: Suraj Patil <surajp815@gmail.com> * Address review comments Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
@@ -23,7 +23,7 @@ from unittest.mock import patch
|
||||
import torch
|
||||
|
||||
from transformers.file_utils import is_apex_available
|
||||
from transformers.testing_utils import TestCasePlus, require_torch_non_multi_gpu_but_fix_me, torch_device
|
||||
from transformers.testing_utils import TestCasePlus, require_torch_non_multi_gpu_but_fix_me, slow, torch_device
|
||||
|
||||
|
||||
SRC_DIRS = [
|
||||
@@ -35,6 +35,7 @@ SRC_DIRS = [
|
||||
"language-modeling",
|
||||
"multiple-choice",
|
||||
"question-answering",
|
||||
"seq2seq",
|
||||
]
|
||||
]
|
||||
sys.path.extend(SRC_DIRS)
|
||||
@@ -47,6 +48,7 @@ if SRC_DIRS is not None:
|
||||
import run_mlm
|
||||
import run_ner
|
||||
import run_qa as run_squad
|
||||
import run_seq2seq
|
||||
import run_swag
|
||||
|
||||
|
||||
@@ -259,3 +261,67 @@ class ExamplesTests(TestCasePlus):
|
||||
with patch.object(sys, "argv", testargs + [model_type, model_name]):
|
||||
result = run_generation.main()
|
||||
self.assertGreaterEqual(len(result[0]), 10)
|
||||
|
||||
@slow
|
||||
@require_torch_non_multi_gpu_but_fix_me
|
||||
def test_run_seq2seq_summarization(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
run_seq2seq.py
|
||||
--model_name_or_path t5-small
|
||||
--task summarization
|
||||
--train_file tests/fixtures/tests_samples/xsum/sample.json
|
||||
--validation_file tests/fixtures/tests_samples/xsum/sample.json
|
||||
--output_dir {tmp_dir}
|
||||
--overwrite_output_dir
|
||||
--max_steps=50
|
||||
--warmup_steps=8
|
||||
--do_train
|
||||
--do_eval
|
||||
--learning_rate=2e-4
|
||||
--per_device_train_batch_size=2
|
||||
--per_device_eval_batch_size=1
|
||||
--predict_with_generate
|
||||
""".split()
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
result = run_seq2seq.main()
|
||||
|
||||
self.assertGreaterEqual(result["eval_rouge1"], 10)
|
||||
self.assertGreaterEqual(result["eval_rouge2"], 2)
|
||||
self.assertGreaterEqual(result["eval_rougeL"], 7)
|
||||
self.assertGreaterEqual(result["eval_rougeLsum"], 7)
|
||||
|
||||
@slow
|
||||
@require_torch_non_multi_gpu_but_fix_me
|
||||
def test_run_seq2seq_translation(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
run_seq2seq.py
|
||||
--model_name_or_path sshleifer/student_marian_en_ro_6_1
|
||||
--task translation_en_to_ro
|
||||
--train_file tests/fixtures/tests_samples/wmt16/sample.json
|
||||
--validation_file tests/fixtures/tests_samples/wmt16/sample.json
|
||||
--output_dir {tmp_dir}
|
||||
--overwrite_output_dir
|
||||
--max_steps=50
|
||||
--warmup_steps=8
|
||||
--do_train
|
||||
--do_eval
|
||||
--learning_rate=3e-3
|
||||
--per_device_train_batch_size=2
|
||||
--per_device_eval_batch_size=1
|
||||
--predict_with_generate
|
||||
--source_lang en_XX
|
||||
--target_lang ro_RO
|
||||
""".split()
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
result = run_seq2seq.main()
|
||||
self.assertGreaterEqual(result["eval_bleu"], 30)
|
||||
|
||||
Reference in New Issue
Block a user