Add SynthID (watermerking by Google DeepMind) (#34350)

* Add SynthIDTextWatermarkLogitsProcessor

* esolving comments.

* Resolving comments.

* esolving commits,

* Improving SynthIDWatermark tests.

* switch to PT version

* detector as pretrained model + style

* update training + style

* rebase

* Update logits_process.py

* Improving SynthIDWatermark tests.

* Shift detector training to wikitext negatives and stabilize with lower learning rate.

* Clean up.

* in for 7B

* cleanup

* upport python 3.8.

* README and final cleanup.

* HF Hub upload and initiaze.

* Update requirements for synthid_text.

* Adding SynthIDTextWatermarkDetector.

* Detector testing.

* Documentation changes.

* Copyrights fix.

* Fix detector api.

* ironing out errors

* ironing out errors

* training checks

* make fixup and make fix-copies

* docstrings and add to docs

* copyright

* BC

* test docstrings

* move import

* protect type hints

* top level imports

* watermarking example

* direct imports

* tpr fpr meaning

* process_kwargs

* SynthIDTextWatermarkingConfig docstring

* assert -> exception

* example updates

* no immutable dict (cant be serialized)

* pack fn

* einsum equivalent

* import order

* fix test on gpu

* add detector example

---------

Co-authored-by: Sumedh Ghaisas <sumedhg@google.com>
Co-authored-by: Marc Sun <marc@huggingface.co>
Co-authored-by: sumedhghaisas2 <138781311+sumedhghaisas2@users.noreply.github.com>
Co-authored-by: raushan <raushan@huggingface.co>
This commit is contained in:
Joao Gante
2024-10-23 21:18:52 +01:00
committed by GitHub
parent e50bf61dec
commit b0f0c61899
15 changed files with 2238 additions and 80 deletions

View File

@@ -84,6 +84,7 @@ if is_torch_available():
SampleEncoderDecoderOutput,
StoppingCriteria,
StoppingCriteriaList,
SynthIDTextWatermarkingConfig,
WatermarkDetector,
WatermarkingConfig,
)
@@ -2517,9 +2518,9 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self.assertListEqual(low_output.tolist(), high_output.tolist())
@slow
def test_watermark_generation(self):
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2").to(torch_device)
def test_green_red_watermark_generation(self):
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
tokenizer.pad_token_id = tokenizer.eos_token_id
model_inputs = tokenizer("I will be", return_tensors="pt").to(torch_device)
input_len = model_inputs["input_ids"].shape[-1]
@@ -2548,6 +2549,61 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self.assertListEqual(detection_out_watermarked.prediction.tolist(), [True])
self.assertListEqual(detection_out.prediction.tolist(), [False])
"""Check the mean bias inserted by the watermarking algorithm."""
@slow
def test_synthid_text_watermark_generation_mean_expected_bias(self):
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
tokenizer.pad_token_id = tokenizer.eos_token_id
model_inputs = tokenizer("I will be", return_tensors="pt").to(torch_device)
input_len = 5
batch_size = 200
# generation should work with both input types: WatermarkingConfig or Dict, so let's check it here :)
watermark_config = SynthIDTextWatermarkingConfig(keys=[10, 20], ngram_len=5, debug_mode=True)
logits_processor = watermark_config.construct_processor(model.config.vocab_size, torch_device)
mean_g_values_repeats = []
for _ in range(40):
input_ids = torch.zeros(
(batch_size, input_len),
dtype=torch.int64,
device=torch_device,
)
model_inputs = {
"input_ids": input_ids,
"attention_mask": torch.ones_like(input_ids, device=torch_device),
}
output = model.generate(
**model_inputs, watermarking_config=watermark_config, do_sample=True, max_length=500, top_k=1000
)
g_values = logits_processor.compute_g_values(input_ids=output[:, input_len:])
context_repetition_mask = logits_processor.compute_context_repetition_mask(
input_ids=output[:, input_len:],
).unsqueeze(dim=2)
mean_g_values = torch.masked.mean(
g_values,
mask=context_repetition_mask,
dim=0,
keepdim=True,
dtype=torch.float64,
)
mean_g_values_repeats.append(mean_g_values)
mean_g_values = torch.concat(mean_g_values_repeats, dim=0).mean(dim=0)
expected_mean_g_value = logits_processor.expected_mean_g_value(
vocab_size=model.config.vocab_size,
)
atol = 0.03
is_close = torch.isclose(
mean_g_values,
torch.tensor(expected_mean_g_value, dtype=torch.float64),
atol=atol,
rtol=0,
)
self.assertTrue(torch.all(is_close))
@slow
def test_beam_search_example_integration(self):
# PT-only test: TF doesn't have a BeamSearchScorer