split seq2seq script into summarization & translation (#10611)
* split seq2seq script, update docs * needless diff * fix readme * remove test diff * s/summarization/translation Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * cr * fix arguments & better mbart/t5 refs * copyright Co-authored-by: Suraj Patil <surajp815@gmail.com> * reword readme Co-authored-by: Suraj Patil <surajp815@gmail.com> * s/summarization/translation * short script names * fix tests * fix isort, include mbart doc * delete old script, update tests * automate source prefix * automate source prefix for translation * s/translation/trans Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * fix script name (short version) * typos Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * exact parameter Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * remove superfluous source_prefix calls in docs * rename scripts & warn for source prefix * black * flake8 Co-authored-by: theo <theo@matussie.re> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Suraj Patil <surajp815@gmail.com> Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
This commit is contained in:
@@ -49,8 +49,9 @@ if SRC_DIRS is not None:
|
||||
import run_mlm
|
||||
import run_ner
|
||||
import run_qa as run_squad
|
||||
import run_seq2seq
|
||||
import run_summarization
|
||||
import run_swag
|
||||
import run_translation
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
@@ -277,15 +278,14 @@ class ExamplesTests(TestCasePlus):
|
||||
self.assertGreaterEqual(len(result[0]), 10)
|
||||
|
||||
@slow
|
||||
def test_run_seq2seq_summarization(self):
|
||||
def test_run_summarization(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
run_seq2seq.py
|
||||
run_summarization.py
|
||||
--model_name_or_path t5-small
|
||||
--task summarization
|
||||
--train_file tests/fixtures/tests_samples/xsum/sample.json
|
||||
--validation_file tests/fixtures/tests_samples/xsum/sample.json
|
||||
--output_dir {tmp_dir}
|
||||
@@ -301,7 +301,7 @@ class ExamplesTests(TestCasePlus):
|
||||
""".split()
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
run_seq2seq.main()
|
||||
run_summarization.main()
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["eval_rouge1"], 10)
|
||||
self.assertGreaterEqual(result["eval_rouge2"], 2)
|
||||
@@ -309,15 +309,16 @@ class ExamplesTests(TestCasePlus):
|
||||
self.assertGreaterEqual(result["eval_rougeLsum"], 7)
|
||||
|
||||
@slow
|
||||
def test_run_seq2seq_translation(self):
|
||||
def test_run_translation(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
run_seq2seq.py
|
||||
run_translation.py
|
||||
--model_name_or_path sshleifer/student_marian_en_ro_6_1
|
||||
--task translation_en_to_ro
|
||||
--source_lang en
|
||||
--target_lang ro
|
||||
--train_file tests/fixtures/tests_samples/wmt16/sample.json
|
||||
--validation_file tests/fixtures/tests_samples/wmt16/sample.json
|
||||
--output_dir {tmp_dir}
|
||||
@@ -335,6 +336,6 @@ class ExamplesTests(TestCasePlus):
|
||||
""".split()
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
run_seq2seq.main()
|
||||
run_translation.main()
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["eval_bleu"], 30)
|
||||
|
||||
Reference in New Issue
Block a user