From 61b7ba93f5f4dfcef795e20a9fb11b2d4ee7608e Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Mon, 31 Aug 2020 13:48:26 -0400 Subject: [PATCH] Marian distill scripts + integration test (#6799) --- .../seq2seq/distil_marian_enro_teacher.sh | 21 ++++ examples/seq2seq/distil_marian_no_teacher.sh | 17 +++ examples/seq2seq/test_bash_script.py | 104 +++++++++++++++--- examples/test_examples.py | 4 +- 4 files changed, 132 insertions(+), 14 deletions(-) create mode 100755 examples/seq2seq/distil_marian_enro_teacher.sh create mode 100755 examples/seq2seq/distil_marian_no_teacher.sh diff --git a/examples/seq2seq/distil_marian_enro_teacher.sh b/examples/seq2seq/distil_marian_enro_teacher.sh new file mode 100755 index 0000000000..575ecf8502 --- /dev/null +++ b/examples/seq2seq/distil_marian_enro_teacher.sh @@ -0,0 +1,21 @@ +#!/usr/bin/env bash +export PYTHONPATH="../":"${PYTHONPATH}" +export WANDB_PROJECT=dmar +# export MAX_LEN=128 +python distillation.py \ + --learning_rate=3e-4 \ + --do_train \ + --do_predict \ + --fp16 \ + --val_check_interval 0.25 \ + --teacher Helsinki-NLP/opus-mt-en-ro --data_dir $ENRO_DIR \ + --max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \ + --student_decoder_layers 3 --student_encoder_layers 6 \ + --freeze_encoder --freeze_embeds \ + --model_name_or_path IGNORED \ + --alpha_hid=3. \ + --train_batch_size=$BS --eval_batch_size=$BS \ + --tokenizer_name Helsinki-NLP/opus-mt-en-ro \ + --warmup_steps 500 --sortish_sampler --logger_name wandb \ + --gpus 1 --fp16_opt_level O1 --task translation \ + "$@" diff --git a/examples/seq2seq/distil_marian_no_teacher.sh b/examples/seq2seq/distil_marian_no_teacher.sh new file mode 100755 index 0000000000..66fdda1d17 --- /dev/null +++ b/examples/seq2seq/distil_marian_no_teacher.sh @@ -0,0 +1,17 @@ +#!/usr/bin/env bash +export PYTHONPATH="../":"${PYTHONPATH}" +export WANDB_PROJECT=dmar +python distillation.py \ + --learning_rate=3e-4 \ + --do_train \ + --do_predict \ + --fp16 --no_teacher \ + --val_check_interval 0.25 \ + --data_dir $ENRO_DIR \ + --max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \ + --freeze_encoder --freeze_embeds \ + --train_batch_size=$BS --eval_batch_size=$BS \ + --tokenizer_name $m --model_name_or_path $m \ + --warmup_steps 500 --sortish_sampler --logger_name wandb \ + --gpus 1 --fp16_opt_level=O1 --task translation \ + "$@" diff --git a/examples/seq2seq/test_bash_script.py b/examples/seq2seq/test_bash_script.py index a9cb6e3a09..d352f30008 100644 --- a/examples/seq2seq/test_bash_script.py +++ b/examples/seq2seq/test_bash_script.py @@ -10,9 +10,10 @@ import pytorch_lightning as pl import timeout_decorator import torch -from transformers import BartForConditionalGeneration +from transformers import BartForConditionalGeneration, MarianMTModel from transformers.testing_utils import slow +from .distillation import BartSummarizationDistiller, distill_main from .finetune import SummarizationModule, main from .test_seq2seq_examples import CUDA_AVAILABLE, MBART_TINY from .utils import load_json @@ -20,6 +21,7 @@ from .utils import load_json MODEL_NAME = MBART_TINY # TODO(SS): MODEL_NAME = "sshleifer/student_mbart_en_ro_1_1" +MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1" @slow @@ -27,6 +29,7 @@ MODEL_NAME = MBART_TINY def test_model_download(): """This warms up the cache so that we can time the next test without including download time, which varies between machines.""" BartForConditionalGeneration.from_pretrained(MODEL_NAME) + MarianMTModel.from_pretrained(MARIAN_MODEL) @timeout_decorator.timeout(120) @@ -35,34 +38,30 @@ def test_model_download(): def test_train_mbart_cc25_enro_script(): data_dir = "examples/seq2seq/test_data/wmt_en_ro" env_vars_to_replace = { - "$MAX_LEN": 200, + "--fp16_opt_level=O1": "", + "$MAX_LEN": 128, "$BS": 4, "$GAS": 1, "$ENRO_DIR": data_dir, "facebook/mbart-large-cc25": MODEL_NAME, - # 1 encoder and 1 decoder layer from finetuned mbart en-ro. Should be able to start >0 and improve quickly. - # Download is 600MB in previous test. + # Download is 120MB in previous test. "val_check_interval=0.25": "val_check_interval=1.0", } # Clean up bash script bash_script = Path("examples/seq2seq/train_mbart_cc25_enro.sh").open().read().split("finetune.py")[1].strip() - bash_script = bash_script.replace("\\\n", "").strip().replace("$@", "") + bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "") for k, v in env_vars_to_replace.items(): bash_script = bash_script.replace(k, str(v)) - output_dir = tempfile.mkdtemp(prefix="output") + output_dir = tempfile.mkdtemp(prefix="output_mbart") - if CUDA_AVAILABLE: - gpus = 1 # torch.cuda.device_count() - else: - gpus = 0 - bash_script = bash_script.replace("--fp16", "") + bash_script = bash_script.replace("--fp16 ", "") testargs = ( ["finetune.py"] + bash_script.split() + [ f"--output_dir={output_dir}", - f"--gpus={gpus}", + "--gpus=1", "--learning_rate=3e-1", "--warmup_steps=0", "--val_check_interval=1.0", @@ -82,7 +81,86 @@ def test_train_mbart_cc25_enro_script(): metrics = load_json(model.metrics_save_path) first_step_stats = metrics["val"][0] last_step_stats = metrics["val"][-1] - assert len(metrics["val"]) == (args.max_epochs / args.val_check_interval) # +1 accounts for val_sanity_check + assert len(metrics["val"]) == (args.max_epochs / args.val_check_interval) + 1 # +1 accounts for val_sanity_check + + assert last_step_stats["val_avg_gen_time"] >= 0.01 + + assert first_step_stats["val_avg_bleu"] < last_step_stats["val_avg_bleu"] # model learned nothing + assert 1.0 >= last_step_stats["val_avg_gen_time"] # model hanging on generate. Maybe bad config was saved. + assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float) + + # check lightning ckpt can be loaded and has a reasonable statedict + contents = os.listdir(output_dir) + ckpt_path = [x for x in contents if x.endswith(".ckpt")][0] + full_path = os.path.join(args.output_dir, ckpt_path) + ckpt = torch.load(full_path, map_location="cpu") + expected_key = "model.model.decoder.layers.0.encoder_attn_layer_norm.weight" + assert expected_key in ckpt["state_dict"] + assert ckpt["state_dict"]["model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32 + + # TODO(SS): turn on args.do_predict when PL bug fixed. + if args.do_predict: + contents = {os.path.basename(p) for p in contents} + assert "test_generations.txt" in contents + assert "test_results.txt" in contents + # assert len(metrics["val"]) == desired_n_evals + assert len(metrics["test"]) == 1 + + +@timeout_decorator.timeout(600) +@slow +@pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU") +def test_opus_mt_distill_script(): + data_dir = "examples/seq2seq/test_data/wmt_en_ro" + env_vars_to_replace = { + "--fp16_opt_level=O1": "", + "$MAX_LEN": 128, + "$BS": 16, + "$GAS": 1, + "$ENRO_DIR": data_dir, + "$m": "sshleifer/student_marian_en_ro_6_1", + "val_check_interval=0.25": "val_check_interval=1.0", + } + + # Clean up bash script + bash_script = ( + Path("examples/seq2seq/distil_marian_no_teacher.sh").open().read().split("distillation.py")[1].strip() + ) + bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "") + bash_script = bash_script.replace("--fp16 ", " ") + + for k, v in env_vars_to_replace.items(): + bash_script = bash_script.replace(k, str(v)) + output_dir = tempfile.mkdtemp(prefix="marian_output") + bash_script = bash_script.replace("--fp16", "") + epochs = 6 + testargs = ( + ["distillation.py"] + + bash_script.split() + + [ + f"--output_dir={output_dir}", + "--gpus=1", + "--learning_rate=1e-3", + f"--num_train_epochs={epochs}", + "--warmup_steps=10", + "--val_check_interval=1.0", + ] + ) + with patch.object(sys, "argv", testargs): + parser = argparse.ArgumentParser() + parser = pl.Trainer.add_argparse_args(parser) + parser = BartSummarizationDistiller.add_model_specific_args(parser, os.getcwd()) + args = parser.parse_args() + args.do_predict = False + # assert args.gpus == gpus THIS BREAKS for multigpu + + model = distill_main(args) + + # Check metrics + metrics = load_json(model.metrics_save_path) + first_step_stats = metrics["val"][0] + last_step_stats = metrics["val"][-1] + assert len(metrics["val"]) == (args.max_epochs / args.val_check_interval) + 1 # +1 accounts for val_sanity_check assert last_step_stats["val_avg_gen_time"] >= 0.01 diff --git a/examples/test_examples.py b/examples/test_examples.py index c6e1d34f89..40bfd2c81a 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -114,7 +114,9 @@ class ExamplesTests(TestCasePlus): --max_seq_length=128 """.split() if torch.cuda.is_available(): - testargs += ["--fp16", "--gpus=1"] + testargs += ["--gpus=1"] + if is_cuda_and_apex_avaliable(): + testargs.append("--fp16") with patch.object(sys, "argv", testargs): result = run_pl_glue.main()