From 48f23f92a80500b1475a84566841efe6581a94c0 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Thu, 1 Oct 2020 00:33:01 -0400 Subject: [PATCH] [s2sTrainer] test + code cleanup (#7467) --- examples/seq2seq/finetune.py | 20 +------ examples/seq2seq/finetune_trainer.py | 72 ++++++++--------------- examples/seq2seq/seq2seq_trainer.py | 35 +++++------ examples/seq2seq/test_finetune_trainer.py | 72 +++++++++++++---------- examples/seq2seq/utils.py | 19 ++++++ 5 files changed, 102 insertions(+), 116 deletions(-) diff --git a/examples/seq2seq/finetune.py b/examples/seq2seq/finetune.py index b11fee1eda..65343b2f0c 100755 --- a/examples/seq2seq/finetune.py +++ b/examples/seq2seq/finetune.py @@ -26,6 +26,7 @@ from utils import ( calculate_bleu, calculate_rouge, flatten_list, + freeze_embeds, freeze_params, get_git_info, label_smoothed_nll_loss, @@ -90,7 +91,7 @@ class SummarizationModule(BaseTransformer): assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}" assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}" if self.hparams.freeze_embeds: - self.freeze_embeds() + freeze_embeds(self.model) if self.hparams.freeze_encoder: freeze_params(self.model.get_encoder()) assert_all_frozen(self.model.get_encoder()) @@ -105,29 +106,12 @@ class SummarizationModule(BaseTransformer): Seq2SeqDataset if hasattr(self.tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset ) self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams - assert self.eval_beams >= 1, f"got self.eval_beams={self.eval_beams}. Need an integer > 1" if self.hparams.eval_max_gen_length is not None: self.eval_max_length = self.hparams.eval_max_gen_length else: self.eval_max_length = self.model.config.max_length self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric - def freeze_embeds(self): - """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5.""" - if self.model_type == "t5": - freeze_params(self.model.shared) - for d in [self.model.encoder, self.model.decoder]: - freeze_params(d.embed_tokens) - elif self.model_type == "fsmt": - for d in [self.model.model.encoder, self.model.model.decoder]: - freeze_params(d.embed_positions) - freeze_params(d.embed_tokens) - else: - freeze_params(self.model.model.shared) - for d in [self.model.model.encoder, self.model.model.decoder]: - freeze_params(d.embed_positions) - freeze_params(d.embed_tokens) - def forward(self, input_ids, **kwargs): return self.model(input_ids, **kwargs) diff --git a/examples/seq2seq/finetune_trainer.py b/examples/seq2seq/finetune_trainer.py index 8dd1f6234a..baa0d5b70e 100644 --- a/examples/seq2seq/finetune_trainer.py +++ b/examples/seq2seq/finetune_trainer.py @@ -1,4 +1,3 @@ -import json import logging import os import sys @@ -29,10 +28,13 @@ from utils import ( assert_all_frozen, calculate_bleu, calculate_rouge, + freeze_embeds, freeze_params, lmap, + save_json, trim_batch, use_task_specific_params, + write_txt_file, ) @@ -43,6 +45,7 @@ class Seq2SeqDataCollator: def __init__(self, tokenizer, data_args, tpu_num_cores=None): self.tokenizer = tokenizer self.pad_token_id = tokenizer.pad_token_id + assert self.pad_token_id is not None, "self.pad_token_id must be defined" self.data_args = data_args self.tpu_num_cores = tpu_num_cores self.add_prefix_space = isinstance(tokenizer, BartTokenizer) @@ -65,10 +68,8 @@ class Seq2SeqDataCollator: if isinstance(self.tokenizer, T5Tokenizer): decoder_input_ids = self._shift_right_t5(labels) - labels = labels else: decoder_input_ids = shift_tokens_right(labels, self.pad_token_id) - labels = labels batch = { "input_ids": input_ids, @@ -79,17 +80,10 @@ class Seq2SeqDataCollator: return batch def _shift_right_t5(self, input_ids): - decoder_start_token_id = self.pad_token_id - - assert ( - decoder_start_token_id is not None - ), "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. See T5 docs for more information" - # shift inputs to the right shifted_input_ids = input_ids.new_zeros(input_ids.shape) shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() - shifted_input_ids[..., 0] = decoder_start_token_id - + shifted_input_ids[..., 0] = self.pad_token_id return shifted_input_ids def _encode(self, batch) -> Dict[str, torch.Tensor]: @@ -267,17 +261,15 @@ def main(): use_task_specific_params(model, data_args.task) # set num_beams for evaluation - if data_args.eval_beams is not None: - model.config.num_beams = data_args.eval_beams - assert model.config.num_beams >= 1, f"got eval_beams={model.config.num_beams}. Need an integer >= 1" - - # set max length for generation - model.config.max_generate_length = data_args.val_max_target_length + if data_args.eval_beams is None: + data_args.eval_beams = model.config.num_beams # set decoder_start_token_id for MBart if model.config.decoder_start_token_id is None and isinstance(tokenizer, MBartTokenizer): - decoder_start_token_id = tokenizer.lang_code_to_id[data_args.tgt_lang] - model.config.decoder_start_token_id = decoder_start_token_id + assert ( + data_args.tgt_lang is not None and data_args.src_lang is not None + ), "mBart requires --tgt_lang and --src_lang" + model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.tgt_lang] def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]: def non_pad_len(tokens: np.ndarray) -> int: @@ -293,32 +285,20 @@ def main(): def summarization_metrics(pred: EvalPrediction) -> Dict: pred_str, label_str = decode_pred(pred) rouge: Dict = calculate_rouge(pred_str, label_str) - summ_len = np.mean(lmap(non_pad_len, pred.predictions)) + summ_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1) rouge.update({"gen_len": summ_len}) return rouge def translation_metrics(pred: EvalPrediction) -> Dict: pred_str, label_str = decode_pred(pred) bleu: Dict = calculate_bleu(pred_str, label_str) - gen_len = np.mean(lmap(non_pad_len, pred.predictions)) + gen_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1) bleu.update({"gen_len": gen_len}) return bleu compute_metrics_fn = summarization_metrics if "summarization" in task_name else translation_metrics return compute_metrics_fn - def freeze_embeds(model: torch.nn.Module): - """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5.""" - try: - freeze_params(model.model.shared) - for d in [model.model.encoder, model.model.decoder]: - freeze_params(d.embed_positions) - freeze_params(d.embed_tokens) - except AttributeError: - freeze_params(model.shared) - for d in [model.encoder, model.decoder]: - freeze_params(d.embed_tokens) - if model_args.freeze_embeds: freeze_embeds(model) if model_args.freeze_encoder: @@ -376,6 +356,7 @@ def main(): eval_dataset=eval_dataset, data_collator=Seq2SeqDataCollator(tokenizer, data_args, training_args.tpu_num_cores), compute_metrics=build_compute_metrics_fn(data_args.task) if training_args.predict_with_generate else None, + data_args=data_args, ) # Training @@ -396,41 +377,36 @@ def main(): result = trainer.evaluate() - output_eval_file = os.path.join(training_args.output_dir, "eval_results.json") if trainer.is_world_process_zero(): logger.info("***** Eval results *****") for key, value in result.items(): logger.info(" %s = %s", key, value) - - with open(output_eval_file, "w") as f: - json.dump(result, f) - + save_json(result, os.path.join(training_args.output_dir, "eval_results.json")) eval_results.update(result) if training_args.do_predict: logging.info("*** Test ***") test_output = trainer.predict(test_dataset=test_dataset) - test_metrics = test_output.metrics - test_metrics = {k.replace("eval", "test"): v for k, v in test_metrics.items()} - - output_test_file = os.path.join(training_args.output_dir, "test_results.json") + test_metrics = {k.replace("eval", "test"): v for k, v in test_output.metrics.items()} if trainer.is_world_process_zero(): logger.info("***** Test results *****") for key, value in test_metrics.items(): logger.info(" %s = %s", key, value) - with open(output_test_file, "w") as f: - json.dump(test_metrics, f) + save_json(test_metrics, os.path.join(training_args.output_dir, "test_results.json")) + eval_results.update(test_metrics) if training_args.predict_with_generate: - test_preds = tokenizer.batch_decode(test_output.predictions, skip_special_tokens=True) + test_preds = tokenizer.batch_decode( + test_output.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True + ) test_preds = lmap(str.strip, test_preds) - output_test_pred_file = os.path.join(training_args.output_dir, "test_generations.txt") - with open(output_test_pred_file, "w") as f: - f.write("\n".join(test_preds)) + write_txt_file(test_preds, os.path.join(training_args.output_dir, "test_generations.txt")) + if trainer.is_world_process_zero(): + save_json(eval_results, "all_results.json") return eval_results diff --git a/examples/seq2seq/seq2seq_trainer.py b/examples/seq2seq/seq2seq_trainer.py index 195eb4768e..885e7263f8 100644 --- a/examples/seq2seq/seq2seq_trainer.py +++ b/examples/seq2seq/seq2seq_trainer.py @@ -20,6 +20,12 @@ logger = logging.getLogger(__name__) class Seq2SeqTrainer(Trainer): + def __init__(self, data_args, *args, **kwargs): + super().__init__(*args, **kwargs) + self.data_args = data_args + self.max_gen_length = data_args.val_max_target_length + self.pad_token_id = self.model.config.pad_token_id + def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]: if isinstance(self.train_dataset, torch.utils.data.IterableDataset): return None @@ -41,7 +47,7 @@ class Seq2SeqTrainer(Trainer): labels = inputs.pop("labels") outputs = model(**inputs, use_cache=False) logits = outputs[0] - return self._compute_loss(logits, labels, ignore_index=model.config.pad_token_id) + return self._compute_loss(logits, labels, ignore_index=self.pad_token_id) def _compute_loss(self, logits, labels, ignore_index): if self.args.label_smoothing == 0: @@ -81,41 +87,32 @@ class Seq2SeqTrainer(Trainer): """ inputs = self._prepare_inputs(inputs) - max_length = ( - model.config.max_generate_length - if hasattr(model.config, "max_generate_length") - else model.config.max_position_embeddings - ) - with torch.no_grad(): if self.args.predict_with_generate and not self.args.prediction_loss_only: generated_tokens = model.generate( inputs["input_ids"], attention_mask=inputs["attention_mask"], use_cache=True, - num_beams=model.config.num_beams, - max_length=max_length, + num_beams=self.data_args.eval_beams, + max_length=self.max_gen_length, ) # in case the batch is shorter than max length, the output should be padded generated_tokens = self._pad_tensors_to_max_len( - generated_tokens, max_length, model.config.pad_token_id + generated_tokens, self.max_gen_length, self.pad_token_id ) labels_out = inputs.get("labels") - outputs = model(**inputs) - logits = outputs[1] - loss = self._compute_loss(logits, labels_out, model.config.pad_token_id) + # Call forward again to get loss # TODO: avoidable? + outputs = model(**inputs, use_cache=False) + loss = self._compute_loss(outputs[1], labels_out, self.pad_token_id) loss = loss.mean().item() if self.args.prediction_loss_only: - logits = None - else: - logits = generated_tokens if self.args.predict_with_generate else logits + return (loss, None, None) - if self.args.prediction_loss_only: - return (loss, None, None) + logits = generated_tokens if self.args.predict_with_generate else outputs[1] labels_out = labels_out.detach() - labels = self._pad_tensors_to_max_len(labels_out, max_length, model.config.pad_token_id) + labels = self._pad_tensors_to_max_len(labels_out, self.max_gen_length, self.pad_token_id) return (loss, logits.detach(), labels) def _pad_tensors_to_max_len(self, tensor, max_length, pad_token_id): diff --git a/examples/seq2seq/test_finetune_trainer.py b/examples/seq2seq/test_finetune_trainer.py index 2863a5aa5f..80f8d699b0 100644 --- a/examples/seq2seq/test_finetune_trainer.py +++ b/examples/seq2seq/test_finetune_trainer.py @@ -3,36 +3,54 @@ import sys import tempfile from unittest.mock import patch -from transformers import BartForConditionalGeneration, MarianMTModel from transformers.testing_utils import slow +from transformers.trainer_utils import set_seed from .finetune_trainer import main from .test_seq2seq_examples import MBART_TINY from .utils import load_json -MODEL_NAME = MBART_TINY -# TODO(SS): MODEL_NAME = "sshleifer/student_mbart_en_ro_1_1" +set_seed(42) MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1" -@slow -def test_model_download(): - """This warms up the cache so that we can time the next test without including download time, which varies between machines.""" - BartForConditionalGeneration.from_pretrained(MODEL_NAME) - MarianMTModel.from_pretrained(MARIAN_MODEL) - - -@slow def test_finetune_trainer(): + output_dir = run_trainer(1, "12", MBART_TINY, 1) + logs = load_json(os.path.join(output_dir, "log_history.json")) + eval_metrics = [log for log in logs if "eval_loss" in log.keys()] + first_step_stats = eval_metrics[0] + assert "eval_bleu" in first_step_stats + + +@slow +def test_finetune_trainer_slow(): + # TODO(SS): This will fail on devices with more than 1 GPU. + # There is a missing call to __init__process_group somewhere + output_dir = run_trainer(eval_steps=2, max_len="32", model_name=MARIAN_MODEL, num_train_epochs=3) + + # Check metrics + logs = load_json(os.path.join(output_dir, "log_history.json")) + eval_metrics = [log for log in logs if "eval_loss" in log.keys()] + first_step_stats = eval_metrics[0] + last_step_stats = eval_metrics[-1] + + assert first_step_stats["eval_bleu"] < last_step_stats["eval_bleu"] # model learned nothing + assert isinstance(last_step_stats["eval_bleu"], float) + + # test if do_predict saves generations and metrics + contents = os.listdir(output_dir) + contents = {os.path.basename(p) for p in contents} + assert "test_generations.txt" in contents + assert "test_results.json" in contents + + +def run_trainer(eval_steps: int, max_len: str, model_name: str, num_train_epochs: int): data_dir = "examples/seq2seq/test_data/wmt_en_ro" - output_dir = tempfile.mkdtemp(prefix="marian_output") - max_len = "128" - num_train_epochs = 4 - eval_steps = 2 + output_dir = tempfile.mkdtemp(prefix="test_output") argv = [ "--model_name_or_path", - MARIAN_MODEL, + model_name, "--data_dir", data_dir, "--output_dir", @@ -72,25 +90,17 @@ def test_finetune_trainer(): "--sortish_sampler", "--label_smoothing", "0.1", + # "--eval_beams", + # "2", "--task", "translation", + "--tgt_lang", + "ro_RO", + "--src_lang", + "en_XX", ] - testargs = ["finetune_trainer.py"] + argv with patch.object(sys, "argv", testargs): main() - # Check metrics - logs = load_json(os.path.join(output_dir, "log_history.json")) - eval_metrics = [log for log in logs if "eval_loss" in log.keys()] - first_step_stats = eval_metrics[0] - last_step_stats = eval_metrics[-1] - - assert first_step_stats["eval_bleu"] < last_step_stats["eval_bleu"] # model learned nothing - assert isinstance(last_step_stats["eval_bleu"], float) - - # test if do_predict saves generations and metrics - contents = os.listdir(output_dir) - contents = {os.path.basename(p) for p in contents} - assert "test_generations.txt" in contents - assert "test_results.json" in contents + return output_dir diff --git a/examples/seq2seq/utils.py b/examples/seq2seq/utils.py index 43f5caf05f..f64f0104cb 100644 --- a/examples/seq2seq/utils.py +++ b/examples/seq2seq/utils.py @@ -441,6 +441,25 @@ def freeze_params(model: nn.Module): par.requires_grad = False +def freeze_embeds(model): + """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5.""" + model_type = model.config.model_type + + if model_type == "t5": + freeze_params(model.shared) + for d in [model.encoder, model.decoder]: + freeze_params(d.embed_tokens) + elif model_type == "fsmt": + for d in [model.model.encoder, model.model.decoder]: + freeze_params(d.embed_positions) + freeze_params(d.embed_tokens) + else: + freeze_params(model.model.shared) + for d in [model.model.encoder, model.model.decoder]: + freeze_params(d.embed_positions) + freeze_params(d.embed_tokens) + + def grad_status(model: nn.Module) -> Iterable: return (par.requires_grad for par in model.parameters())