From 64e3d966b1131c15b5905b1e1e582d4bebac1ef0 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Wed, 1 Jul 2020 08:11:55 -0400 Subject: [PATCH] Add support for past states (#5399) * Add support for past states * Style and forgotten self * You mean, documenting is not enough? I have to actually add it too? * Add memory support during evaluation * Fix tests in eval and add TF support * No need to change this line anymore --- src/transformers/trainer.py | 20 ++++++++++++++++++++ src/transformers/trainer_tf.py | 24 ++++++++++++++++++++++-- src/transformers/training_args.py | 10 ++++++++++ src/transformers/training_args_tf.py | 5 +++++ 4 files changed, 57 insertions(+), 2 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index fa40947b1a..7b974814ad 100644 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -493,6 +493,10 @@ class Trainer: else: epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=not self.is_local_master()) + # Reset the past mems state at the beginning of each epoch if necessary. + if self.args.past_index >= 0: + self._past = None + for step, inputs in enumerate(epoch_iterator): # Skip past any already trained steps if resuming training @@ -575,6 +579,9 @@ class Trainer: if self.tb_writer: self.tb_writer.close() + if self.args.past_index and hasattr(self, "_past"): + # Clean the state at the end of training + delattr(self, "_past") logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") return TrainOutput(self.global_step, tr_loss / self.global_step) @@ -617,9 +624,15 @@ class Trainer: if isinstance(v, torch.Tensor): inputs[k] = v.to(self.args.device) + if self.args.past_index >= 0 and self._past is not None: + inputs["mems"] = self._past + outputs = model(**inputs) loss = outputs[0] # model outputs are always tuple in transformers (see doc) + if self.args.past_index >= 0: + self._past = outputs[self.args.past_index] + if self.args.n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu parallel training if self.args.gradient_accumulation_steps > 1: @@ -802,12 +815,17 @@ class Trainer: if is_torch_tpu_available(): dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device) + if self.args.past_index >= 0: + past = None + for inputs in tqdm(dataloader, desc=description): has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"]) for k, v in inputs.items(): if isinstance(v, torch.Tensor): inputs[k] = v.to(self.args.device) + if self.args.past_index >= 0: + inputs["mems"] = past with torch.no_grad(): outputs = model(**inputs) @@ -816,6 +834,8 @@ class Trainer: eval_losses += [step_eval_loss.mean().item()] else: logits = outputs[0] + if self.args.past_index >= 0: + past = outputs[self.args.past_index if has_labels else self.args.past_index - 1] if not prediction_loss_only: if preds is None: diff --git a/src/transformers/trainer_tf.py b/src/transformers/trainer_tf.py index e1afbc1743..c61d3661f3 100644 --- a/src/transformers/trainer_tf.py +++ b/src/transformers/trainer_tf.py @@ -240,6 +240,10 @@ class TFTrainer: step: int = 1 + # Reset the past mems state at the beginning of the evaluation if necessary. + if self.args.past_index >= 0: + self._past = None + for features, labels in dataset: step = tf.convert_to_tensor(step, dtype=tf.int64) loss, logits = self._evaluate_steps(features, labels) @@ -288,6 +292,10 @@ class TFTrainer: if not key.startswith("eval_"): metrics[f"eval_{key}"] = metrics.pop(key) + if self.args.past_index and hasattr(self, "_past"): + # Clean the state at the end of training + delattr(self, "_past") + return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics) def _log(self, logs: Dict[str, float]) -> None: @@ -405,6 +413,9 @@ class TFTrainer: logger.info(" Total optimization steps = %d", t_total) for epoch_iter in range(epochs_trained, int(epochs + 1)): + # Reset the past mems state at the beginning of each epoch if necessary. + if self.args.past_index >= 0: + self._past = None for step, training_loss in enumerate(self._training_steps(train_ds, optimizer)): self.global_step = iterations.numpy() self.epoch_logging = epoch_iter - 1 + (step + 1) / steps_per_epoch @@ -444,6 +455,10 @@ class TFTrainer: if self.args.max_steps > 0 and self.global_step % self.args.max_steps == 0: break + if self.args.past_index and hasattr(self, "_past"): + # Clean the state at the end of training + delattr(self, "_past") + def _training_steps(self, ds, optimizer): """ Returns a generator over training steps (i.e. parameters update). @@ -518,10 +533,15 @@ class TFTrainer: labels: the batched labels. training: run the model in training mode or not """ + if self.args.past_index >= 0 and getattr(self, "_past", None) is not None: + features["mems"] = self._past if isinstance(labels, (dict)): - loss, logits = self.model(features, training=training, **labels)[:2] + outputs = self.model(features, training=training, **labels)[:2] else: - loss, logits = self.model(features, labels=labels, training=training)[:2] + outputs = self.model(features, labels=labels, training=training)[:2] + loss, logits = outputs[:2] + if self.args.past_index >= 0: + self._past = outputs[self.args.past_index] loss += sum(self.model.losses) * (1.0 / self.args.n_gpu) return loss, logits diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index c0d7a4913f..d78ec19dbe 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -102,6 +102,11 @@ class TrainingArguments: dataloader_drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size) or not. + past_index (:obj:`int`, `optional`, defaults to -1): + Some models like :doc:`TransformerXL <../model_doc/transformerxl>` or :doc`XLNet <../model_doc/xlnet>` can + make use of the past hidden states for their predictions. If this argument is set to a positive int, the + ``Trainer`` will use the corresponding output (usually index 2) as the past state and feed it to the model + at the next training step under the keyword argument ``mems``. """ output_dir: str = field( @@ -203,6 +208,11 @@ class TrainingArguments: default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."} ) + past_index: int = field( + default=-1, + metadata={"help": "If >=0, uses the corresponding part of the output as the past state for next step."}, + ) + @property def train_batch_size(self) -> int: """ diff --git a/src/transformers/training_args_tf.py b/src/transformers/training_args_tf.py index f87c7bc994..942dc13892 100644 --- a/src/transformers/training_args_tf.py +++ b/src/transformers/training_args_tf.py @@ -85,6 +85,11 @@ class TFTrainingArguments(TrainingArguments): dataloader_drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size) or not. + past_index (:obj:`int`, `optional`, defaults to -1): + Some models like :doc:`TransformerXL <../model_doc/transformerxl>` or :doc`XLNet <../model_doc/xlnet>` can + make use of the past hidden states for their predictions. If this argument is set to a positive int, the + ``Trainer`` will use the corresponding output (usually index 2) as the past state and feed it to the model + at the next training step under the keyword argument ``mems``. tpu_name (:obj:`str`, `optional`): The name of the TPU the process is running on. eval_steps (:obj:`int`, `optional`, defaults to 1000):