split seq2seq script into summarization & translation (#10611)

* split seq2seq script, update docs

* needless diff

* fix readme

* remove test diff

* s/summarization/translation

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* cr

* fix arguments & better mbart/t5 refs

* copyright

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* reword readme

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* s/summarization/translation

* short script names

* fix tests

* fix isort, include mbart doc

* delete old script, update tests

* automate source prefix

* automate source prefix for translation

* s/translation/trans

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* fix script name (short version)

* typos

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* exact parameter

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* remove superfluous source_prefix calls in docs

* rename scripts & warn for source prefix

* black

* flake8

Co-authored-by: theo <theo@matussie.re>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Suraj Patil <surajp815@gmail.com>
Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
This commit is contained in:
Théo Matussière
2021-03-15 14:11:42 +01:00
committed by GitHub
parent 505494a86f
commit 6f840990a7
9 changed files with 653 additions and 168 deletions

View File

@@ -49,8 +49,9 @@ if SRC_DIRS is not None:
import run_mlm
import run_ner
import run_qa as run_squad
import run_seq2seq
import run_summarization
import run_swag
import run_translation
logging.basicConfig(level=logging.DEBUG)
@@ -277,15 +278,14 @@ class ExamplesTests(TestCasePlus):
self.assertGreaterEqual(len(result[0]), 10)
@slow
def test_run_seq2seq_summarization(self):
def test_run_summarization(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_seq2seq.py
run_summarization.py
--model_name_or_path t5-small
--task summarization
--train_file tests/fixtures/tests_samples/xsum/sample.json
--validation_file tests/fixtures/tests_samples/xsum/sample.json
--output_dir {tmp_dir}
@@ -301,7 +301,7 @@ class ExamplesTests(TestCasePlus):
""".split()
with patch.object(sys, "argv", testargs):
run_seq2seq.main()
run_summarization.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_rouge1"], 10)
self.assertGreaterEqual(result["eval_rouge2"], 2)
@@ -309,15 +309,16 @@ class ExamplesTests(TestCasePlus):
self.assertGreaterEqual(result["eval_rougeLsum"], 7)
@slow
def test_run_seq2seq_translation(self):
def test_run_translation(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_seq2seq.py
run_translation.py
--model_name_or_path sshleifer/student_marian_en_ro_6_1
--task translation_en_to_ro
--source_lang en
--target_lang ro
--train_file tests/fixtures/tests_samples/wmt16/sample.json
--validation_file tests/fixtures/tests_samples/wmt16/sample.json
--output_dir {tmp_dir}
@@ -335,6 +336,6 @@ class ExamplesTests(TestCasePlus):
""".split()
with patch.object(sys, "argv", testargs):
run_seq2seq.main()
run_translation.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_bleu"], 30)