Add SynthID (watermerking by Google DeepMind) (#34350)
* Add SynthIDTextWatermarkLogitsProcessor * esolving comments. * Resolving comments. * esolving commits, * Improving SynthIDWatermark tests. * switch to PT version * detector as pretrained model + style * update training + style * rebase * Update logits_process.py * Improving SynthIDWatermark tests. * Shift detector training to wikitext negatives and stabilize with lower learning rate. * Clean up. * in for 7B * cleanup * upport python 3.8. * README and final cleanup. * HF Hub upload and initiaze. * Update requirements for synthid_text. * Adding SynthIDTextWatermarkDetector. * Detector testing. * Documentation changes. * Copyrights fix. * Fix detector api. * ironing out errors * ironing out errors * training checks * make fixup and make fix-copies * docstrings and add to docs * copyright * BC * test docstrings * move import * protect type hints * top level imports * watermarking example * direct imports * tpr fpr meaning * process_kwargs * SynthIDTextWatermarkingConfig docstring * assert -> exception * example updates * no immutable dict (cant be serialized) * pack fn * einsum equivalent * import order * fix test on gpu * add detector example --------- Co-authored-by: Sumedh Ghaisas <sumedhg@google.com> Co-authored-by: Marc Sun <marc@huggingface.co> Co-authored-by: sumedhghaisas2 <138781311+sumedhghaisas2@users.noreply.github.com> Co-authored-by: raushan <raushan@huggingface.co>
This commit is contained in:
@@ -84,6 +84,7 @@ if is_torch_available():
|
||||
SampleEncoderDecoderOutput,
|
||||
StoppingCriteria,
|
||||
StoppingCriteriaList,
|
||||
SynthIDTextWatermarkingConfig,
|
||||
WatermarkDetector,
|
||||
WatermarkingConfig,
|
||||
)
|
||||
@@ -2517,9 +2518,9 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
||||
|
||||
@slow
|
||||
def test_watermark_generation(self):
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
|
||||
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2").to(torch_device)
|
||||
def test_green_red_watermark_generation(self):
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
model_inputs = tokenizer("I will be", return_tensors="pt").to(torch_device)
|
||||
input_len = model_inputs["input_ids"].shape[-1]
|
||||
@@ -2548,6 +2549,61 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
self.assertListEqual(detection_out_watermarked.prediction.tolist(), [True])
|
||||
self.assertListEqual(detection_out.prediction.tolist(), [False])
|
||||
|
||||
"""Check the mean bias inserted by the watermarking algorithm."""
|
||||
|
||||
@slow
|
||||
def test_synthid_text_watermark_generation_mean_expected_bias(self):
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
model_inputs = tokenizer("I will be", return_tensors="pt").to(torch_device)
|
||||
input_len = 5
|
||||
batch_size = 200
|
||||
|
||||
# generation should work with both input types: WatermarkingConfig or Dict, so let's check it here :)
|
||||
watermark_config = SynthIDTextWatermarkingConfig(keys=[10, 20], ngram_len=5, debug_mode=True)
|
||||
logits_processor = watermark_config.construct_processor(model.config.vocab_size, torch_device)
|
||||
mean_g_values_repeats = []
|
||||
for _ in range(40):
|
||||
input_ids = torch.zeros(
|
||||
(batch_size, input_len),
|
||||
dtype=torch.int64,
|
||||
device=torch_device,
|
||||
)
|
||||
model_inputs = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": torch.ones_like(input_ids, device=torch_device),
|
||||
}
|
||||
output = model.generate(
|
||||
**model_inputs, watermarking_config=watermark_config, do_sample=True, max_length=500, top_k=1000
|
||||
)
|
||||
g_values = logits_processor.compute_g_values(input_ids=output[:, input_len:])
|
||||
context_repetition_mask = logits_processor.compute_context_repetition_mask(
|
||||
input_ids=output[:, input_len:],
|
||||
).unsqueeze(dim=2)
|
||||
|
||||
mean_g_values = torch.masked.mean(
|
||||
g_values,
|
||||
mask=context_repetition_mask,
|
||||
dim=0,
|
||||
keepdim=True,
|
||||
dtype=torch.float64,
|
||||
)
|
||||
mean_g_values_repeats.append(mean_g_values)
|
||||
|
||||
mean_g_values = torch.concat(mean_g_values_repeats, dim=0).mean(dim=0)
|
||||
expected_mean_g_value = logits_processor.expected_mean_g_value(
|
||||
vocab_size=model.config.vocab_size,
|
||||
)
|
||||
atol = 0.03
|
||||
is_close = torch.isclose(
|
||||
mean_g_values,
|
||||
torch.tensor(expected_mean_g_value, dtype=torch.float64),
|
||||
atol=atol,
|
||||
rtol=0,
|
||||
)
|
||||
self.assertTrue(torch.all(is_close))
|
||||
|
||||
@slow
|
||||
def test_beam_search_example_integration(self):
|
||||
# PT-only test: TF doesn't have a BeamSearchScorer
|
||||
|
||||
Reference in New Issue
Block a user