disable_ngram_loss fix for prophetnet (#8554)
* `disable_ngram_loss` fix for prophetnet * add changes documentation * fix _compute_loss to use mean reduction and -100 to masked tokens & remove unnecessary arguments * mean label smoothing loss * small refactor * fix test Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -417,7 +417,7 @@ class ProphetNetModelTester:
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
labels=lm_labels,
|
||||
)
|
||||
self.parent.assertTrue(torch.allclose(result.loss, torch.tensor(128.2925, device=torch_device), atol=1e-3))
|
||||
self.parent.assertTrue(torch.allclose(result.loss, torch.tensor(4.5819, device=torch_device), atol=1e-3))
|
||||
|
||||
expected_logit_slice = torch.tensor(
|
||||
[-0.1565, 0.0418, 0.1207, 0.0030, 0.0665, 0.0467, 0.0412], device=torch_device
|
||||
|
||||
Reference in New Issue
Block a user