From 4208f496ee6554a998c3e7dd416cad3b1c1a0111 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Thu, 19 Nov 2020 10:43:15 -0500 Subject: [PATCH] Better filtering of the model outputs in Trainer (#8633) * Better filtering of the model outputs in Trainer * Fix examples tests * Add test for Lysandre --- examples/seq2seq/seq2seq_trainer.py | 8 ++- src/transformers/configuration_utils.py | 2 + .../models/bart/configuration_bart.py | 1 + .../models/ctrl/configuration_ctrl.py | 1 + .../models/gpt2/configuration_gpt2.py | 1 + .../models/marian/configuration_marian.py | 1 + .../models/mbart/configuration_mbart.py | 1 + .../models/mt5/configuration_mt5.py | 1 + .../models/pegasus/configuration_pegasus.py | 1 + .../prophetnet/configuration_prophetnet.py | 1 + .../models/reformer/configuration_reformer.py | 1 + .../models/t5/configuration_t5.py | 1 + .../transfo_xl/configuration_transfo_xl.py | 1 + .../models/xlnet/configuration_xlnet.py | 1 + src/transformers/trainer.py | 57 ++++++++++++++----- tests/test_trainer.py | 55 ++++++++++++++++++ 16 files changed, 119 insertions(+), 15 deletions(-) diff --git a/examples/seq2seq/seq2seq_trainer.py b/examples/seq2seq/seq2seq_trainer.py index 520df0e87b..99826e2228 100644 --- a/examples/seq2seq/seq2seq_trainer.py +++ b/examples/seq2seq/seq2seq_trainer.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch import nn @@ -153,7 +153,11 @@ class Seq2SeqTrainer(Trainer): return loss def prediction_step( - self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool + self, + model: nn.Module, + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: """ Perform an evaluation step on :obj:`model` using obj:`inputs`. diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 4e55c4db65..f7587faac0 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -43,6 +43,8 @@ class PretrainedConfig(object): - **is_composition** (:obj:`bool`): Whether the config class is composed of multiple sub-configs. In this case the config has to be initialized from two or more configs of type :class:`~transformers.PretrainedConfig` like: :class:`~transformers.EncoderDecoderConfig` or :class:`~RagConfig`. + - **keys_to_ignore_at_inference** (:obj:`List[str]`): A list of keys to ignore by default when looking at + dictionary outputs of the model during inference. Args: name_or_path (:obj:`str`, `optional`, defaults to :obj:`""`): diff --git a/src/transformers/models/bart/configuration_bart.py b/src/transformers/models/bart/configuration_bart.py index 8533a013be..67e92e8966 100644 --- a/src/transformers/models/bart/configuration_bart.py +++ b/src/transformers/models/bart/configuration_bart.py @@ -110,6 +110,7 @@ class BartConfig(PretrainedConfig): :obj:`True` for `bart-large-cnn`. """ model_type = "bart" + keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, diff --git a/src/transformers/models/ctrl/configuration_ctrl.py b/src/transformers/models/ctrl/configuration_ctrl.py index faffaa0df9..c2633c49b8 100644 --- a/src/transformers/models/ctrl/configuration_ctrl.py +++ b/src/transformers/models/ctrl/configuration_ctrl.py @@ -77,6 +77,7 @@ class CTRLConfig(PretrainedConfig): """ model_type = "ctrl" + keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, diff --git a/src/transformers/models/gpt2/configuration_gpt2.py b/src/transformers/models/gpt2/configuration_gpt2.py index 25cdcb49f2..a30da248a5 100644 --- a/src/transformers/models/gpt2/configuration_gpt2.py +++ b/src/transformers/models/gpt2/configuration_gpt2.py @@ -120,6 +120,7 @@ class GPT2Config(PretrainedConfig): """ model_type = "gpt2" + keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, diff --git a/src/transformers/models/marian/configuration_marian.py b/src/transformers/models/marian/configuration_marian.py index d5769bcb9c..a17531bb2f 100644 --- a/src/transformers/models/marian/configuration_marian.py +++ b/src/transformers/models/marian/configuration_marian.py @@ -97,3 +97,4 @@ class MarianConfig(BartConfig): """ model_type = "marian" + keys_to_ignore_at_inference = ["past_key_values"] diff --git a/src/transformers/models/mbart/configuration_mbart.py b/src/transformers/models/mbart/configuration_mbart.py index 7436660278..c8b4540e1e 100644 --- a/src/transformers/models/mbart/configuration_mbart.py +++ b/src/transformers/models/mbart/configuration_mbart.py @@ -102,3 +102,4 @@ class MBartConfig(BartConfig): """ model_type = "mbart" + keys_to_ignore_at_inference = ["past_key_values"] diff --git a/src/transformers/models/mt5/configuration_mt5.py b/src/transformers/models/mt5/configuration_mt5.py index 23bde10047..09e9ac2262 100644 --- a/src/transformers/models/mt5/configuration_mt5.py +++ b/src/transformers/models/mt5/configuration_mt5.py @@ -62,6 +62,7 @@ class MT5Config(PretrainedConfig): Type of feed forward layer to be used. Should be one of :obj:`"relu"` or :obj:`"gated-gelu"`. """ model_type = "mt5" + keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, diff --git a/src/transformers/models/pegasus/configuration_pegasus.py b/src/transformers/models/pegasus/configuration_pegasus.py index f134ea5832..585f06ddb4 100644 --- a/src/transformers/models/pegasus/configuration_pegasus.py +++ b/src/transformers/models/pegasus/configuration_pegasus.py @@ -141,4 +141,5 @@ class PegasusConfig(BartConfig): """ model_type = "pegasus" + keys_to_ignore_at_inference = ["past_key_values"] # The implementation of the config object is in BartConfig diff --git a/src/transformers/models/prophetnet/configuration_prophetnet.py b/src/transformers/models/prophetnet/configuration_prophetnet.py index f652043e66..fdb6f5f300 100644 --- a/src/transformers/models/prophetnet/configuration_prophetnet.py +++ b/src/transformers/models/prophetnet/configuration_prophetnet.py @@ -92,6 +92,7 @@ class ProphetNetConfig(PretrainedConfig): smoothing is performed. """ model_type = "prophetnet" + keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, diff --git a/src/transformers/models/reformer/configuration_reformer.py b/src/transformers/models/reformer/configuration_reformer.py index 69d178875e..9e860a48c9 100755 --- a/src/transformers/models/reformer/configuration_reformer.py +++ b/src/transformers/models/reformer/configuration_reformer.py @@ -153,6 +153,7 @@ class ReformerConfig(PretrainedConfig): >>> configuration = model.config """ model_type = "reformer" + keys_to_ignore_at_inference = ["past_buckets_states"] def __init__( self, diff --git a/src/transformers/models/t5/configuration_t5.py b/src/transformers/models/t5/configuration_t5.py index 48bdb6c329..75b396742c 100644 --- a/src/transformers/models/t5/configuration_t5.py +++ b/src/transformers/models/t5/configuration_t5.py @@ -71,6 +71,7 @@ class T5Config(PretrainedConfig): the :obj:`"gated-gelu"` feed forward projection. Original T5 uses :obj:`"relu"`. """ model_type = "t5" + keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, diff --git a/src/transformers/models/transfo_xl/configuration_transfo_xl.py b/src/transformers/models/transfo_xl/configuration_transfo_xl.py index 9885cbfa2e..1008f3488a 100644 --- a/src/transformers/models/transfo_xl/configuration_transfo_xl.py +++ b/src/transformers/models/transfo_xl/configuration_transfo_xl.py @@ -105,6 +105,7 @@ class TransfoXLConfig(PretrainedConfig): """ model_type = "transfo-xl" + keys_to_ignore_at_inference = ["mems"] def __init__( self, diff --git a/src/transformers/models/xlnet/configuration_xlnet.py b/src/transformers/models/xlnet/configuration_xlnet.py index db10231790..f0592a8d0b 100644 --- a/src/transformers/models/xlnet/configuration_xlnet.py +++ b/src/transformers/models/xlnet/configuration_xlnet.py @@ -128,6 +128,7 @@ class XLNetConfig(PretrainedConfig): """ model_type = "xlnet" + keys_to_ignore_at_inference = ["mems"] def __init__( self, diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 950e242913..64c363afb5 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1098,10 +1098,11 @@ class Trainer: """ outputs = model(**inputs) # Save past state if it exists + # TODO: this needs to be fixed and made cleaner later. if self.args.past_index >= 0: self._past = outputs[self.args.past_index] # We don't use .loss here since the model may return tuples instead of ModelOutput. - return outputs[0] + return outputs["loss"] if isinstance(outputs, dict) else outputs[0] def is_local_process_zero(self) -> bool: """ @@ -1220,7 +1221,9 @@ class Trainer: logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint)) shutil.rmtree(checkpoint) - def evaluate(self, eval_dataset: Optional[Dataset] = None) -> Dict[str, float]: + def evaluate( + self, eval_dataset: Optional[Dataset] = None, ignore_keys: Optional[List[str]] = None + ) -> Dict[str, float]: """ Run evaluation and returns metrics. @@ -1234,6 +1237,9 @@ class Trainer: Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`, columns not accepted by the ``model.forward()`` method are automatically removed. It must implement the :obj:`__len__` method. + ignore_keys (:obj:`Lst[str]`, `optional`): + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions. Returns: A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The @@ -1250,6 +1256,7 @@ class Trainer: # No point gathering the predictions if there are no metrics, otherwise we defer to # self.args.prediction_loss_only prediction_loss_only=True if self.compute_metrics is None else None, + ignore_keys=ignore_keys, ) self.log(output.metrics) @@ -1261,7 +1268,7 @@ class Trainer: self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics) return output.metrics - def predict(self, test_dataset: Dataset) -> PredictionOutput: + def predict(self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None) -> PredictionOutput: """ Run prediction and returns predictions and potential metrics. @@ -1272,6 +1279,9 @@ class Trainer: test_dataset (:obj:`Dataset`): Dataset to run the predictions on. If it is an :obj:`datasets.Dataset`, columns not accepted by the ``model.forward()`` method are automatically removed. Has to implement the method :obj:`__len__` + ignore_keys (:obj:`Lst[str]`, `optional`): + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions. .. note:: @@ -1291,10 +1301,14 @@ class Trainer: test_dataloader = self.get_test_dataloader(test_dataset) - return self.prediction_loop(test_dataloader, description="Prediction") + return self.prediction_loop(test_dataloader, description="Prediction", ignore_keys=ignore_keys) def prediction_loop( - self, dataloader: DataLoader, description: str, prediction_loss_only: Optional[bool] = None + self, + dataloader: DataLoader, + description: str, + prediction_loss_only: Optional[bool] = None, + ignore_keys: Optional[List[str]] = None, ) -> PredictionOutput: """ Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`. @@ -1346,7 +1360,7 @@ class Trainer: self.callback_handler.eval_dataloader = dataloader for step, inputs in enumerate(dataloader): - loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only) + loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) if loss is not None: losses = loss.repeat(batch_size) losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0) @@ -1410,7 +1424,11 @@ class Trainer: return nested_numpify(tensors) def prediction_step( - self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool + self, + model: nn.Module, + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: """ Perform an evaluation step on :obj:`model` using obj:`inputs`. @@ -1427,6 +1445,9 @@ class Trainer: argument :obj:`labels`. Check your model's documentation for all accepted arguments. prediction_loss_only (:obj:`bool`): Whether or not to return the loss only. + ignore_keys (:obj:`Lst[str]`, `optional`): + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions. Return: Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and @@ -1434,6 +1455,11 @@ class Trainer: """ has_labels = all(inputs.get(k) is not None for k in self.label_names) inputs = self._prepare_inputs(inputs) + if ignore_keys is None: + if hasattr(self.model, "config"): + ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", []) + else: + ignore_keys = [] with torch.no_grad(): if self.args.fp16 and _use_native_amp: @@ -1442,16 +1468,21 @@ class Trainer: else: outputs = model(**inputs) if has_labels: - loss = outputs[0].mean().detach() - logits = outputs[1:] + if isinstance(outputs, dict): + loss = outputs["loss"].mean().detach() + logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"]) + else: + loss = outputs[0].mean().detach() + logits = outputs[1:] else: loss = None - # Slicing so we get a tuple even if `outputs` is a `ModelOutput`. - logits = outputs[:] + if isinstance(outputs, dict): + logits = tuple(v for k, v in outputs.items() if k not in ignore_keys) + else: + logits = outputs + # TODO: this needs to be fixed and made cleaner later. if self.args.past_index >= 0: self._past = outputs[self.args.past_index if has_labels else self.args.past_index - 1] - # Remove the past from the logits. - logits = logits[: self.args.past_index - 1] + logits[self.args.past_index :] if prediction_loss_only: return (loss, None, None) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index b5db8c0712..5d80654d48 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -44,6 +44,8 @@ if is_torch_available(): DataCollatorForLanguageModeling, GlueDataset, GlueDataTrainingArguments, + GPT2Config, + GPT2LMHeadModel, LineByLineTextDataset, PreTrainedModel, TextDataset, @@ -73,6 +75,18 @@ class RegressionDataset: return result +class RepeatDataset: + def __init__(self, x, length=64): + self.x = x + self.length = length + + def __len__(self): + return self.length + + def __getitem__(self, i): + return {"input_ids": self.x, "labels": self.x} + + class DynamicShapesDataset: def __init__(self, length=64, seed=42, batch_size=8): self.length = length @@ -136,6 +150,20 @@ if is_torch_available(): loss = torch.nn.functional.mse_loss(y, labels) return (loss, y, y) if self.double_output else (loss, y) + class RegressionDictModel(torch.nn.Module): + def __init__(self, a=0, b=0): + super().__init__() + self.a = torch.nn.Parameter(torch.tensor(a).float()) + self.b = torch.nn.Parameter(torch.tensor(b).float()) + self.config = None + + def forward(self, input_x=None, labels=None, **kwargs): + y = input_x * self.a + self.b + result = {"output": y} + if labels is not None: + result["loss"] = torch.nn.functional.mse_loss(y, labels) + return result + class RegressionPreTrainedModel(PreTrainedModel): config_class = RegressionModelConfig base_model_prefix = "regression" @@ -236,6 +264,33 @@ class TrainerIntegrationTest(unittest.TestCase): metrics = trainer.evaluate() self.assertEqual(metrics[metric], best_value) + def test_trainer_works_with_dict(self): + # Edge case because Apex with mode O2 will change our models to return dicts. This test checks it doesn't break + # anything. + train_dataset = RegressionDataset() + eval_dataset = RegressionDataset() + model = RegressionDictModel() + args = TrainingArguments("./regression") + trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset) + trainer.train() + _ = trainer.evaluate() + _ = trainer.predict(eval_dataset) + + def test_evaluation_with_keys_to_drop(self): + config = GPT2Config(vocab_size=100, n_positions=128, n_ctx=128, n_embd=32, n_layer=3, n_head=4) + tiny_gpt2 = GPT2LMHeadModel(config) + x = torch.randint(0, 100, (128,)) + eval_dataset = RepeatDataset(x) + args = TrainingArguments("./test") + trainer = Trainer(tiny_gpt2, args, eval_dataset=eval_dataset) + # By default the past_key_values are removed + result = trainer.predict(eval_dataset) + self.assertTrue(isinstance(result.predictions, np.ndarray)) + # We can still get them by setting ignore_keys to [] + result = trainer.predict(eval_dataset, ignore_keys=[]) + self.assertTrue(isinstance(result.predictions, tuple)) + self.assertEqual(len(result.predictions), 2) + def test_training_arguments_are_left_untouched(self): trainer = get_regression_trainer() trainer.train()