Diverse beam search 2 (#9006)

* diverse beam search

* bug fixes

* bug fixes

* bug fix

* separate out diverse_beam_search function

* separate out diverse_beam_search function

* bug fix

* improve code quality

* bug fix

* bug fix

* separate out diverse beam search scorer

* code format

* code format

* code format

* code format

* add test

* code format

* documentation changes

* code quality

* add slow integration tests

* more general name

* refactor into logits processor

* add test

* avoid too much copy paste

* refactor

* add to docs

* fix-copies

* bug fix

* Revert "bug fix"

This reverts commit c99eb5a8dc57a7b0d33a8ac06d8c6a32a7812ad4.

* improve comment

* implement sylvains feedback

Co-authored-by: Ayush Jain <a.jain@sprinklr.com>
Co-authored-by: ayushtiku5 <40797286+ayushtiku5@users.noreply.github.com>
This commit is contained in:
Patrick von Platen
2020-12-09 15:00:37 +01:00
committed by GitHub
parent 67ff1c314a
commit 02d0e0355c
10 changed files with 590 additions and 20 deletions

View File

@@ -27,6 +27,7 @@ if is_torch_available():
import torch.nn.functional as F
from transformers.generation_logits_process import (
HammingDiversityLogitsProcessor,
LogitsProcessorList,
MinLengthLogitsProcessor,
NoBadWordsLogitsProcessor,
@@ -302,3 +303,30 @@ class LogitsProcessorTest(unittest.TestCase):
self.assertListEqual(
torch.isinf(filtered_scores).tolist(), [[False, False, True, True, True], [True, True, False, False, True]]
)
def test_hamming_diversity(self):
vocab_size = 4
num_beams = 2
num_beam_groups = 2
scores = self._get_uniform_logits(num_beams, vocab_size)
# batch_idx = 0 -> index batch_idx * num_beam_groups -> idx = 0 * 2 = 0 -> penalises tokens 1
# batch_idx = 1 -> index batch_idx * num_beam_groups -> idx = 1 * 2 = 2 -> penalises tokens 1
current_tokens = torch.tensor([0, 3, 1, 2], device=torch_device, dtype=torch.long)
diversity_logits_processor = HammingDiversityLogitsProcessor(
diversity_penalty=1.0, num_beams=num_beams, num_beam_groups=num_beam_groups
)
processed_scores = diversity_logits_processor(None, scores, current_tokens, 1)
self.assertTrue(
torch.allclose(
processed_scores[0], torch.tensor([-0.7500, 0.2500, 0.2500, 0.2500], device=torch_device), atol=1e-3
)
)
self.assertTrue(
torch.allclose(
processed_scores[1], torch.tensor([0.2500, -0.7500, 0.2500, 0.2500], device=torch_device), atol=1e-3
)
)