Add support for past states (#5399)
* Add support for past states * Style and forgotten self * You mean, documenting is not enough? I have to actually add it too? * Add memory support during evaluation * Fix tests in eval and add TF support * No need to change this line anymore
This commit is contained in:
@@ -493,6 +493,10 @@ class Trainer:
|
|||||||
else:
|
else:
|
||||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=not self.is_local_master())
|
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=not self.is_local_master())
|
||||||
|
|
||||||
|
# Reset the past mems state at the beginning of each epoch if necessary.
|
||||||
|
if self.args.past_index >= 0:
|
||||||
|
self._past = None
|
||||||
|
|
||||||
for step, inputs in enumerate(epoch_iterator):
|
for step, inputs in enumerate(epoch_iterator):
|
||||||
|
|
||||||
# Skip past any already trained steps if resuming training
|
# Skip past any already trained steps if resuming training
|
||||||
@@ -575,6 +579,9 @@ class Trainer:
|
|||||||
|
|
||||||
if self.tb_writer:
|
if self.tb_writer:
|
||||||
self.tb_writer.close()
|
self.tb_writer.close()
|
||||||
|
if self.args.past_index and hasattr(self, "_past"):
|
||||||
|
# Clean the state at the end of training
|
||||||
|
delattr(self, "_past")
|
||||||
|
|
||||||
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
|
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
|
||||||
return TrainOutput(self.global_step, tr_loss / self.global_step)
|
return TrainOutput(self.global_step, tr_loss / self.global_step)
|
||||||
@@ -617,9 +624,15 @@ class Trainer:
|
|||||||
if isinstance(v, torch.Tensor):
|
if isinstance(v, torch.Tensor):
|
||||||
inputs[k] = v.to(self.args.device)
|
inputs[k] = v.to(self.args.device)
|
||||||
|
|
||||||
|
if self.args.past_index >= 0 and self._past is not None:
|
||||||
|
inputs["mems"] = self._past
|
||||||
|
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
|
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
|
||||||
|
|
||||||
|
if self.args.past_index >= 0:
|
||||||
|
self._past = outputs[self.args.past_index]
|
||||||
|
|
||||||
if self.args.n_gpu > 1:
|
if self.args.n_gpu > 1:
|
||||||
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
||||||
if self.args.gradient_accumulation_steps > 1:
|
if self.args.gradient_accumulation_steps > 1:
|
||||||
@@ -802,12 +815,17 @@ class Trainer:
|
|||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available():
|
||||||
dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)
|
dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)
|
||||||
|
|
||||||
|
if self.args.past_index >= 0:
|
||||||
|
past = None
|
||||||
|
|
||||||
for inputs in tqdm(dataloader, desc=description):
|
for inputs in tqdm(dataloader, desc=description):
|
||||||
has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"])
|
has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"])
|
||||||
|
|
||||||
for k, v in inputs.items():
|
for k, v in inputs.items():
|
||||||
if isinstance(v, torch.Tensor):
|
if isinstance(v, torch.Tensor):
|
||||||
inputs[k] = v.to(self.args.device)
|
inputs[k] = v.to(self.args.device)
|
||||||
|
if self.args.past_index >= 0:
|
||||||
|
inputs["mems"] = past
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
@@ -816,6 +834,8 @@ class Trainer:
|
|||||||
eval_losses += [step_eval_loss.mean().item()]
|
eval_losses += [step_eval_loss.mean().item()]
|
||||||
else:
|
else:
|
||||||
logits = outputs[0]
|
logits = outputs[0]
|
||||||
|
if self.args.past_index >= 0:
|
||||||
|
past = outputs[self.args.past_index if has_labels else self.args.past_index - 1]
|
||||||
|
|
||||||
if not prediction_loss_only:
|
if not prediction_loss_only:
|
||||||
if preds is None:
|
if preds is None:
|
||||||
|
|||||||
@@ -240,6 +240,10 @@ class TFTrainer:
|
|||||||
|
|
||||||
step: int = 1
|
step: int = 1
|
||||||
|
|
||||||
|
# Reset the past mems state at the beginning of the evaluation if necessary.
|
||||||
|
if self.args.past_index >= 0:
|
||||||
|
self._past = None
|
||||||
|
|
||||||
for features, labels in dataset:
|
for features, labels in dataset:
|
||||||
step = tf.convert_to_tensor(step, dtype=tf.int64)
|
step = tf.convert_to_tensor(step, dtype=tf.int64)
|
||||||
loss, logits = self._evaluate_steps(features, labels)
|
loss, logits = self._evaluate_steps(features, labels)
|
||||||
@@ -288,6 +292,10 @@ class TFTrainer:
|
|||||||
if not key.startswith("eval_"):
|
if not key.startswith("eval_"):
|
||||||
metrics[f"eval_{key}"] = metrics.pop(key)
|
metrics[f"eval_{key}"] = metrics.pop(key)
|
||||||
|
|
||||||
|
if self.args.past_index and hasattr(self, "_past"):
|
||||||
|
# Clean the state at the end of training
|
||||||
|
delattr(self, "_past")
|
||||||
|
|
||||||
return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
|
return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
|
||||||
|
|
||||||
def _log(self, logs: Dict[str, float]) -> None:
|
def _log(self, logs: Dict[str, float]) -> None:
|
||||||
@@ -405,6 +413,9 @@ class TFTrainer:
|
|||||||
logger.info(" Total optimization steps = %d", t_total)
|
logger.info(" Total optimization steps = %d", t_total)
|
||||||
|
|
||||||
for epoch_iter in range(epochs_trained, int(epochs + 1)):
|
for epoch_iter in range(epochs_trained, int(epochs + 1)):
|
||||||
|
# Reset the past mems state at the beginning of each epoch if necessary.
|
||||||
|
if self.args.past_index >= 0:
|
||||||
|
self._past = None
|
||||||
for step, training_loss in enumerate(self._training_steps(train_ds, optimizer)):
|
for step, training_loss in enumerate(self._training_steps(train_ds, optimizer)):
|
||||||
self.global_step = iterations.numpy()
|
self.global_step = iterations.numpy()
|
||||||
self.epoch_logging = epoch_iter - 1 + (step + 1) / steps_per_epoch
|
self.epoch_logging = epoch_iter - 1 + (step + 1) / steps_per_epoch
|
||||||
@@ -444,6 +455,10 @@ class TFTrainer:
|
|||||||
if self.args.max_steps > 0 and self.global_step % self.args.max_steps == 0:
|
if self.args.max_steps > 0 and self.global_step % self.args.max_steps == 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if self.args.past_index and hasattr(self, "_past"):
|
||||||
|
# Clean the state at the end of training
|
||||||
|
delattr(self, "_past")
|
||||||
|
|
||||||
def _training_steps(self, ds, optimizer):
|
def _training_steps(self, ds, optimizer):
|
||||||
"""
|
"""
|
||||||
Returns a generator over training steps (i.e. parameters update).
|
Returns a generator over training steps (i.e. parameters update).
|
||||||
@@ -518,10 +533,15 @@ class TFTrainer:
|
|||||||
labels: the batched labels.
|
labels: the batched labels.
|
||||||
training: run the model in training mode or not
|
training: run the model in training mode or not
|
||||||
"""
|
"""
|
||||||
|
if self.args.past_index >= 0 and getattr(self, "_past", None) is not None:
|
||||||
|
features["mems"] = self._past
|
||||||
if isinstance(labels, (dict)):
|
if isinstance(labels, (dict)):
|
||||||
loss, logits = self.model(features, training=training, **labels)[:2]
|
outputs = self.model(features, training=training, **labels)[:2]
|
||||||
else:
|
else:
|
||||||
loss, logits = self.model(features, labels=labels, training=training)[:2]
|
outputs = self.model(features, labels=labels, training=training)[:2]
|
||||||
|
loss, logits = outputs[:2]
|
||||||
|
if self.args.past_index >= 0:
|
||||||
|
self._past = outputs[self.args.past_index]
|
||||||
loss += sum(self.model.losses) * (1.0 / self.args.n_gpu)
|
loss += sum(self.model.losses) * (1.0 / self.args.n_gpu)
|
||||||
|
|
||||||
return loss, logits
|
return loss, logits
|
||||||
|
|||||||
@@ -102,6 +102,11 @@ class TrainingArguments:
|
|||||||
dataloader_drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
dataloader_drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size)
|
Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size)
|
||||||
or not.
|
or not.
|
||||||
|
past_index (:obj:`int`, `optional`, defaults to -1):
|
||||||
|
Some models like :doc:`TransformerXL <../model_doc/transformerxl>` or :doc`XLNet <../model_doc/xlnet>` can
|
||||||
|
make use of the past hidden states for their predictions. If this argument is set to a positive int, the
|
||||||
|
``Trainer`` will use the corresponding output (usually index 2) as the past state and feed it to the model
|
||||||
|
at the next training step under the keyword argument ``mems``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
output_dir: str = field(
|
output_dir: str = field(
|
||||||
@@ -203,6 +208,11 @@ class TrainingArguments:
|
|||||||
default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."}
|
default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
past_index: int = field(
|
||||||
|
default=-1,
|
||||||
|
metadata={"help": "If >=0, uses the corresponding part of the output as the past state for next step."},
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def train_batch_size(self) -> int:
|
def train_batch_size(self) -> int:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -85,6 +85,11 @@ class TFTrainingArguments(TrainingArguments):
|
|||||||
dataloader_drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
dataloader_drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size)
|
Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size)
|
||||||
or not.
|
or not.
|
||||||
|
past_index (:obj:`int`, `optional`, defaults to -1):
|
||||||
|
Some models like :doc:`TransformerXL <../model_doc/transformerxl>` or :doc`XLNet <../model_doc/xlnet>` can
|
||||||
|
make use of the past hidden states for their predictions. If this argument is set to a positive int, the
|
||||||
|
``Trainer`` will use the corresponding output (usually index 2) as the past state and feed it to the model
|
||||||
|
at the next training step under the keyword argument ``mems``.
|
||||||
tpu_name (:obj:`str`, `optional`):
|
tpu_name (:obj:`str`, `optional`):
|
||||||
The name of the TPU the process is running on.
|
The name of the TPU the process is running on.
|
||||||
eval_steps (:obj:`int`, `optional`, defaults to 1000):
|
eval_steps (:obj:`int`, `optional`, defaults to 1000):
|
||||||
|
|||||||
Reference in New Issue
Block a user