[Generate] Add save mode logits processor to remove nans and infs if necessary (#10769)
* push * finish * finish * make fix copies * change name
This commit is contained in:
committed by
GitHub
parent
9f8fa4e973
commit
77bf3fe787
@@ -31,6 +31,7 @@ if is_torch_available():
|
||||
ForcedBOSTokenLogitsProcessor,
|
||||
ForcedEOSTokenLogitsProcessor,
|
||||
HammingDiversityLogitsProcessor,
|
||||
InfNanRemoveLogitsProcessor,
|
||||
LogitsProcessorList,
|
||||
MinLengthLogitsProcessor,
|
||||
NoBadWordsLogitsProcessor,
|
||||
@@ -436,3 +437,24 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores = logits_processor(input_ids, scores)
|
||||
self.assertFalse(torch.isinf(scores).any())
|
||||
|
||||
def test_remove_nan_inf_logits_processor(self):
|
||||
scores = torch.tensor(
|
||||
[[0.0, 0.7, 0.8, float("nan")], [0.1, float("inf"), 0.3, float("-inf")]], device=torch_device
|
||||
)
|
||||
input_ids = ids_tensor((2, 4), vocab_size=20)
|
||||
|
||||
logits_processor = InfNanRemoveLogitsProcessor()
|
||||
|
||||
scores = logits_processor(input_ids, scores)
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
scores,
|
||||
torch.tensor(
|
||||
[[0.0, 0.7, 0.8, 0.0], [0.1, torch.finfo(scores.dtype).max, 0.3, float("-inf")]],
|
||||
device=torch_device,
|
||||
),
|
||||
atol=1e-6,
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user