Add Watermarking LogitsProcessor and WatermarkDetector (#29676)
* add watermarking processor * remove the other hashing (context width=1 always) * make style * Update src/transformers/generation/logits_process.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/logits_process.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/logits_process.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/configuration_utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * update watermarking process * add detector * update tests to use detector * fix failing tests * rename `input_seq` * make style * doc for processor * minor fixes * docs * make quality * Update src/transformers/generation/configuration_utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/logits_process.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/watermarking.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/watermarking.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/watermarking.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * add PR suggestions * let's use lru_cache's default max size (128) * import processor if torch available * maybe like this * lets move the config to torch independet file * add docs * tiny docs fix to make the test happy * Update src/transformers/generation/configuration_utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/watermarking.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * PR suggestions * add docs * fix test * fix docs * address pr comments * style * Revert "style" This reverts commit 7f33cc34ff08b414f8e7f90060889877606b43b2. * correct style * make doctest green --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
committed by
GitHub
parent
65ea1904ff
commit
5ad960f1f4
@@ -53,6 +53,7 @@ if is_torch_available():
|
||||
TopPLogitsWarper,
|
||||
TypicalLogitsWarper,
|
||||
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
||||
WatermarkLogitsProcessor,
|
||||
)
|
||||
from transformers.generation.logits_process import BarkEosPrioritizerLogitsProcessor
|
||||
|
||||
@@ -949,3 +950,26 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
[float("-inf"), float("-inf"), scores[0][0], scores[0][0]],
|
||||
]
|
||||
self.assertListEqual(actual_scores.tolist(), expected_scores_list)
|
||||
|
||||
def test_watermarking_processor(self):
|
||||
batch_size = 3
|
||||
vocab_size = 20
|
||||
|
||||
input_ids = ids_tensor((batch_size, 5), vocab_size=20)
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
|
||||
# raise error if incorrect seeding_scheme is passed
|
||||
with self.assertRaises(ValueError):
|
||||
WatermarkLogitsProcessor(vocab_size=vocab_size, device="cpu", seeding_scheme="hash")
|
||||
|
||||
# raise error if the greenlist_ratio in not in range (0.0, 1.0)
|
||||
with self.assertRaises(ValueError):
|
||||
WatermarkLogitsProcessor(vocab_size=vocab_size, device="cpu", greenlist_ratio=1.2)
|
||||
|
||||
watermark = WatermarkLogitsProcessor(vocab_size=vocab_size, device=input_ids.device)
|
||||
|
||||
# use fixed id for last token, needed for reprodicibility and tests
|
||||
input_ids[:, -1] = 10
|
||||
scores_wo_bias = scores[:, -1].clone()
|
||||
out = watermark(input_ids=input_ids, scores=scores)
|
||||
self.assertTrue((out[:, 1] == scores_wo_bias + watermark.bias).all())
|
||||
|
||||
Reference in New Issue
Block a user