Add support for past states (#5399)

* Add support for past states

* Style and forgotten self

* You mean, documenting is not enough? I have to actually add it too?

* Add memory support during evaluation

* Fix tests in eval and add TF support

* No need to change this line anymore
This commit is contained in:
Sylvain Gugger
2020-07-01 08:11:55 -04:00
committed by GitHub
parent 4ade7491f4
commit 64e3d966b1
4 changed files with 57 additions and 2 deletions

View File

@@ -493,6 +493,10 @@ class Trainer:
else: else:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=not self.is_local_master()) epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=not self.is_local_master())
# Reset the past mems state at the beginning of each epoch if necessary.
if self.args.past_index >= 0:
self._past = None
for step, inputs in enumerate(epoch_iterator): for step, inputs in enumerate(epoch_iterator):
# Skip past any already trained steps if resuming training # Skip past any already trained steps if resuming training
@@ -575,6 +579,9 @@ class Trainer:
if self.tb_writer: if self.tb_writer:
self.tb_writer.close() self.tb_writer.close()
if self.args.past_index and hasattr(self, "_past"):
# Clean the state at the end of training
delattr(self, "_past")
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
return TrainOutput(self.global_step, tr_loss / self.global_step) return TrainOutput(self.global_step, tr_loss / self.global_step)
@@ -617,9 +624,15 @@ class Trainer:
if isinstance(v, torch.Tensor): if isinstance(v, torch.Tensor):
inputs[k] = v.to(self.args.device) inputs[k] = v.to(self.args.device)
if self.args.past_index >= 0 and self._past is not None:
inputs["mems"] = self._past
outputs = model(**inputs) outputs = model(**inputs)
loss = outputs[0] # model outputs are always tuple in transformers (see doc) loss = outputs[0] # model outputs are always tuple in transformers (see doc)
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index]
if self.args.n_gpu > 1: if self.args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training loss = loss.mean() # mean() to average on multi-gpu parallel training
if self.args.gradient_accumulation_steps > 1: if self.args.gradient_accumulation_steps > 1:
@@ -802,12 +815,17 @@ class Trainer:
if is_torch_tpu_available(): if is_torch_tpu_available():
dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device) dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)
if self.args.past_index >= 0:
past = None
for inputs in tqdm(dataloader, desc=description): for inputs in tqdm(dataloader, desc=description):
has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"]) has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"])
for k, v in inputs.items(): for k, v in inputs.items():
if isinstance(v, torch.Tensor): if isinstance(v, torch.Tensor):
inputs[k] = v.to(self.args.device) inputs[k] = v.to(self.args.device)
if self.args.past_index >= 0:
inputs["mems"] = past
with torch.no_grad(): with torch.no_grad():
outputs = model(**inputs) outputs = model(**inputs)
@@ -816,6 +834,8 @@ class Trainer:
eval_losses += [step_eval_loss.mean().item()] eval_losses += [step_eval_loss.mean().item()]
else: else:
logits = outputs[0] logits = outputs[0]
if self.args.past_index >= 0:
past = outputs[self.args.past_index if has_labels else self.args.past_index - 1]
if not prediction_loss_only: if not prediction_loss_only:
if preds is None: if preds is None:

View File

