Diverse beam search 2 (#9006)
* diverse beam search * bug fixes * bug fixes * bug fix * separate out diverse_beam_search function * separate out diverse_beam_search function * bug fix * improve code quality * bug fix * bug fix * separate out diverse beam search scorer * code format * code format * code format * code format * add test * code format * documentation changes * code quality * add slow integration tests * more general name * refactor into logits processor * add test * avoid too much copy paste * refactor * add to docs * fix-copies * bug fix * Revert "bug fix" This reverts commit c99eb5a8dc57a7b0d33a8ac06d8c6a32a7812ad4. * improve comment * implement sylvains feedback Co-authored-by: Ayush Jain <a.jain@sprinklr.com> Co-authored-by: ayushtiku5 <40797286+ayushtiku5@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
67ff1c314a
commit
02d0e0355c
@@ -27,6 +27,7 @@ if is_torch_available():
|
||||
import torch.nn.functional as F
|
||||
|
||||
from transformers.generation_logits_process import (
|
||||
HammingDiversityLogitsProcessor,
|
||||
LogitsProcessorList,
|
||||
MinLengthLogitsProcessor,
|
||||
NoBadWordsLogitsProcessor,
|
||||
@@ -302,3 +303,30 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
self.assertListEqual(
|
||||
torch.isinf(filtered_scores).tolist(), [[False, False, True, True, True], [True, True, False, False, True]]
|
||||
)
|
||||
|
||||
def test_hamming_diversity(self):
|
||||
vocab_size = 4
|
||||
num_beams = 2
|
||||
num_beam_groups = 2
|
||||
|
||||
scores = self._get_uniform_logits(num_beams, vocab_size)
|
||||
# batch_idx = 0 -> index batch_idx * num_beam_groups -> idx = 0 * 2 = 0 -> penalises tokens 1
|
||||
# batch_idx = 1 -> index batch_idx * num_beam_groups -> idx = 1 * 2 = 2 -> penalises tokens 1
|
||||
current_tokens = torch.tensor([0, 3, 1, 2], device=torch_device, dtype=torch.long)
|
||||
|
||||
diversity_logits_processor = HammingDiversityLogitsProcessor(
|
||||
diversity_penalty=1.0, num_beams=num_beams, num_beam_groups=num_beam_groups
|
||||
)
|
||||
|
||||
processed_scores = diversity_logits_processor(None, scores, current_tokens, 1)
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
processed_scores[0], torch.tensor([-0.7500, 0.2500, 0.2500, 0.2500], device=torch_device), atol=1e-3
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
processed_scores[1], torch.tensor([0.2500, -0.7500, 0.2500, 0.2500], device=torch_device), atol=1e-3
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user