New run_seq2seq script (#9605)

* New run_seq2seq script

* Add tests

* Mark as slow

* Update examples/seq2seq/run_seq2seq.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/data/data_collator.py

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* Update src/transformers/data/data_collator.py

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* Address review comments

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
Sylvain Gugger
2021-01-19 15:22:17 -05:00
committed by GitHub
parent fa876aee2a
commit e4c06ed664
7 changed files with 650 additions and 1 deletions

View File

@@ -23,7 +23,7 @@ from unittest.mock import patch
import torch
from transformers.file_utils import is_apex_available
from transformers.testing_utils import TestCasePlus, require_torch_non_multi_gpu_but_fix_me, torch_device
from transformers.testing_utils import TestCasePlus, require_torch_non_multi_gpu_but_fix_me, slow, torch_device
SRC_DIRS = [
@@ -35,6 +35,7 @@ SRC_DIRS = [
"language-modeling",
"multiple-choice",
"question-answering",
"seq2seq",
]
]
sys.path.extend(SRC_DIRS)
@@ -47,6 +48,7 @@ if SRC_DIRS is not None:
import run_mlm
import run_ner
import run_qa as run_squad
import run_seq2seq
import run_swag
@@ -259,3 +261,67 @@ class ExamplesTests(TestCasePlus):
with patch.object(sys, "argv", testargs + [model_type, model_name]):
result = run_generation.main()
self.assertGreaterEqual(len(result[0]), 10)
@slow
@require_torch_non_multi_gpu_but_fix_me
def test_run_seq2seq_summarization(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_seq2seq.py
--model_name_or_path t5-small
--task summarization
--train_file tests/fixtures/tests_samples/xsum/sample.json
--validation_file tests/fixtures/tests_samples/xsum/sample.json
--output_dir {tmp_dir}
--overwrite_output_dir
--max_steps=50
--warmup_steps=8
--do_train
--do_eval
--learning_rate=2e-4
--per_device_train_batch_size=2
--per_device_eval_batch_size=1
--predict_with_generate
""".split()
with patch.object(sys, "argv", testargs):
result = run_seq2seq.main()
self.assertGreaterEqual(result["eval_rouge1"], 10)
self.assertGreaterEqual(result["eval_rouge2"], 2)
self.assertGreaterEqual(result["eval_rougeL"], 7)
self.assertGreaterEqual(result["eval_rougeLsum"], 7)
@slow
@require_torch_non_multi_gpu_but_fix_me
def test_run_seq2seq_translation(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_seq2seq.py
--model_name_or_path sshleifer/student_marian_en_ro_6_1
--task translation_en_to_ro
--train_file tests/fixtures/tests_samples/wmt16/sample.json
--validation_file tests/fixtures/tests_samples/wmt16/sample.json
--output_dir {tmp_dir}
--overwrite_output_dir
--max_steps=50
--warmup_steps=8
--do_train
--do_eval
--learning_rate=3e-3
--per_device_train_batch_size=2
--per_device_eval_batch_size=1
--predict_with_generate
--source_lang en_XX
--target_lang ro_RO
""".split()
with patch.object(sys, "argv", testargs):
result = run_seq2seq.main()
self.assertGreaterEqual(result["eval_bleu"], 30)