Marian distill scripts + integration test (#6799)
This commit is contained in:
21
examples/seq2seq/distil_marian_enro_teacher.sh
Executable file
21
examples/seq2seq/distil_marian_enro_teacher.sh
Executable file
@@ -0,0 +1,21 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||||
|
export WANDB_PROJECT=dmar
|
||||||
|
# export MAX_LEN=128
|
||||||
|
python distillation.py \
|
||||||
|
--learning_rate=3e-4 \
|
||||||
|
--do_train \
|
||||||
|
--do_predict \
|
||||||
|
--fp16 \
|
||||||
|
--val_check_interval 0.25 \
|
||||||
|
--teacher Helsinki-NLP/opus-mt-en-ro --data_dir $ENRO_DIR \
|
||||||
|
--max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \
|
||||||
|
--student_decoder_layers 3 --student_encoder_layers 6 \
|
||||||
|
--freeze_encoder --freeze_embeds \
|
||||||
|
--model_name_or_path IGNORED \
|
||||||
|
--alpha_hid=3. \
|
||||||
|
--train_batch_size=$BS --eval_batch_size=$BS \
|
||||||
|
--tokenizer_name Helsinki-NLP/opus-mt-en-ro \
|
||||||
|
--warmup_steps 500 --sortish_sampler --logger_name wandb \
|
||||||
|
--gpus 1 --fp16_opt_level O1 --task translation \
|
||||||
|
"$@"
|
||||||
17
examples/seq2seq/distil_marian_no_teacher.sh
Executable file
17
examples/seq2seq/distil_marian_no_teacher.sh
Executable file
@@ -0,0 +1,17 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||||
|
export WANDB_PROJECT=dmar
|
||||||
|
python distillation.py \
|
||||||
|
--learning_rate=3e-4 \
|
||||||
|
--do_train \
|
||||||
|
--do_predict \
|
||||||
|
--fp16 --no_teacher \
|
||||||
|
--val_check_interval 0.25 \
|
||||||
|
--data_dir $ENRO_DIR \
|
||||||
|
--max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \
|
||||||
|
--freeze_encoder --freeze_embeds \
|
||||||
|
--train_batch_size=$BS --eval_batch_size=$BS \
|
||||||
|
--tokenizer_name $m --model_name_or_path $m \
|
||||||
|
--warmup_steps 500 --sortish_sampler --logger_name wandb \
|
||||||
|
--gpus 1 --fp16_opt_level=O1 --task translation \
|
||||||
|
"$@"
|
||||||
@@ -10,9 +10,10 @@ import pytorch_lightning as pl
|
|||||||
import timeout_decorator
|
import timeout_decorator
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import BartForConditionalGeneration
|
from transformers import BartForConditionalGeneration, MarianMTModel
|
||||||
from transformers.testing_utils import slow
|
from transformers.testing_utils import slow
|
||||||
|
|
||||||
|
from .distillation import BartSummarizationDistiller, distill_main
|
||||||
from .finetune import SummarizationModule, main
|
from .finetune import SummarizationModule, main
|
||||||
from .test_seq2seq_examples import CUDA_AVAILABLE, MBART_TINY
|
from .test_seq2seq_examples import CUDA_AVAILABLE, MBART_TINY
|
||||||
from .utils import load_json
|
from .utils import load_json
|
||||||
@@ -20,6 +21,7 @@ from .utils import load_json
|
|||||||
|
|
||||||
MODEL_NAME = MBART_TINY
|
MODEL_NAME = MBART_TINY
|
||||||
# TODO(SS): MODEL_NAME = "sshleifer/student_mbart_en_ro_1_1"
|
# TODO(SS): MODEL_NAME = "sshleifer/student_mbart_en_ro_1_1"
|
||||||
|
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
|
||||||
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@@ -27,6 +29,7 @@ MODEL_NAME = MBART_TINY
|
|||||||
def test_model_download():
|
def test_model_download():
|
||||||
"""This warms up the cache so that we can time the next test without including download time, which varies between machines."""
|
"""This warms up the cache so that we can time the next test without including download time, which varies between machines."""
|
||||||
BartForConditionalGeneration.from_pretrained(MODEL_NAME)
|
BartForConditionalGeneration.from_pretrained(MODEL_NAME)
|
||||||
|
MarianMTModel.from_pretrained(MARIAN_MODEL)
|
||||||
|
|
||||||
|
|
||||||
@timeout_decorator.timeout(120)
|
@timeout_decorator.timeout(120)
|
||||||
@@ -35,34 +38,30 @@ def test_model_download():
|
|||||||
def test_train_mbart_cc25_enro_script():
|
def test_train_mbart_cc25_enro_script():
|
||||||
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
|
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
|
||||||
env_vars_to_replace = {
|
env_vars_to_replace = {
|
||||||
"$MAX_LEN": 200,
|
"--fp16_opt_level=O1": "",
|
||||||
|
"$MAX_LEN": 128,
|
||||||
"$BS": 4,
|
"$BS": 4,
|
||||||
"$GAS": 1,
|
"$GAS": 1,
|
||||||
"$ENRO_DIR": data_dir,
|
"$ENRO_DIR": data_dir,
|
||||||
"facebook/mbart-large-cc25": MODEL_NAME,
|
"facebook/mbart-large-cc25": MODEL_NAME,
|
||||||
# 1 encoder and 1 decoder layer from finetuned mbart en-ro. Should be able to start >0 and improve quickly.
|
# Download is 120MB in previous test.
|
||||||
# Download is 600MB in previous test.
|
|
||||||
"val_check_interval=0.25": "val_check_interval=1.0",
|
"val_check_interval=0.25": "val_check_interval=1.0",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Clean up bash script
|
# Clean up bash script
|
||||||
bash_script = Path("examples/seq2seq/train_mbart_cc25_enro.sh").open().read().split("finetune.py")[1].strip()
|
bash_script = Path("examples/seq2seq/train_mbart_cc25_enro.sh").open().read().split("finetune.py")[1].strip()
|
||||||
bash_script = bash_script.replace("\\\n", "").strip().replace("$@", "")
|
bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "")
|
||||||
for k, v in env_vars_to_replace.items():
|
for k, v in env_vars_to_replace.items():
|
||||||
bash_script = bash_script.replace(k, str(v))
|
bash_script = bash_script.replace(k, str(v))
|
||||||
output_dir = tempfile.mkdtemp(prefix="output")
|
output_dir = tempfile.mkdtemp(prefix="output_mbart")
|
||||||
|
|
||||||
if CUDA_AVAILABLE:
|
bash_script = bash_script.replace("--fp16 ", "")
|
||||||
gpus = 1 # torch.cuda.device_count()
|
|
||||||
else:
|
|
||||||
gpus = 0
|
|
||||||
bash_script = bash_script.replace("--fp16", "")
|
|
||||||
testargs = (
|
testargs = (
|
||||||
["finetune.py"]
|
["finetune.py"]
|
||||||
+ bash_script.split()
|
+ bash_script.split()
|
||||||
+ [
|
+ [
|
||||||
f"--output_dir={output_dir}",
|
f"--output_dir={output_dir}",
|
||||||
f"--gpus={gpus}",
|
"--gpus=1",
|
||||||
"--learning_rate=3e-1",
|
"--learning_rate=3e-1",
|
||||||
"--warmup_steps=0",
|
"--warmup_steps=0",
|
||||||
"--val_check_interval=1.0",
|
"--val_check_interval=1.0",
|
||||||
@@ -82,7 +81,86 @@ def test_train_mbart_cc25_enro_script():
|
|||||||
metrics = load_json(model.metrics_save_path)
|
metrics = load_json(model.metrics_save_path)
|
||||||
first_step_stats = metrics["val"][0]
|
first_step_stats = metrics["val"][0]
|
||||||
last_step_stats = metrics["val"][-1]
|
last_step_stats = metrics["val"][-1]
|
||||||
assert len(metrics["val"]) == (args.max_epochs / args.val_check_interval) # +1 accounts for val_sanity_check
|
assert len(metrics["val"]) == (args.max_epochs / args.val_check_interval) + 1 # +1 accounts for val_sanity_check
|
||||||
|
|
||||||
|
assert last_step_stats["val_avg_gen_time"] >= 0.01
|
||||||
|
|
||||||
|
assert first_step_stats["val_avg_bleu"] < last_step_stats["val_avg_bleu"] # model learned nothing
|
||||||
|
assert 1.0 >= last_step_stats["val_avg_gen_time"] # model hanging on generate. Maybe bad config was saved.
|
||||||
|
assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float)
|
||||||
|
|
||||||
|
# check lightning ckpt can be loaded and has a reasonable statedict
|
||||||
|
contents = os.listdir(output_dir)
|
||||||
|
ckpt_path = [x for x in contents if x.endswith(".ckpt")][0]
|
||||||
|
full_path = os.path.join(args.output_dir, ckpt_path)
|
||||||
|
ckpt = torch.load(full_path, map_location="cpu")
|
||||||
|
expected_key = "model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
|
||||||
|
assert expected_key in ckpt["state_dict"]
|
||||||
|
assert ckpt["state_dict"]["model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32
|
||||||
|
|
||||||
|
# TODO(SS): turn on args.do_predict when PL bug fixed.
|
||||||
|
if args.do_predict:
|
||||||
|
contents = {os.path.basename(p) for p in contents}
|
||||||
|
assert "test_generations.txt" in contents
|
||||||
|
assert "test_results.txt" in contents
|
||||||
|
# assert len(metrics["val"]) == desired_n_evals
|
||||||
|
assert len(metrics["test"]) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@timeout_decorator.timeout(600)
|
||||||
|
@slow
|
||||||
|
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU")
|
||||||
|
def test_opus_mt_distill_script():
|
||||||
|
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
|
||||||
|
env_vars_to_replace = {
|
||||||
|
"--fp16_opt_level=O1": "",
|
||||||
|
"$MAX_LEN": 128,
|
||||||
|
"$BS": 16,
|
||||||
|
"$GAS": 1,
|
||||||
|
"$ENRO_DIR": data_dir,
|
||||||
|
"$m": "sshleifer/student_marian_en_ro_6_1",
|
||||||
|
"val_check_interval=0.25": "val_check_interval=1.0",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Clean up bash script
|
||||||
|
bash_script = (
|
||||||
|
Path("examples/seq2seq/distil_marian_no_teacher.sh").open().read().split("distillation.py")[1].strip()
|
||||||
|
)
|
||||||
|
bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "")
|
||||||
|
bash_script = bash_script.replace("--fp16 ", " ")
|
||||||
|
|
||||||
|
for k, v in env_vars_to_replace.items():
|
||||||
|
bash_script = bash_script.replace(k, str(v))
|
||||||
|
output_dir = tempfile.mkdtemp(prefix="marian_output")
|
||||||
|
bash_script = bash_script.replace("--fp16", "")
|
||||||
|
epochs = 6
|
||||||
|
testargs = (
|
||||||
|
["distillation.py"]
|
||||||
|
+ bash_script.split()
|
||||||
|
+ [
|
||||||
|
f"--output_dir={output_dir}",
|
||||||
|
"--gpus=1",
|
||||||
|
"--learning_rate=1e-3",
|
||||||
|
f"--num_train_epochs={epochs}",
|
||||||
|
"--warmup_steps=10",
|
||||||
|
"--val_check_interval=1.0",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
with patch.object(sys, "argv", testargs):
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser = pl.Trainer.add_argparse_args(parser)
|
||||||
|
parser = BartSummarizationDistiller.add_model_specific_args(parser, os.getcwd())
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.do_predict = False
|
||||||
|
# assert args.gpus == gpus THIS BREAKS for multigpu
|
||||||
|
|
||||||
|
model = distill_main(args)
|
||||||
|
|
||||||
|
# Check metrics
|
||||||
|
metrics = load_json(model.metrics_save_path)
|
||||||
|
first_step_stats = metrics["val"][0]
|
||||||
|
last_step_stats = metrics["val"][-1]
|
||||||
|
assert len(metrics["val"]) == (args.max_epochs / args.val_check_interval) + 1 # +1 accounts for val_sanity_check
|
||||||
|
|
||||||
assert last_step_stats["val_avg_gen_time"] >= 0.01
|
assert last_step_stats["val_avg_gen_time"] >= 0.01
|
||||||
|
|
||||||
|
|||||||
@@ -114,7 +114,9 @@ class ExamplesTests(TestCasePlus):
|
|||||||
--max_seq_length=128
|
--max_seq_length=128
|
||||||
""".split()
|
""".split()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
testargs += ["--fp16", "--gpus=1"]
|
testargs += ["--gpus=1"]
|
||||||
|
if is_cuda_and_apex_avaliable():
|
||||||
|
testargs.append("--fp16")
|
||||||
|
|
||||||
with patch.object(sys, "argv", testargs):
|
with patch.object(sys, "argv", testargs):
|
||||||
result = run_pl_glue.main()
|
result = run_pl_glue.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user