Add SynthID (watermerking by Google DeepMind) (#34350)

* Add SynthIDTextWatermarkLogitsProcessor

* esolving comments.

* Resolving comments.

* esolving commits,

* Improving SynthIDWatermark tests.

* switch to PT version

* detector as pretrained model + style

* update training + style

* rebase

* Update logits_process.py

* Improving SynthIDWatermark tests.

* Shift detector training to wikitext negatives and stabilize with lower learning rate.

* Clean up.

* in for 7B

* cleanup

* upport python 3.8.

* README and final cleanup.

* HF Hub upload and initiaze.

* Update requirements for synthid_text.

* Adding SynthIDTextWatermarkDetector.

* Detector testing.

* Documentation changes.

* Copyrights fix.

* Fix detector api.

* ironing out errors

* ironing out errors

* training checks

* make fixup and make fix-copies

* docstrings and add to docs

* copyright

* BC

* test docstrings

* move import

* protect type hints

* top level imports

* watermarking example

* direct imports

* tpr fpr meaning

* process_kwargs

* SynthIDTextWatermarkingConfig docstring

* assert -> exception

* example updates

* no immutable dict (cant be serialized)

* pack fn

* einsum equivalent

* import order

* fix test on gpu

* add detector example

---------

Co-authored-by: Sumedh Ghaisas <sumedhg@google.com>
Co-authored-by: Marc Sun <marc@huggingface.co>
Co-authored-by: sumedhghaisas2 <138781311+sumedhghaisas2@users.noreply.github.com>
Co-authored-by: raushan <raushan@huggingface.co>
This commit is contained in:
Joao Gante
2024-10-23 21:18:52 +01:00
committed by GitHub
parent e50bf61dec
commit b0f0c61899
15 changed files with 2238 additions and 80 deletions

View File

@@ -16,6 +16,7 @@
import unittest
from typing import List, Union
import numpy as np
from parameterized import parameterized
from transformers import is_torch_available
@@ -48,6 +49,7 @@ if is_torch_available():
PrefixConstrainedLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
SequenceBiasLogitsProcessor,
SynthIDTextWatermarkLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
@@ -975,3 +977,187 @@ class LogitsProcessorTest(unittest.TestCase):
scores_wo_bias = scores[:, -1].clone()
out = watermark(input_ids=input_ids, scores=scores)
self.assertTrue((out[:, 1] == scores_wo_bias + watermark.bias).all())
@parameterized.expand([(5, 3, 10000), (10, 5, 1000)])
def test_synthidtext_watermarking_processor_bias_uniformity(self, ngram_len, num_layers, vocab_size):
"""Test SynthID watermarked distribution bias uniformity over iterations."""
torch.manual_seed(0)
np.random.seed(0)
watermarking_config = {
"ngram_len": ngram_len,
"keys": np.random.randint(low=0, high=2**16, size=(num_layers,)),
"sampling_table_size": 2**16,
"sampling_table_seed": 0,
"context_history_size": 512,
"device": torch_device,
}
batch_size = 100000
ngrams = torch.randint(
low=0,
high=vocab_size,
size=(batch_size, ngram_len),
device=torch_device,
)
logits_processor = SynthIDTextWatermarkLogitsProcessor(**watermarking_config)
g_values = logits_processor.compute_g_values(ngrams)
g_values_mean = torch.mean(torch.mean(g_values.float(), dim=0))
self.assertAlmostEqual(g_values_mean, 0.5, delta=0.01)
@parameterized.expand([(10000, 3), (1000, 20)])
def test_synthidtext_watermark_processor_bias_uniformity_across_vocab(self, vocab_size, num_layers):
"""Test SynthID watermarked distribution bias uniformity over vocabs of the model."""
batch_size = 1000
ngram_len = 5
torch.manual_seed(0)
np.random.seed(0)
watermarking_config = {
"ngram_len": ngram_len,
"keys": np.random.randint(low=0, high=2**16, size=(num_layers,)),
"sampling_table_size": 2**16,
"sampling_table_seed": 0,
"context_history_size": 512,
"device": torch_device,
}
n_minus_1_grams = torch.randint(
low=0,
high=vocab_size,
size=(batch_size, watermarking_config["ngram_len"] - 1),
device=torch_device,
)
logits_processor = SynthIDTextWatermarkLogitsProcessor(**watermarking_config)
ngram_keys, _ = logits_processor._compute_keys(
n_minus_1_grams,
torch.stack([torch.arange(vocab_size, device=torch_device) for _ in range(batch_size)]),
)
g_values = logits_processor.sample_g_values(ngram_keys)
# g_values shape should be [batch_size, vocab_size, num_layers]
g_values_mean = torch.mean(torch.mean(g_values.float(), dim=1))
self.assertAlmostEqual(g_values_mean, 0.5, delta=0.001)
@parameterized.expand([(2, "uniform"), (10, "uniform"), (2, "random"), (10, "random")])
def test_synthidtext_watermark_processor_distributional_convergence(self, vocab_size, logits_type):
"""Check if watermarked distribution converges to unwatermarked logits distribution."""
batch_size = 1500
num_keys = 1000
updated_softmaxes = 0
np.random.seed(0)
torch.manual_seed(0)
if logits_type == "uniform":
fixed_logits = torch.ones((batch_size, vocab_size), device=torch_device)
elif logits_type == "random":
fixed_logits = torch.rand(
(
1,
vocab_size,
),
device=torch_device,
)
fixed_logits = fixed_logits.repeat(batch_size, 1)
else:
raise ValueError(f"Unrecognized logits_type {logits_type}")
for _ in range(num_keys):
watermarking_config = {
"ngram_len": 5,
"keys": np.random.randint(0, 10**9, size=(1,), dtype=np.int64),
"sampling_table_size": 2**16,
"sampling_table_seed": 0,
"context_history_size": 1024,
"device": torch_device,
}
logits_processor = SynthIDTextWatermarkLogitsProcessor(**watermarking_config)
ngrams = torch.randint(
low=0,
high=vocab_size,
size=(batch_size, watermarking_config["ngram_len"]),
device=torch_device,
)
# Insert ngram-1 into logit_processor state.
for idx in range(watermarking_config["ngram_len"] - 1):
_ = logits_processor(ngrams[:, :idx], fixed_logits)
updated_scores = logits_processor(ngrams, fixed_logits)
updated_softmaxes += torch.nn.functional.softmax(updated_scores, dim=1).cpu().numpy()
updated_softmaxes = np.mean(updated_softmaxes, axis=0) / num_keys
is_close = torch.all(
torch.isclose(
torch.tensor(updated_softmaxes, device=torch_device),
torch.nn.Softmax()(fixed_logits[0]), # Take any batch entry, all are same.
atol=1e-3,
rtol=0,
)
)
self.assertTrue(is_close)
@parameterized.expand([(2, 10, 1, 0.01), (100, 5, 1, 0.01), (100, 10, 2, 0.02)])
def test_synthidtext_watermark_processor_bias_test(self, vocab_size, ngram_len, num_layers, atol):
"""Test SynthID watermarking bias matches theoretical value."""
batch_size = 20000
generator = torch.Generator(device=torch_device).manual_seed(0)
np.random.seed(0)
keys = [np.random.randint(0, 10**9) for _ in range(num_layers)]
# Use 10**9 rather than vocab_size to ensure variety in (n-1)-grams.
context = torch.randint(
low=0,
high=10**9,
size=(batch_size, ngram_len - 1),
dtype=torch.int64,
generator=generator,
device=torch_device,
)
context_history_size = 1024
logits_processor = SynthIDTextWatermarkLogitsProcessor(
ngram_len=ngram_len,
keys=keys,
sampling_table_size=2**16,
sampling_table_seed=0,
context_history_size=context_history_size,
device=torch_device,
)
scores = torch.ones(
(batch_size, vocab_size),
dtype=torch.float64,
device=torch_device,
)
# Init state of the logits processor.
logits_processor(context, scores)
# insert context into the state.
for idx in range(1, ngram_len - 1):
_ = logits_processor(context[:, :idx], scores)
updated_scores = logits_processor(context, scores)
probs = torch.nn.functional.softmax(updated_scores, dim=1)
generator = torch.Generator(device=torch_device).manual_seed(0)
next_tokens = torch.multinomial(
probs,
num_samples=1,
generator=generator,
)
ngrams = torch.concat((context, next_tokens), dim=1)
g_values = logits_processor.compute_g_values(ngrams)
mean_g_values = g_values.mean(dtype=torch.float64, dim=(0, 1))
expected_mean_g_value = logits_processor.expected_mean_g_value(
vocab_size=vocab_size,
)
is_close = torch.all(
torch.isclose(
mean_g_values,
torch.tensor(expected_mean_g_value, dtype=torch.float64, device=torch_device),
atol=atol,
rtol=0,
)
)
self.assertTrue(is_close)

View File

@@ -84,6 +84,7 @@ if is_torch_available():
SampleEncoderDecoderOutput,
StoppingCriteria,
StoppingCriteriaList,
SynthIDTextWatermarkingConfig,
WatermarkDetector,
WatermarkingConfig,
)
@@ -2517,9 +2518,9 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self.assertListEqual(low_output.tolist(), high_output.tolist())
@slow
def test_watermark_generation(self):
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2").to(torch_device)
def test_green_red_watermark_generation(self):
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
tokenizer.pad_token_id = tokenizer.eos_token_id
model_inputs = tokenizer("I will be", return_tensors="pt").to(torch_device)
input_len = model_inputs["input_ids"].shape[-1]
@@ -2548,6 +2549,61 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self.assertListEqual(detection_out_watermarked.prediction.tolist(), [True])
self.assertListEqual(detection_out.prediction.tolist(), [False])
"""Check the mean bias inserted by the watermarking algorithm."""
@slow
def test_synthid_text_watermark_generation_mean_expected_bias(self):
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
tokenizer.pad_token_id = tokenizer.eos_token_id
model_inputs = tokenizer("I will be", return_tensors="pt").to(torch_device)
input_len = 5
batch_size = 200
# generation should work with both input types: WatermarkingConfig or Dict, so let's check it here :)
watermark_config = SynthIDTextWatermarkingConfig(keys=[10, 20], ngram_len=5, debug_mode=True)
logits_processor = watermark_config.construct_processor(model.config.vocab_size, torch_device)
mean_g_values_repeats = []
for _ in range(40):
input_ids = torch.zeros(
(batch_size, input_len),
dtype=torch.int64,
device=torch_device,
)
model_inputs = {
"input_ids": input_ids,
"attention_mask": torch.ones_like(input_ids, device=torch_device),
}
output = model.generate(
**model_inputs, watermarking_config=watermark_config, do_sample=True, max_length=500, top_k=1000
)
g_values = logits_processor.compute_g_values(input_ids=output[:, input_len:])
context_repetition_mask = logits_processor.compute_context_repetition_mask(
input_ids=output[:, input_len:],
).unsqueeze(dim=2)
mean_g_values = torch.masked.mean(
g_values,
mask=context_repetition_mask,
dim=0,
keepdim=True,
dtype=torch.float64,
)
mean_g_values_repeats.append(mean_g_values)
mean_g_values = torch.concat(mean_g_values_repeats, dim=0).mean(dim=0)
expected_mean_g_value = logits_processor.expected_mean_g_value(
vocab_size=model.config.vocab_size,
)
atol = 0.03
is_close = torch.isclose(
mean_g_values,
torch.tensor(expected_mean_g_value, dtype=torch.float64),
atol=atol,
rtol=0,
)
self.assertTrue(torch.all(is_close))
@slow
def test_beam_search_example_integration(self):
# PT-only test: TF doesn't have a BeamSearchScorer