Reorganize examples (#9010)
* Reorganize example folder * Continue reorganization * Change requirements for tests * Final cleanup * Finish regroup with tests all passing * Copyright * Requirements and readme * Make a full link for the documentation * Address review comments * Apply suggestions from code review Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Add symlink * Reorg again * Apply suggestions from code review Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com> * Adapt title * Update to new strucutre * Remove test * Update READMEs Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com>
This commit is contained in:
@@ -1,96 +1,32 @@
|
||||
import argparse
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
import lightning_base
|
||||
from convert_pl_checkpoint_to_hf import convert_pl_to_hf
|
||||
from distillation import distill_main
|
||||
from finetune import SummarizationModule, main
|
||||
from parameterized import parameterized
|
||||
from run_eval import generate_summaries_or_translations, run_generate
|
||||
from run_eval import run_generate
|
||||
from run_eval_search import run_search
|
||||
from transformers import AutoConfig, AutoModelForSeq2SeqLM
|
||||
from transformers.hf_api import HfApi
|
||||
from transformers.testing_utils import CaptureStderr, CaptureStdout, TestCasePlus, require_torch_gpu, slow
|
||||
from utils import ROUGE_KEYS, label_smoothed_nll_loss, lmap, load_json
|
||||
from transformers.testing_utils import CaptureStdout, TestCasePlus, slow
|
||||
from utils import ROUGE_KEYS
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger()
|
||||
CUDA_AVAILABLE = torch.cuda.is_available()
|
||||
CHEAP_ARGS = {
|
||||
"max_tokens_per_batch": None,
|
||||
"supervise_forward": True,
|
||||
"normalize_hidden": True,
|
||||
"label_smoothing": 0.2,
|
||||
"eval_max_gen_length": None,
|
||||
"eval_beams": 1,
|
||||
"val_metric": "loss",
|
||||
"save_top_k": 1,
|
||||
"adafactor": True,
|
||||
"early_stopping_patience": 2,
|
||||
"logger_name": "default",
|
||||
"length_penalty": 0.5,
|
||||
"cache_dir": "",
|
||||
"task": "summarization",
|
||||
"num_workers": 2,
|
||||
"alpha_hid": 0,
|
||||
"freeze_embeds": True,
|
||||
"enc_only": False,
|
||||
"tgt_suffix": "",
|
||||
"resume_from_checkpoint": None,
|
||||
"sortish_sampler": True,
|
||||
"student_decoder_layers": 1,
|
||||
"val_check_interval": 1.0,
|
||||
"output_dir": "",
|
||||
"fp16": False, # TODO(SS): set this to CUDA_AVAILABLE if ci installs apex or start using native amp
|
||||
"no_teacher": False,
|
||||
"fp16_opt_level": "O1",
|
||||
"gpus": 1 if CUDA_AVAILABLE else 0,
|
||||
"n_tpu_cores": 0,
|
||||
"max_grad_norm": 1.0,
|
||||
"do_train": True,
|
||||
"do_predict": True,
|
||||
"accumulate_grad_batches": 1,
|
||||
"server_ip": "",
|
||||
"server_port": "",
|
||||
"seed": 42,
|
||||
"model_name_or_path": "sshleifer/bart-tiny-random",
|
||||
"config_name": "",
|
||||
"tokenizer_name": "facebook/bart-large",
|
||||
"do_lower_case": False,
|
||||
"learning_rate": 0.3,
|
||||
"lr_scheduler": "linear",
|
||||
"weight_decay": 0.0,
|
||||
"adam_epsilon": 1e-08,
|
||||
"warmup_steps": 0,
|
||||
"max_epochs": 1,
|
||||
"train_batch_size": 2,
|
||||
"eval_batch_size": 2,
|
||||
"max_source_length": 12,
|
||||
"max_target_length": 12,
|
||||
"val_max_target_length": 12,
|
||||
"test_max_target_length": 12,
|
||||
"fast_dev_run": False,
|
||||
"no_cache": False,
|
||||
"n_train": -1,
|
||||
"n_val": -1,
|
||||
"n_test": -1,
|
||||
"student_encoder_layers": 1,
|
||||
"freeze_encoder": False,
|
||||
"auto_scale_batch_size": False,
|
||||
"overwrite_output_dir": False,
|
||||
"student": None,
|
||||
}
|
||||
|
||||
|
||||
def _dump_articles(path: Path, articles: list):
|
||||
@@ -98,187 +34,15 @@ def _dump_articles(path: Path, articles: list):
|
||||
Path(path).open("w").writelines(content)
|
||||
|
||||
|
||||
ARTICLES = [" Sam ate lunch today.", "Sams lunch ingredients."]
|
||||
SUMMARIES = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"]
|
||||
T5_TINY = "patrickvonplaten/t5-tiny-random"
|
||||
T5_TINIER = "sshleifer/t5-tinier-random"
|
||||
BART_TINY = "sshleifer/bart-tiny-random"
|
||||
MBART_TINY = "sshleifer/tiny-mbart"
|
||||
MARIAN_TINY = "sshleifer/tiny-marian-en-de"
|
||||
FSMT_TINY = "stas/tiny-wmt19-en-de"
|
||||
|
||||
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
|
||||
|
||||
|
||||
def make_test_data_dir(tmp_dir):
|
||||
for split in ["train", "val", "test"]:
|
||||
_dump_articles(os.path.join(tmp_dir, f"{split}.source"), ARTICLES)
|
||||
_dump_articles(os.path.join(tmp_dir, f"{split}.target"), SUMMARIES)
|
||||
return tmp_dir
|
||||
|
||||
|
||||
class TestSummarizationDistiller(TestCasePlus):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
|
||||
return cls
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_hub_configs(self):
|
||||
"""I put require_torch_gpu cause I only want this to run with self-scheduled."""
|
||||
|
||||
model_list = HfApi().model_list()
|
||||
org = "sshleifer"
|
||||
model_ids = [x.modelId for x in model_list if x.modelId.startswith(org)]
|
||||
allowed_to_be_broken = ["sshleifer/blenderbot-3B", "sshleifer/blenderbot-90M"]
|
||||
failures = []
|
||||
for m in model_ids:
|
||||
if m in allowed_to_be_broken:
|
||||
continue
|
||||
try:
|
||||
AutoConfig.from_pretrained(m)
|
||||
except Exception:
|
||||
failures.append(m)
|
||||
assert not failures, f"The following models could not be loaded through AutoConfig: {failures}"
|
||||
|
||||
def test_distill_no_teacher(self):
|
||||
updates = dict(student_encoder_layers=2, student_decoder_layers=1, no_teacher=True)
|
||||
self._test_distiller_cli(updates)
|
||||
|
||||
def test_distill_checkpointing_with_teacher(self):
|
||||
updates = dict(
|
||||
student_encoder_layers=2,
|
||||
student_decoder_layers=1,
|
||||
max_epochs=4,
|
||||
val_check_interval=0.25,
|
||||
alpha_hid=2.0,
|
||||
model_name_or_path="IGNORE_THIS_IT_DOESNT_GET_USED",
|
||||
)
|
||||
model = self._test_distiller_cli(updates, check_contents=False)
|
||||
|
||||
ckpts = list(Path(model.output_dir).glob("*.ckpt"))
|
||||
self.assertEqual(1, len(ckpts))
|
||||
transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin"))
|
||||
self.assertEqual(len(transformer_ckpts), 2)
|
||||
examples = lmap(str.strip, Path(model.hparams.data_dir).joinpath("test.source").open().readlines())
|
||||
out_path = tempfile.mktemp() # XXX: not being cleaned up
|
||||
generate_summaries_or_translations(examples, out_path, str(model.output_dir / "best_tfmr"))
|
||||
self.assertTrue(Path(out_path).exists())
|
||||
|
||||
out_path_new = self.get_auto_remove_tmp_dir()
|
||||
convert_pl_to_hf(ckpts[0], transformer_ckpts[0].parent, out_path_new)
|
||||
assert os.path.exists(os.path.join(out_path_new, "pytorch_model.bin"))
|
||||
|
||||
def test_loss_fn(self):
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(BART_TINY)
|
||||
input_ids, mask = model.dummy_inputs["input_ids"], model.dummy_inputs["attention_mask"]
|
||||
target_ids = torch.tensor([[0, 4, 8, 2], [0, 8, 2, 1]], dtype=torch.long, device=model.device)
|
||||
decoder_input_ids = target_ids[:, :-1].contiguous() # Why this line?
|
||||
lm_labels = target_ids[:, 1:].clone() # why clone?
|
||||
model_computed_loss = model(
|
||||
input_ids, attention_mask=mask, decoder_input_ids=decoder_input_ids, labels=lm_labels, use_cache=False
|
||||
).loss
|
||||
|
||||
logits = model(input_ids, attention_mask=mask, decoder_input_ids=decoder_input_ids, use_cache=False).logits
|
||||
|
||||
lprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
||||
smoothed_loss, nll_loss = label_smoothed_nll_loss(
|
||||
lprobs, lm_labels, 0.1, ignore_index=model.config.pad_token_id
|
||||
)
|
||||
with self.assertRaises(AssertionError):
|
||||
# TODO: understand why this breaks
|
||||
self.assertEqual(nll_loss, model_computed_loss)
|
||||
|
||||
def test_distill_mbart(self):
|
||||
updates = dict(
|
||||
student_encoder_layers=2,
|
||||
student_decoder_layers=1,
|
||||
num_train_epochs=4,
|
||||
val_check_interval=0.25,
|
||||
alpha_hid=2.0,
|
||||
task="translation",
|
||||
model_name_or_path="IGNORE_THIS_IT_DOESNT_GET_USED",
|
||||
tokenizer_name=MBART_TINY,
|
||||
teacher=MBART_TINY,
|
||||
src_lang="en_XX",
|
||||
tgt_lang="ro_RO",
|
||||
)
|
||||
model = self._test_distiller_cli(updates, check_contents=False)
|
||||
assert model.model.config.model_type == "mbart"
|
||||
|
||||
ckpts = list(Path(model.output_dir).glob("*.ckpt"))
|
||||
self.assertEqual(1, len(ckpts))
|
||||
transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin"))
|
||||
all_files = list(Path(model.output_dir).glob("best_tfmr/*"))
|
||||
assert len(all_files) > 2
|
||||
self.assertEqual(len(transformer_ckpts), 2)
|
||||
|
||||
def test_distill_t5(self):
|
||||
updates = dict(
|
||||
student_encoder_layers=1,
|
||||
student_decoder_layers=1,
|
||||
alpha_hid=2.0,
|
||||
teacher=T5_TINY,
|
||||
model_name_or_path=T5_TINY,
|
||||
tokenizer_name=T5_TINY,
|
||||
)
|
||||
self._test_distiller_cli(updates)
|
||||
|
||||
def test_distill_different_base_models(self):
|
||||
updates = dict(
|
||||
teacher=T5_TINY,
|
||||
student=T5_TINIER,
|
||||
model_name_or_path=T5_TINIER,
|
||||
tokenizer_name=T5_TINIER,
|
||||
)
|
||||
self._test_distiller_cli(updates)
|
||||
|
||||
def _test_distiller_cli(self, updates, check_contents=True):
|
||||
default_updates = dict(
|
||||
label_smoothing=0.0,
|
||||
early_stopping_patience=-1,
|
||||
train_batch_size=1,
|
||||
eval_batch_size=2,
|
||||
max_epochs=2,
|
||||
alpha_mlm=0.2,
|
||||
alpha_ce=0.8,
|
||||
do_predict=True,
|
||||
model_name_or_path="sshleifer/tinier_bart",
|
||||
teacher=CHEAP_ARGS["model_name_or_path"],
|
||||
val_check_interval=0.5,
|
||||
)
|
||||
default_updates.update(updates)
|
||||
args_d: dict = CHEAP_ARGS.copy()
|
||||
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
|
||||
args_d.update(data_dir=tmp_dir, output_dir=output_dir, **default_updates)
|
||||
model = distill_main(argparse.Namespace(**args_d))
|
||||
if not check_contents:
|
||||
return model
|
||||
contents = os.listdir(output_dir)
|
||||
contents = {os.path.basename(p) for p in contents}
|
||||
ckpt_files = [p for p in contents if p.endswith("ckpt")]
|
||||
assert len(ckpt_files) > 0
|
||||
|
||||
self.assertIn("test_generations.txt", contents)
|
||||
self.assertIn("test_results.txt", contents)
|
||||
|
||||
metrics = load_json(model.metrics_save_path)
|
||||
last_step_stats = metrics["val"][-1]
|
||||
self.assertGreaterEqual(last_step_stats["val_avg_gen_time"], 0.01)
|
||||
self.assertGreaterEqual(1.0, last_step_stats["val_avg_gen_time"])
|
||||
self.assertIsInstance(last_step_stats[f"val_avg_{model.val_metric}"], float)
|
||||
desired_n_evals = int(args_d["max_epochs"] * (1 / args_d["val_check_interval"]) + 1)
|
||||
self.assertEqual(len(metrics["val"]), desired_n_evals)
|
||||
self.assertEqual(len(metrics["test"]), 1)
|
||||
return model
|
||||
|
||||
|
||||
class TestTheRest(TestCasePlus):
|
||||
def run_eval_tester(self, model):
|
||||
input_file_name = Path(self.get_auto_remove_tmp_dir()) / "utest_input.source"
|
||||
@@ -365,167 +129,3 @@ class TestTheRest(TestCasePlus):
|
||||
assert w not in cs.out
|
||||
assert Path(output_file_name).exists()
|
||||
os.remove(Path(output_file_name))
|
||||
|
||||
@parameterized.expand(
|
||||
[T5_TINY, BART_TINY, MBART_TINY, MARIAN_TINY, FSMT_TINY],
|
||||
)
|
||||
def test_finetune(self, model):
|
||||
args_d: dict = CHEAP_ARGS.copy()
|
||||
task = "translation" if model in [MBART_TINY, MARIAN_TINY, FSMT_TINY] else "summarization"
|
||||
args_d["label_smoothing"] = 0.1 if task == "translation" else 0
|
||||
|
||||
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
args_d.update(
|
||||
data_dir=tmp_dir,
|
||||
model_name_or_path=model,
|
||||
tokenizer_name=None,
|
||||
train_batch_size=2,
|
||||
eval_batch_size=2,
|
||||
output_dir=output_dir,
|
||||
do_predict=True,
|
||||
task=task,
|
||||
src_lang="en_XX",
|
||||
tgt_lang="ro_RO",
|
||||
freeze_encoder=True,
|
||||
freeze_embeds=True,
|
||||
)
|
||||
assert "n_train" in args_d
|
||||
args = argparse.Namespace(**args_d)
|
||||
module = main(args)
|
||||
|
||||
input_embeds = module.model.get_input_embeddings()
|
||||
assert not input_embeds.weight.requires_grad
|
||||
if model == T5_TINY:
|
||||
lm_head = module.model.lm_head
|
||||
assert not lm_head.weight.requires_grad
|
||||
assert (lm_head.weight == input_embeds.weight).all().item()
|
||||
elif model == FSMT_TINY:
|
||||
fsmt = module.model.model
|
||||
embed_pos = fsmt.decoder.embed_positions
|
||||
assert not embed_pos.weight.requires_grad
|
||||
assert not fsmt.decoder.embed_tokens.weight.requires_grad
|
||||
# check that embeds are not the same
|
||||
assert fsmt.decoder.embed_tokens != fsmt.encoder.embed_tokens
|
||||
else:
|
||||
bart = module.model.model
|
||||
embed_pos = bart.decoder.embed_positions
|
||||
assert not embed_pos.weight.requires_grad
|
||||
assert not bart.shared.weight.requires_grad
|
||||
# check that embeds are the same
|
||||
assert bart.decoder.embed_tokens == bart.encoder.embed_tokens
|
||||
assert bart.decoder.embed_tokens == bart.shared
|
||||
|
||||
example_batch = load_json(module.output_dir / "text_batch.json")
|
||||
assert isinstance(example_batch, dict)
|
||||
assert len(example_batch) >= 4
|
||||
|
||||
def test_finetune_extra_model_args(self):
|
||||
args_d: dict = CHEAP_ARGS.copy()
|
||||
|
||||
task = "summarization"
|
||||
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
|
||||
|
||||
args_d.update(
|
||||
data_dir=tmp_dir,
|
||||
tokenizer_name=None,
|
||||
train_batch_size=2,
|
||||
eval_batch_size=2,
|
||||
do_predict=False,
|
||||
task=task,
|
||||
src_lang="en_XX",
|
||||
tgt_lang="ro_RO",
|
||||
freeze_encoder=True,
|
||||
freeze_embeds=True,
|
||||
)
|
||||
|
||||
# test models whose config includes the extra_model_args
|
||||
model = BART_TINY
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
args_d1 = args_d.copy()
|
||||
args_d1.update(
|
||||
model_name_or_path=model,
|
||||
output_dir=output_dir,
|
||||
)
|
||||
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout")
|
||||
for p in extra_model_params:
|
||||
args_d1[p] = 0.5
|
||||
args = argparse.Namespace(**args_d1)
|
||||
model = main(args)
|
||||
for p in extra_model_params:
|
||||
assert getattr(model.config, p) == 0.5, f"failed to override the model config for param {p}"
|
||||
|
||||
# test models whose config doesn't include the extra_model_args
|
||||
model = T5_TINY
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
args_d2 = args_d.copy()
|
||||
args_d2.update(
|
||||
model_name_or_path=model,
|
||||
output_dir=output_dir,
|
||||
)
|
||||
unsupported_param = "encoder_layerdrop"
|
||||
args_d2[unsupported_param] = 0.5
|
||||
args = argparse.Namespace(**args_d2)
|
||||
with pytest.raises(Exception) as excinfo:
|
||||
model = main(args)
|
||||
assert str(excinfo.value) == f"model config doesn't have a `{unsupported_param}` attribute"
|
||||
|
||||
def test_finetune_lr_schedulers(self):
|
||||
args_d: dict = CHEAP_ARGS.copy()
|
||||
|
||||
task = "summarization"
|
||||
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
|
||||
|
||||
model = BART_TINY
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
|
||||
args_d.update(
|
||||
data_dir=tmp_dir,
|
||||
model_name_or_path=model,
|
||||
output_dir=output_dir,
|
||||
tokenizer_name=None,
|
||||
train_batch_size=2,
|
||||
eval_batch_size=2,
|
||||
do_predict=False,
|
||||
task=task,
|
||||
src_lang="en_XX",
|
||||
tgt_lang="ro_RO",
|
||||
freeze_encoder=True,
|
||||
freeze_embeds=True,
|
||||
)
|
||||
|
||||
# emulate finetune.py
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
|
||||
args = {"--help": True}
|
||||
|
||||
# --help test
|
||||
with pytest.raises(SystemExit) as excinfo:
|
||||
with CaptureStdout() as cs:
|
||||
args = parser.parse_args(args)
|
||||
assert False, "--help is expected to sys.exit"
|
||||
assert excinfo.type == SystemExit
|
||||
expected = lightning_base.arg_to_scheduler_metavar
|
||||
assert expected in cs.out, "--help is expected to list the supported schedulers"
|
||||
|
||||
# --lr_scheduler=non_existing_scheduler test
|
||||
unsupported_param = "non_existing_scheduler"
|
||||
args = {f"--lr_scheduler={unsupported_param}"}
|
||||
with pytest.raises(SystemExit) as excinfo:
|
||||
with CaptureStderr() as cs:
|
||||
args = parser.parse_args(args)
|
||||
assert False, "invalid argument is expected to sys.exit"
|
||||
assert excinfo.type == SystemExit
|
||||
expected = f"invalid choice: '{unsupported_param}'"
|
||||
assert expected in cs.err, f"should have bailed on invalid choice of scheduler {unsupported_param}"
|
||||
|
||||
# --lr_scheduler=existing_scheduler test
|
||||
supported_param = "cosine"
|
||||
args_d1 = args_d.copy()
|
||||
args_d1["lr_scheduler"] = supported_param
|
||||
args = argparse.Namespace(**args_d1)
|
||||
model = main(args)
|
||||
assert (
|
||||
getattr(model.hparams, "lr_scheduler") == supported_param
|
||||
), f"lr_scheduler={supported_param} shouldn't fail"
|
||||
|
||||
Reference in New Issue
Block a user