[s2s] fix label_smoothed_nll_loss (#6344)
This commit is contained in:
@@ -29,17 +29,15 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
|
|||||||
pad_mask = target.eq(ignore_index)
|
pad_mask = target.eq(ignore_index)
|
||||||
nll_loss.masked_fill_(pad_mask, 0.0)
|
nll_loss.masked_fill_(pad_mask, 0.0)
|
||||||
smooth_loss.masked_fill_(pad_mask, 0.0)
|
smooth_loss.masked_fill_(pad_mask, 0.0)
|
||||||
bs = pad_mask.long().sum()
|
|
||||||
else:
|
else:
|
||||||
nll_loss = nll_loss.squeeze(-1)
|
nll_loss = nll_loss.squeeze(-1)
|
||||||
smooth_loss = smooth_loss.squeeze(-1)
|
smooth_loss = smooth_loss.squeeze(-1)
|
||||||
bs = lprobs.shape[0]
|
|
||||||
|
|
||||||
nll_loss = nll_loss.sum() # mean()? Scared to break other math.
|
nll_loss = nll_loss.sum() # mean()? Scared to break other math.
|
||||||
smooth_loss = smooth_loss.sum()
|
smooth_loss = smooth_loss.sum()
|
||||||
eps_i = epsilon / lprobs.size(-1)
|
eps_i = epsilon / lprobs.size(-1)
|
||||||
loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
|
loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
|
||||||
return loss / bs, nll_loss / bs
|
return loss, nll_loss
|
||||||
|
|
||||||
|
|
||||||
def encode_line(tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"):
|
def encode_line(tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"):
|
||||||
|
|||||||
Reference in New Issue
Block a user