Fix memory regression in Seq2Seq example (#9713)
* Fix memory regression in Seq2Seq example * Fix test and properly deal with -100 * Easier condition with device safety * Patch for MBartTokenzierFast
This commit is contained in:
@@ -1297,14 +1297,18 @@ class Trainer:
|
||||
|
||||
Subclass and override for custom behavior.
|
||||
"""
|
||||
if self.label_smoother is not None and "labels" in inputs:
|
||||
labels = inputs.pop("labels")
|
||||
else:
|
||||
labels = None
|
||||
outputs = model(**inputs)
|
||||
# Save past state if it exists
|
||||
# TODO: this needs to be fixed and made cleaner later.
|
||||
if self.args.past_index >= 0:
|
||||
self._past = outputs[self.args.past_index]
|
||||
|
||||
if self.label_smoother is not None and "labels" in inputs:
|
||||
return self.label_smoother(outputs, inputs["labels"])
|
||||
if labels is not None:
|
||||
return self.label_smoother(outputs, labels)
|
||||
else:
|
||||
# We don't use .loss here since the model may return tuples instead of ModelOutput.
|
||||
return outputs["loss"] if isinstance(outputs, dict) else outputs[0]
|
||||
|
||||
@@ -380,17 +380,26 @@ class LabelSmoother:
|
||||
ignore_index: int = -100
|
||||
|
||||
def __call__(self, model_output, labels):
|
||||
model_loss = model_output["loss"] if isinstance(model_output, dict) else model_output[0]
|
||||
logits = model_output["logits"] if isinstance(model_output, dict) else model_output[1]
|
||||
logits = model_output["logits"] if isinstance(model_output, dict) else model_output[0]
|
||||
log_probs = -torch.nn.functional.log_softmax(logits, dim=-1)
|
||||
if labels.dim() == log_probs.dim() - 1:
|
||||
labels = labels.unsqueeze(-1)
|
||||
|
||||
# Look at the ignored index and mask the corresponding log_probs.
|
||||
padding_mask = labels.unsqueeze(-1).eq(self.ignore_index)
|
||||
log_probs.masked_fill_(padding_mask, 0.0)
|
||||
padding_mask = labels.eq(self.ignore_index)
|
||||
# In case the ignore_index is -100, the gather will fail, so we replace labels by 0. The padding_mask
|
||||
# will ignore them in any case.
|
||||
labels.clamp_min_(0)
|
||||
nll_loss = log_probs.gather(dim=-1, index=labels)
|
||||
smoothed_loss = log_probs.sum(dim=-1, keepdim=True)
|
||||
|
||||
nll_loss.masked_fill_(padding_mask, 0.0)
|
||||
smoothed_loss.masked_fill_(padding_mask, 0.0)
|
||||
|
||||
# Take the mean over the label dimensions, then divide by the number of active elements (i.e. not-padded):
|
||||
smoothed_loss = log_probs.mean(dim=-1).sum() / (padding_mask.numel() - padding_mask.long().sum())
|
||||
return (1 - self.epsilon) * model_loss + self.epsilon * smoothed_loss
|
||||
num_active_elements = padding_mask.numel() - padding_mask.long().sum()
|
||||
nll_loss = nll_loss.sum() / num_active_elements
|
||||
smoothed_loss = smoothed_loss.sum() / (num_active_elements * log_probs.shape[-1])
|
||||
return (1 - self.epsilon) * nll_loss + self.epsilon * smoothed_loss
|
||||
|
||||
|
||||
def get_length_grouped_indices(lengths, batch_size, mega_batch_mult=None, generator=None):
|
||||
|
||||
Reference in New Issue
Block a user