Add Watermarking LogitsProcessor and WatermarkDetector (#29676)

* add watermarking processor

* remove the other hashing (context width=1 always)

* make style

* Update src/transformers/generation/logits_process.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/generation/logits_process.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/generation/logits_process.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/generation/configuration_utils.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* update watermarking process

* add detector

* update tests to use detector

* fix failing tests

* rename `input_seq`

* make style

* doc for processor

* minor fixes

* docs

* make quality

* Update src/transformers/generation/configuration_utils.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/generation/logits_process.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/generation/watermarking.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/generation/watermarking.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/generation/watermarking.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* add PR suggestions

* let's use lru_cache's default max size (128)

* import processor if torch available

* maybe like this

* lets move the config to torch independet file

* add docs

* tiny docs fix to make the test happy

* Update src/transformers/generation/configuration_utils.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/generation/watermarking.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* PR suggestions

* add docs

* fix test

* fix docs

* address pr comments

* style

* Revert "style"

This reverts commit 7f33cc34ff08b414f8e7f90060889877606b43b2.

* correct style

* make doctest green

---------

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
Raushan Turganbay
2024-05-14 13:31:39 +05:00
committed by GitHub
parent 65ea1904ff
commit 5ad960f1f4
12 changed files with 738 additions and 4 deletions

View File

@@ -53,6 +53,7 @@ if is_torch_available():
TopPLogitsWarper,
TypicalLogitsWarper,
UnbatchedClassifierFreeGuidanceLogitsProcessor,
WatermarkLogitsProcessor,
)
from transformers.generation.logits_process import BarkEosPrioritizerLogitsProcessor
@@ -949,3 +950,26 @@ class LogitsProcessorTest(unittest.TestCase):
[float("-inf"), float("-inf"), scores[0][0], scores[0][0]],
]
self.assertListEqual(actual_scores.tolist(), expected_scores_list)
def test_watermarking_processor(self):
batch_size = 3
vocab_size = 20
input_ids = ids_tensor((batch_size, 5), vocab_size=20)
scores = self._get_uniform_logits(batch_size, vocab_size)
# raise error if incorrect seeding_scheme is passed
with self.assertRaises(ValueError):
WatermarkLogitsProcessor(vocab_size=vocab_size, device="cpu", seeding_scheme="hash")
# raise error if the greenlist_ratio in not in range (0.0, 1.0)
with self.assertRaises(ValueError):
WatermarkLogitsProcessor(vocab_size=vocab_size, device="cpu", greenlist_ratio=1.2)
watermark = WatermarkLogitsProcessor(vocab_size=vocab_size, device=input_ids.device)
# use fixed id for last token, needed for reprodicibility and tests
input_ids[:, -1] = 10
scores_wo_bias = scores[:, -1].clone()
out = watermark(input_ids=input_ids, scores=scores)
self.assertTrue((out[:, 1] == scores_wo_bias + watermark.bias).all())