split seq2seq script into summarization & translation (#10611)
* split seq2seq script, update docs * needless diff * fix readme * remove test diff * s/summarization/translation Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * cr * fix arguments & better mbart/t5 refs * copyright Co-authored-by: Suraj Patil <surajp815@gmail.com> * reword readme Co-authored-by: Suraj Patil <surajp815@gmail.com> * s/summarization/translation * short script names * fix tests * fix isort, include mbart doc * delete old script, update tests * automate source prefix * automate source prefix for translation * s/translation/trans Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * fix script name (short version) * typos Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * exact parameter Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * remove superfluous source_prefix calls in docs * rename scripts & warn for source prefix * black * flake8 Co-authored-by: theo <theo@matussie.re> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Suraj Patil <surajp815@gmail.com> Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
This commit is contained in:
@@ -35,7 +35,7 @@ from transformers.trainer_utils import set_seed
|
||||
|
||||
bindir = os.path.abspath(os.path.dirname(__file__))
|
||||
sys.path.append(f"{bindir}/../../seq2seq")
|
||||
from run_seq2seq import main # noqa
|
||||
from run_translation import main # noqa
|
||||
|
||||
|
||||
set_seed(42)
|
||||
@@ -209,7 +209,6 @@ class TestTrainerExt(TestCasePlus):
|
||||
--group_by_length
|
||||
--label_smoothing_factor 0.1
|
||||
--adafactor
|
||||
--task translation
|
||||
--target_lang ro_RO
|
||||
--source_lang en_XX
|
||||
"""
|
||||
@@ -226,12 +225,12 @@ class TestTrainerExt(TestCasePlus):
|
||||
distributed_args = f"""
|
||||
-m torch.distributed.launch
|
||||
--nproc_per_node={n_gpu}
|
||||
{self.examples_dir_str}/seq2seq/run_seq2seq.py
|
||||
{self.examples_dir_str}/seq2seq/run_translation.py
|
||||
""".split()
|
||||
cmd = [sys.executable] + distributed_args + args
|
||||
execute_subprocess_async(cmd, env=self.get_env())
|
||||
else:
|
||||
testargs = ["run_seq2seq.py"] + args
|
||||
testargs = ["run_translation.py"] + args
|
||||
with patch.object(sys, "argv", testargs):
|
||||
main()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user