Fix memory regression in Seq2Seq example (#9713)
* Fix memory regression in Seq2Seq example * Fix test and properly deal with -100 * Easier condition with device safety * Patch for MBartTokenzierFast
This commit is contained in:
@@ -380,17 +380,26 @@ class LabelSmoother:
|
||||
ignore_index: int = -100
|
||||
|
||||
def __call__(self, model_output, labels):
|
||||
model_loss = model_output["loss"] if isinstance(model_output, dict) else model_output[0]
|
||||
logits = model_output["logits"] if isinstance(model_output, dict) else model_output[1]
|
||||
logits = model_output["logits"] if isinstance(model_output, dict) else model_output[0]
|
||||
log_probs = -torch.nn.functional.log_softmax(logits, dim=-1)
|
||||
if labels.dim() == log_probs.dim() - 1:
|
||||
labels = labels.unsqueeze(-1)
|
||||
|
||||
# Look at the ignored index and mask the corresponding log_probs.
|
||||
padding_mask = labels.unsqueeze(-1).eq(self.ignore_index)
|
||||
log_probs.masked_fill_(padding_mask, 0.0)
|
||||
padding_mask = labels.eq(self.ignore_index)
|
||||
# In case the ignore_index is -100, the gather will fail, so we replace labels by 0. The padding_mask
|
||||
# will ignore them in any case.
|
||||
labels.clamp_min_(0)
|
||||
nll_loss = log_probs.gather(dim=-1, index=labels)
|
||||
smoothed_loss = log_probs.sum(dim=-1, keepdim=True)
|
||||
|
||||
nll_loss.masked_fill_(padding_mask, 0.0)
|
||||
smoothed_loss.masked_fill_(padding_mask, 0.0)
|
||||
|
||||
# Take the mean over the label dimensions, then divide by the number of active elements (i.e. not-padded):
|
||||
smoothed_loss = log_probs.mean(dim=-1).sum() / (padding_mask.numel() - padding_mask.long().sum())
|
||||
return (1 - self.epsilon) * model_loss + self.epsilon * smoothed_loss
|
||||
num_active_elements = padding_mask.numel() - padding_mask.long().sum()
|
||||
nll_loss = nll_loss.sum() / num_active_elements
|
||||
smoothed_loss = smoothed_loss.sum() / (num_active_elements * log_probs.shape[-1])
|
||||
return (1 - self.epsilon) * nll_loss + self.epsilon * smoothed_loss
|
||||
|
||||
|
||||
def get_length_grouped_indices(lengths, batch_size, mega_batch_mult=None, generator=None):
|
||||
|
||||
Reference in New Issue
Block a user