Marian distill scripts + integration test (#6799)

This commit is contained in:
Sam Shleifer
2020-08-31 13:48:26 -04:00
committed by GitHub
parent 02d09c8fcc
commit 61b7ba93f5
4 changed files with 132 additions and 14 deletions

View File

@@ -0,0 +1,21 @@
#!/usr/bin/env bash
export PYTHONPATH="../":"${PYTHONPATH}"
export WANDB_PROJECT=dmar
# export MAX_LEN=128
python distillation.py \
--learning_rate=3e-4 \
--do_train \
--do_predict \
--fp16 \
--val_check_interval 0.25 \
--teacher Helsinki-NLP/opus-mt-en-ro --data_dir $ENRO_DIR \
--max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \
--student_decoder_layers 3 --student_encoder_layers 6 \
--freeze_encoder --freeze_embeds \
--model_name_or_path IGNORED \
--alpha_hid=3. \
--train_batch_size=$BS --eval_batch_size=$BS \
--tokenizer_name Helsinki-NLP/opus-mt-en-ro \
--warmup_steps 500 --sortish_sampler --logger_name wandb \
--gpus 1 --fp16_opt_level O1 --task translation \
"$@"

View File

@@ -0,0 +1,17 @@
#!/usr/bin/env bash
export PYTHONPATH="../":"${PYTHONPATH}"
export WANDB_PROJECT=dmar
python distillation.py \
--learning_rate=3e-4 \
--do_train \
--do_predict \
--fp16 --no_teacher \
--val_check_interval 0.25 \
--data_dir $ENRO_DIR \
--max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \
--freeze_encoder --freeze_embeds \
--train_batch_size=$BS --eval_batch_size=$BS \
--tokenizer_name $m --model_name_or_path $m \
--warmup_steps 500 --sortish_sampler --logger_name wandb \
--gpus 1 --fp16_opt_level=O1 --task translation \
"$@"

View File

@@ -10,9 +10,10 @@ import pytorch_lightning as pl
import timeout_decorator import timeout_decorator
import torch import torch
from transformers import BartForConditionalGeneration from transformers import BartForConditionalGeneration, MarianMTModel
from transformers.testing_utils import slow from transformers.testing_utils import slow
from .distillation import BartSummarizationDistiller, distill_main
from .finetune import SummarizationModule, main from .finetune import SummarizationModule, main
from .test_seq2seq_examples import CUDA_AVAILABLE, MBART_TINY from .test_seq2seq_examples import CUDA_AVAILABLE, MBART_TINY
from .utils import load_json from .utils import load_json
@@ -20,6 +21,7 @@ from .utils import load_json
MODEL_NAME = MBART_TINY MODEL_NAME = MBART_TINY
# TODO(SS): MODEL_NAME = "sshleifer/student_mbart_en_ro_1_1" # TODO(SS): MODEL_NAME = "sshleifer/student_mbart_en_ro_1_1"
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
@slow @slow
@@ -27,6 +29,7 @@ MODEL_NAME = MBART_TINY
def test_model_download(): def test_model_download():
"""This warms up the cache so that we can time the next test without including download time, which varies between machines.""" """This warms up the cache so that we can time the next test without including download time, which varies between machines."""
BartForConditionalGeneration.from_pretrained(MODEL_NAME) BartForConditionalGeneration.from_pretrained(MODEL_NAME)
MarianMTModel.from_pretrained(MARIAN_MODEL)
@timeout_decorator.timeout(120) @timeout_decorator.timeout(120)
@@ -35,34 +38,30 @@ def test_model_download():
def test_train_mbart_cc25_enro_script(): def test_train_mbart_cc25_enro_script():
data_dir = "examples/seq2seq/test_data/wmt_en_ro" data_dir = "examples/seq2seq/test_data/wmt_en_ro"
env_vars_to_replace = { env_vars_to_replace = {
"$MAX_LEN": 200, "--fp16_opt_level=O1": "",
"$MAX_LEN": 128,
"$BS": 4, "$BS": 4,
"$GAS": 1, "$GAS": 1,
"$ENRO_DIR": data_dir, "$ENRO_DIR": data_dir,
"facebook/mbart-large-cc25": MODEL_NAME, "facebook/mbart-large-cc25": MODEL_NAME,
# 1 encoder and 1 decoder layer from finetuned mbart en-ro. Should be able to start >0 and improve quickly. # Download is 120MB in previous test.
# Download is 600MB in previous test.
"val_check_interval=0.25": "val_check_interval=1.0", "val_check_interval=0.25": "val_check_interval=1.0",
} }
# Clean up bash script # Clean up bash script
bash_script = Path("examples/seq2seq/train_mbart_cc25_enro.sh").open().read().split("finetune.py")[1].strip() bash_script = Path("examples/seq2seq/train_mbart_cc25_enro.sh").open().read().split("finetune.py")[1].strip()
bash_script = bash_script.replace("\\\n", "").strip().replace("$@", "") bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "")
for k, v in env_vars_to_replace.items(): for k, v in env_vars_to_replace.items():
bash_script = bash_script.replace(k, str(v)) bash_script = bash_script.replace(k, str(v))
output_dir = tempfile.mkdtemp(prefix="output") output_dir = tempfile.mkdtemp(prefix="output_mbart")
if CUDA_AVAILABLE:
gpus = 1 # torch.cuda.device_count()
else:
gpus = 0
bash_script = bash_script.replace("--fp16 ", "") bash_script = bash_script.replace("--fp16 ", "")
testargs = ( testargs = (
["finetune.py"] ["finetune.py"]
+ bash_script.split() + bash_script.split()
+ [ + [
f"--output_dir={output_dir}", f"--output_dir={output_dir}",
f"--gpus={gpus}", "--gpus=1",
"--learning_rate=3e-1", "--learning_rate=3e-1",
"--warmup_steps=0", "--warmup_steps=0",
"--val_check_interval=1.0", "--val_check_interval=1.0",
@@ -82,7 +81,86 @@ def test_train_mbart_cc25_enro_script():
metrics = load_json(model.metrics_save_path) metrics = load_json(model.metrics_save_path)
first_step_stats = metrics["val"][0] first_step_stats = metrics["val"][0]
last_step_stats = metrics["val"][-1] last_step_stats = metrics["val"][-1]
assert len(metrics["val"]) == (args.max_epochs / args.val_check_interval) # +1 accounts for val_sanity_check assert len(metrics["val"]) == (args.max_epochs / args.val_check_interval) + 1 # +1 accounts for val_sanity_check
assert last_step_stats["val_avg_gen_time"] >= 0.01
assert first_step_stats["val_avg_bleu"] < last_step_stats["val_avg_bleu"] # model learned nothing
assert 1.0 >= last_step_stats["val_avg_gen_time"] # model hanging on generate. Maybe bad config was saved.
assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float)
# check lightning ckpt can be loaded and has a reasonable statedict
contents = os.listdir(output_dir)
ckpt_path = [x for x in contents if x.endswith(".ckpt")][0]
full_path = os.path.join(args.output_dir, ckpt_path)
ckpt = torch.load(full_path, map_location="cpu")
expected_key = "model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
assert expected_key in ckpt["state_dict"]
assert ckpt["state_dict"]["model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32
# TODO(SS): turn on args.do_predict when PL bug fixed.
if args.do_predict:
contents = {os.path.basename(p) for p in contents}
assert "test_generations.txt" in contents
assert "test_results.txt" in contents
# assert len(metrics["val"]) == desired_n_evals
assert len(metrics["test"]) == 1
@timeout_decorator.timeout(600)
@slow
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU")
def test_opus_mt_distill_script():
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
env_vars_to_replace = {
"--fp16_opt_level=O1": "",
"$MAX_LEN": 128,
"$BS": 16,
"$GAS": 1,
"$ENRO_DIR": data_dir,
"$m": "sshleifer/student_marian_en_ro_6_1",
"val_check_interval=0.25": "val_check_interval=1.0",
}
# Clean up bash script
bash_script = (
Path("examples/seq2seq/distil_marian_no_teacher.sh").open().read().split("distillation.py")[1].strip()
)
bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "")
bash_script = bash_script.replace("--fp16 ", " ")
for k, v in env_vars_to_replace.items():
bash_script = bash_script.replace(k, str(v))
output_dir = tempfile.mkdtemp(prefix="marian_output")
bash_script = bash_script.replace("--fp16", "")
epochs = 6
testargs = (
["distillation.py"]
+ bash_script.split()
+ [
f"--output_dir={output_dir}",
"--gpus=1",
"--learning_rate=1e-3",
f"--num_train_epochs={epochs}",
"--warmup_steps=10",
"--val_check_interval=1.0",
]
)
with patch.object(sys, "argv", testargs):
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = BartSummarizationDistiller.add_model_specific_args(parser, os.getcwd())
args = parser.parse_args()
args.do_predict = False
# assert args.gpus == gpus THIS BREAKS for multigpu
model = distill_main(args)
# Check metrics
metrics = load_json(model.metrics_save_path)
first_step_stats = metrics["val"][0]
last_step_stats = metrics["val"][-1]
assert len(metrics["val"]) == (args.max_epochs / args.val_check_interval) + 1 # +1 accounts for val_sanity_check
assert last_step_stats["val_avg_gen_time"] >= 0.01 assert last_step_stats["val_avg_gen_time"] >= 0.01

View File

@@ -114,7 +114,9 @@ class ExamplesTests(TestCasePlus):
--max_seq_length=128 --max_seq_length=128
""".split() """.split()
if torch.cuda.is_available(): if torch.cuda.is_available():
testargs += ["--fp16", "--gpus=1"] testargs += ["--gpus=1"]
if is_cuda_and_apex_avaliable():
testargs.append("--fp16")
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
result = run_pl_glue.main() result = run_pl_glue.main()