Fix memory regression in Seq2Seq example (#9713)

* Fix memory regression in Seq2Seq example

* Fix test and properly deal with -100

* Easier condition with device safety

* Patch for MBartTokenzierFast
This commit is contained in:
Sylvain Gugger
2021-01-21 12:05:46 -05:00
committed by GitHub
parent a7dabfb3d1
commit 5f80c15ef5
5 changed files with 43 additions and 16 deletions

View File

@@ -1297,14 +1297,18 @@ class Trainer:
Subclass and override for custom behavior.
"""
if self.label_smoother is not None and "labels" in inputs:
labels = inputs.pop("labels")
else:
labels = None
outputs = model(**inputs)
# Save past state if it exists
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index]
if self.label_smoother is not None and "labels" in inputs:
return self.label_smoother(outputs, inputs["labels"])
if labels is not None:
return self.label_smoother(outputs, labels)
else:
# We don't use .loss here since the model may return tuples instead of ModelOutput.
return outputs["loss"] if isinstance(outputs, dict) else outputs[0]

View File

@@ -380,17 +380,26 @@ class LabelSmoother:
ignore_index: int = -100
def __call__(self, model_output, labels):
model_loss = model_output["loss"] if isinstance(model_output, dict) else model_output[0]
logits = model_output["logits"] if isinstance(model_output, dict) else model_output[1]
logits = model_output["logits"] if isinstance(model_output, dict) else model_output[0]
log_probs = -torch.nn.functional.log_softmax(logits, dim=-1)
if labels.dim() == log_probs.dim() - 1:
labels = labels.unsqueeze(-1)
# Look at the ignored index and mask the corresponding log_probs.
padding_mask = labels.unsqueeze(-1).eq(self.ignore_index)
log_probs.masked_fill_(padding_mask, 0.0)
padding_mask = labels.eq(self.ignore_index)
# In case the ignore_index is -100, the gather will fail, so we replace labels by 0. The padding_mask
# will ignore them in any case.
labels.clamp_min_(0)
nll_loss = log_probs.gather(dim=-1, index=labels)
smoothed_loss = log_probs.sum(dim=-1, keepdim=True)
nll_loss.masked_fill_(padding_mask, 0.0)
smoothed_loss.masked_fill_(padding_mask, 0.0)
# Take the mean over the label dimensions, then divide by the number of active elements (i.e. not-padded):
smoothed_loss = log_probs.mean(dim=-1).sum() / (padding_mask.numel() - padding_mask.long().sum())
return (1 - self.epsilon) * model_loss + self.epsilon * smoothed_loss
num_active_elements = padding_mask.numel() - padding_mask.long().sum()
nll_loss = nll_loss.sum() / num_active_elements
smoothed_loss = smoothed_loss.sum() / (num_active_elements * log_probs.shape[-1])
return (1 - self.epsilon) * nll_loss + self.epsilon * smoothed_loss
def get_length_grouped_indices(lengths, batch_size, mega_batch_mult=None, generator=None):