fix typos in the tests directory (#36717)
This commit is contained in:
@@ -751,7 +751,7 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
processed_scores = logits_processor(input_ids, scores)
|
||||
self.assertTrue(torch.isneginf(processed_scores[:, bos_token_id + 1 :]).all())
|
||||
# score for bos_token_id shold be zero
|
||||
# score for bos_token_id should be zero
|
||||
self.assertListEqual(processed_scores[:, bos_token_id].tolist(), 4 * [0])
|
||||
|
||||
# processor should not change logits in-place
|
||||
@@ -972,7 +972,7 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
|
||||
watermark = WatermarkLogitsProcessor(vocab_size=vocab_size, device=input_ids.device)
|
||||
|
||||
# use fixed id for last token, needed for reprodicibility and tests
|
||||
# use fixed id for last token, needed for reproducibility and tests
|
||||
input_ids[:, -1] = 10
|
||||
scores_wo_bias = scores[:, -1].clone()
|
||||
out = watermark(input_ids=input_ids, scores=scores)
|
||||
|
||||
Reference in New Issue
Block a user