Add Watermarking LogitsProcessor and WatermarkDetector (#29676)

* add watermarking processor

* remove the other hashing (context width=1 always)

* make style

* Update src/transformers/generation/logits_process.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/generation/logits_process.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/generation/logits_process.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/generation/configuration_utils.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* update watermarking process

* add detector

* update tests to use detector

* fix failing tests

* rename `input_seq`

* make style

* doc for processor

* minor fixes

* docs

* make quality

* Update src/transformers/generation/configuration_utils.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/generation/logits_process.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/generation/watermarking.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/generation/watermarking.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/generation/watermarking.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* add PR suggestions

* let's use lru_cache's default max size (128)

* import processor if torch available

* maybe like this

* lets move the config to torch independet file

* add docs

* tiny docs fix to make the test happy

* Update src/transformers/generation/configuration_utils.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/generation/watermarking.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* PR suggestions

* add docs

* fix test

* fix docs

* address pr comments

* style

* Revert "style"

This reverts commit 7f33cc34ff08b414f8e7f90060889877606b43b2.

* correct style

* make doctest green

---------

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
Raushan Turganbay
2024-05-14 13:31:39 +05:00
committed by GitHub
parent 65ea1904ff
commit 5ad960f1f4
12 changed files with 738 additions and 4 deletions

View File

@@ -53,6 +53,7 @@ if is_torch_available():
TopPLogitsWarper,
TypicalLogitsWarper,
UnbatchedClassifierFreeGuidanceLogitsProcessor,
WatermarkLogitsProcessor,
)
from transformers.generation.logits_process import BarkEosPrioritizerLogitsProcessor
@@ -949,3 +950,26 @@ class LogitsProcessorTest(unittest.TestCase):
[float("-inf"), float("-inf"), scores[0][0], scores[0][0]],
]
self.assertListEqual(actual_scores.tolist(), expected_scores_list)
def test_watermarking_processor(self):
batch_size = 3
vocab_size = 20
input_ids = ids_tensor((batch_size, 5), vocab_size=20)
scores = self._get_uniform_logits(batch_size, vocab_size)
# raise error if incorrect seeding_scheme is passed
with self.assertRaises(ValueError):
WatermarkLogitsProcessor(vocab_size=vocab_size, device="cpu", seeding_scheme="hash")
# raise error if the greenlist_ratio in not in range (0.0, 1.0)
with self.assertRaises(ValueError):
WatermarkLogitsProcessor(vocab_size=vocab_size, device="cpu", greenlist_ratio=1.2)
watermark = WatermarkLogitsProcessor(vocab_size=vocab_size, device=input_ids.device)
# use fixed id for last token, needed for reprodicibility and tests
input_ids[:, -1] = 10
scores_wo_bias = scores[:, -1].clone()
out = watermark(input_ids=input_ids, scores=scores)
self.assertTrue((out[:, 1] == scores_wo_bias + watermark.bias).all())

View File

@@ -75,6 +75,8 @@ if is_torch_available():
SampleEncoderDecoderOutput,
StoppingCriteria,
StoppingCriteriaList,
WatermarkDetector,
WatermarkingConfig,
)
from transformers.generation.utils import _speculative_sampling
@@ -2098,6 +2100,44 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
)
self.assertListEqual(low_output.tolist(), high_output.tolist())
@slow
def test_watermark_generation(self):
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2").to(torch_device)
tokenizer.pad_token_id = tokenizer.eos_token_id
model_inputs = tokenizer("I will be", return_tensors="pt").to(torch_device)
input_len = model_inputs["input_ids"].shape[-1]
# generation should work with both input types: WatermarkingConfig or Dict, so let's check it here :)
watermark_config = WatermarkingConfig(bias=2.5, seeding_scheme="selfhash")
_ = model.generate(**model_inputs, watermarking_config=watermark_config, do_sample=False, max_length=15)
args = {
"bias": 2.0,
"context_width": 1,
"seeding_scheme": "selfhash",
"greenlist_ratio": 0.25,
"hashing_key": 15485863,
}
output = model.generate(**model_inputs, do_sample=False, max_length=15)
output_selfhash = model.generate(**model_inputs, watermarking_config=args, do_sample=False, max_length=15)
# check that the watermarked text is generating what is should
self.assertListEqual(
output.tolist(), [[40, 481, 307, 262, 717, 284, 9159, 326, 314, 716, 407, 257, 4336, 286, 262]]
)
self.assertListEqual(
output_selfhash.tolist(), [[40, 481, 307, 2263, 616, 640, 284, 651, 616, 1621, 503, 612, 553, 531, 367]]
)
detector = WatermarkDetector(model_config=model.config, device=torch_device, watermarking_config=args)
detection_out_watermarked = detector(output_selfhash[:, input_len:], return_dict=True)
detection_out = detector(output[:, input_len:], return_dict=True)
# check that the detector is detecting watermarked text
self.assertListEqual(detection_out_watermarked.prediction.tolist(), [True])
self.assertListEqual(detection_out.prediction.tolist(), [False])
@slow
def test_beam_search_example_integration(self):
# PT-only test: TF doesn't have a BeamSearchScorer