Generate: add SequenceBiasLogitsProcessor (#24334)

This commit is contained in:
Joao Gante
2023-06-21 11:14:41 +01:00
committed by GitHub
parent 45f71d793d
commit 5f0801d174
8 changed files with 241 additions and 123 deletions

View File

@@ -46,6 +46,7 @@ if is_torch_available():
NoRepeatNGramLogitsProcessor,
PrefixConstrainedLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
SequenceBiasLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
@@ -512,6 +513,27 @@ class LogitsProcessorTest(unittest.TestCase):
filtered_scores = no_bad_words_dist_proc(input_ids, scores.clone())
self.assertTrue(torch.allclose(scores, filtered_scores, atol=1e-3))
def test_bias_dist_processor(self):
vocab_size = 5
batch_size = 2
input_ids = torch.tensor([[0, 1, 3, 1], [0, 1, 0, 1]], device=torch_device, dtype=torch.long)
positive_bias = {(1,): 100.0, (4,): 100.0}
negative_bias = {(1, 0): -100.0, (0, 1, 2): -100.0, (1, 3, 1, 3): -100.0}
sequence_bias = {**positive_bias, **negative_bias}
# scores = 0 to facilitate checks
scores = torch.zeros((batch_size, vocab_size), dtype=torch.float, device=torch_device)
bias_dist_proc = SequenceBiasLogitsProcessor(sequence_bias=sequence_bias)
filtered_scores = bias_dist_proc(input_ids, scores.clone())
# batch 1: positive bias: tokens (1, 4); negative bias: tokens (0, 3); neutral: tokens (2)
# batch 2: positive bias: tokens (1, 4); negative bias: tokens (0, 2); neutral: tokens (3)
self.assertListEqual(
filtered_scores.tolist(), [[-100.0, 100.0, 0.0, -100.0, 100.0], [-100.0, 100.0, -100.0, 0.0, 100.0]]
)
def test_processor_list(self):
batch_size = 4
sequence_length = 10