Generate: add SequenceBiasLogitsProcessor (#24334)
This commit is contained in:
@@ -46,6 +46,7 @@ if is_torch_available():
|
||||
NoRepeatNGramLogitsProcessor,
|
||||
PrefixConstrainedLogitsProcessor,
|
||||
RepetitionPenaltyLogitsProcessor,
|
||||
SequenceBiasLogitsProcessor,
|
||||
TemperatureLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
@@ -512,6 +513,27 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
filtered_scores = no_bad_words_dist_proc(input_ids, scores.clone())
|
||||
self.assertTrue(torch.allclose(scores, filtered_scores, atol=1e-3))
|
||||
|
||||
def test_bias_dist_processor(self):
|
||||
vocab_size = 5
|
||||
batch_size = 2
|
||||
|
||||
input_ids = torch.tensor([[0, 1, 3, 1], [0, 1, 0, 1]], device=torch_device, dtype=torch.long)
|
||||
positive_bias = {(1,): 100.0, (4,): 100.0}
|
||||
negative_bias = {(1, 0): -100.0, (0, 1, 2): -100.0, (1, 3, 1, 3): -100.0}
|
||||
sequence_bias = {**positive_bias, **negative_bias}
|
||||
|
||||
# scores = 0 to facilitate checks
|
||||
scores = torch.zeros((batch_size, vocab_size), dtype=torch.float, device=torch_device)
|
||||
|
||||
bias_dist_proc = SequenceBiasLogitsProcessor(sequence_bias=sequence_bias)
|
||||
filtered_scores = bias_dist_proc(input_ids, scores.clone())
|
||||
|
||||
# batch 1: positive bias: tokens (1, 4); negative bias: tokens (0, 3); neutral: tokens (2)
|
||||
# batch 2: positive bias: tokens (1, 4); negative bias: tokens (0, 2); neutral: tokens (3)
|
||||
self.assertListEqual(
|
||||
filtered_scores.tolist(), [[-100.0, 100.0, 0.0, -100.0, 100.0], [-100.0, 100.0, -100.0, 0.0, 100.0]]
|
||||
)
|
||||
|
||||
def test_processor_list(self):
|
||||
batch_size = 4
|
||||
sequence_length = 10
|
||||
|
||||
Reference in New Issue
Block a user