[Generate] Add save mode logits processor to remove nans and infs if necessary (#10769)

* push

* finish

* finish

* make fix copies

* change name
This commit is contained in:
Patrick von Platen
2021-03-23 01:00:05 +03:00
committed by GitHub
parent 9f8fa4e973
commit 77bf3fe787
9 changed files with 156 additions and 36 deletions

View File

@@ -31,6 +31,7 @@ if is_torch_available():
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
HammingDiversityLogitsProcessor,
InfNanRemoveLogitsProcessor,
LogitsProcessorList,
MinLengthLogitsProcessor,
NoBadWordsLogitsProcessor,
@@ -436,3 +437,24 @@ class LogitsProcessorTest(unittest.TestCase):
scores = self._get_uniform_logits(batch_size, vocab_size)
scores = logits_processor(input_ids, scores)
self.assertFalse(torch.isinf(scores).any())
def test_remove_nan_inf_logits_processor(self):
scores = torch.tensor(
[[0.0, 0.7, 0.8, float("nan")], [0.1, float("inf"), 0.3, float("-inf")]], device=torch_device
)
input_ids = ids_tensor((2, 4), vocab_size=20)
logits_processor = InfNanRemoveLogitsProcessor()
scores = logits_processor(input_ids, scores)
self.assertTrue(
torch.allclose(
scores,
torch.tensor(
[[0.0, 0.7, 0.8, 0.0], [0.1, torch.finfo(scores.dtype).max, 0.3, float("-inf")]],
device=torch_device,
),
atol=1e-6,
)
)