@@ -240,6 +240,10 @@ class TFTrainer:
step: int = 1 step: int = 1
# Reset the past mems state at the beginning of the evaluation if necessary.
if self.args.past_index >= 0:
self._past = None
for features, labels in dataset: for features, labels in dataset:
step = tf.convert_to_tensor(step, dtype=tf.int64) step = tf.convert_to_tensor(step, dtype=tf.int64)
loss, logits = self._evaluate_steps(features, labels) loss, logits = self._evaluate_steps(features, labels)
@@ -288,6 +292,10 @@ class TFTrainer:
if not key.startswith("eval_"): if not key.startswith("eval_"):
metrics[f"eval_{key}"] = metrics.pop(key) metrics[f"eval_{key}"] = metrics.pop(key)
if self.args.past_index and hasattr(self, "_past"):
# Clean the state at the end of training
delattr(self, "_past")
return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics) return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
def _log(self, logs: Dict[str, float]) -> None: def _log(self, logs: Dict[str, float]) -> None:
@@ -405,6 +413,9 @@ class TFTrainer:
logger.info(" Total optimization steps = %d", t_total) logger.info(" Total optimization steps = %d", t_total)
for epoch_iter in range(epochs_trained, int(epochs + 1)): for epoch_iter in range(epochs_trained, int(epochs + 1)):
# Reset the past mems state at the beginning of each epoch if necessary.
if self.args.past_index >= 0:
self._past = None
for step, training_loss in enumerate(self._training_steps(train_ds, optimizer)): for step, training_loss in enumerate(self._training_steps(train_ds, optimizer)):
self.global_step = iterations.numpy() self.global_step = iterations.numpy()
self.epoch_logging = epoch_iter - 1 + (step + 1) / steps_per_epoch self.epoch_logging = epoch_iter - 1 + (step + 1) / steps_per_epoch
@@ -444,6 +455,10 @@ class TFTrainer:
if self.args.max_steps > 0 and self.global_step % self.args.max_steps == 0: if self.args.max_steps > 0 and self.global_step % self.args.max_steps == 0:
break break
if self.args.past_index and hasattr(self, "_past"):
# Clean the state at the end of training
delattr(self, "_past")
def _training_steps(self, ds, optimizer): def _training_steps(self, ds, optimizer):
""" """
Returns a generator over training steps (i.e. parameters update). Returns a generator over training steps (i.e. parameters update).
@@ -518,10 +533,15 @@ class TFTrainer:
labels: the batched labels. labels: the batched labels.
training: run the model in training mode or not training: run the model in training mode or not
""" """
if self.args.past_index >= 0 and getattr(self, "_past", None) is not None:
features["mems"] = self._past
if isinstance(labels, (dict)): if isinstance(labels, (dict)):
loss, logits = self.model(features, training=training, **labels)[:2] outputs = self.model(features, training=training, **labels)[:2]
else: else:
loss, logits = self.model(features, labels=labels, training=training)[:2] outputs = self.model(features, labels=labels, training=training)[:2]
loss, logits = outputs[:2]
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index]
loss += sum(self.model.losses) * (1.0 / self.args.n_gpu) loss += sum(self.model.losses) * (1.0 / self.args.n_gpu)
return loss, logits return loss, logits

View File

@@ -102,6 +102,11 @@ class TrainingArguments:
dataloader_drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`): dataloader_drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size) Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size)
or not. or not.
past_index (:obj:`int`, `optional`, defaults to -1):
Some models like :doc:`TransformerXL <../model_doc/transformerxl>` or :doc`XLNet <../model_doc/xlnet>` can
make use of the past hidden states for their predictions. If this argument is set to a positive int, the
``Trainer`` will use the corresponding output (usually index 2) as the past state and feed it to the model
at the next training step under the keyword argument ``mems``.
""" """
output_dir: str = field( output_dir: str = field(
@@ -203,6 +208,11 @@ class TrainingArguments:
default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."} default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."}
) )
past_index: int = field(
default=-1,
metadata={"help": "If >=0, uses the corresponding part of the output as the past state for next step."},
)
@property @property
def train_batch_size(self) -> int: def train_batch_size(self) -> int:
""" """

View File

@@ -85,6 +85,11 @@ class TFTrainingArguments(TrainingArguments):
dataloader_drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`): dataloader_drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size) Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size)
or not. or not.
past_index (:obj:`int`, `optional`, defaults to -1):
Some models like :doc:`TransformerXL <../model_doc/transformerxl>` or :doc`XLNet <../model_doc/xlnet>` can
make use of the past hidden states for their predictions. If this argument is set to a positive int, the
``Trainer`` will use the corresponding output (usually index 2) as the past state and feed it to the model
at the next training step under the keyword argument ``mems``.
tpu_name (:obj:`str`, `optional`): tpu_name (:obj:`str`, `optional`):
The name of the TPU the process is running on. The name of the TPU the process is running on.
eval_steps (:obj:`int`, `optional`, defaults to 1000): eval_steps (:obj:`int`, `optional`, defaults to 1000):