Add SynthID (watermerking by Google DeepMind) (#34350)

* Add SynthIDTextWatermarkLogitsProcessor

* esolving comments.

* Resolving comments.

* esolving commits,

* Improving SynthIDWatermark tests.

* switch to PT version

* detector as pretrained model + style

* update training + style

* rebase

* Update logits_process.py

* Improving SynthIDWatermark tests.

* Shift detector training to wikitext negatives and stabilize with lower learning rate.

* Clean up.

* in for 7B

* cleanup

* upport python 3.8.

* README and final cleanup.

* HF Hub upload and initiaze.

* Update requirements for synthid_text.

* Adding SynthIDTextWatermarkDetector.

* Detector testing.

* Documentation changes.

* Copyrights fix.

* Fix detector api.

* ironing out errors

* ironing out errors

* training checks

* make fixup and make fix-copies

* docstrings and add to docs

* copyright

* BC

* test docstrings

* move import

* protect type hints

* top level imports

* watermarking example

* direct imports

* tpr fpr meaning

* process_kwargs

* SynthIDTextWatermarkingConfig docstring

* assert -> exception

* example updates

* no immutable dict (cant be serialized)

* pack fn

* einsum equivalent

* import order

* fix test on gpu

* add detector example

---------

Co-authored-by: Sumedh Ghaisas <sumedhg@google.com>
Co-authored-by: Marc Sun <marc@huggingface.co>
Co-authored-by: sumedhghaisas2 <138781311+sumedhghaisas2@users.noreply.github.com>
Co-authored-by: raushan <raushan@huggingface.co>
This commit is contained in:
Joao Gante
2024-10-23 21:18:52 +01:00
committed by GitHub
parent e50bf61dec
commit b0f0c61899
15 changed files with 2238 additions and 80 deletions

View File

@@ -16,6 +16,7 @@
import unittest
from typing import List, Union
import numpy as np
from parameterized import parameterized
from transformers import is_torch_available
@@ -48,6 +49,7 @@ if is_torch_available():
PrefixConstrainedLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
SequenceBiasLogitsProcessor,
SynthIDTextWatermarkLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
@@ -975,3 +977,187 @@ class LogitsProcessorTest(unittest.TestCase):
scores_wo_bias = scores[:, -1].clone()
out = watermark(input_ids=input_ids, scores=scores)
self.assertTrue((out[:, 1] == scores_wo_bias + watermark.bias).all())
@parameterized.expand([(5, 3, 10000), (10, 5, 1000)])
def test_synthidtext_watermarking_processor_bias_uniformity(self, ngram_len, num_layers, vocab_size):
"""Test SynthID watermarked distribution bias uniformity over iterations."""
torch.manual_seed(0)
np.random.seed(0)
watermarking_config = {
"ngram_len": ngram_len,
"keys": np.random.randint(low=0, high=2**16, size=(num_layers,)),
"sampling_table_size": 2**16,
"sampling_table_seed": 0,
"context_history_size": 512,
"device": torch_device,
}
batch_size = 100000
ngrams = torch.randint(
low=0,
high=vocab_size,
size=(batch_size, ngram_len),
device=torch_device,
)
logits_processor = SynthIDTextWatermarkLogitsProcessor(**watermarking_config)
g_values = logits_processor.compute_g_values(ngrams)
g_values_mean = torch.mean(torch.mean(g_values.float(), dim=0))
self.assertAlmostEqual(g_values_mean, 0.5, delta=0.01)
@parameterized.expand([(10000, 3), (1000, 20)])
def test_synthidtext_watermark_processor_bias_uniformity_across_vocab(self, vocab_size, num_layers):
"""Test SynthID watermarked distribution bias uniformity over vocabs of the model."""
batch_size = 1000
ngram_len = 5
torch.manual_seed(0)
np.random.seed(0)
watermarking_config = {
"ngram_len": ngram_len,
"keys": np.random.randint(low=0, high=2**16, size=(num_layers,)),
"sampling_table_size": 2**16,
"sampling_table_seed": 0,
"context_history_size": 512,
"device": torch_device,
}
n_minus_1_grams = torch.randint(
low=0,
high=vocab_size,
size=(batch_size, watermarking_config["ngram_len"] - 1),
device=torch_device,
)
logits_processor = SynthIDTextWatermarkLogitsProcessor(**watermarking_config)
ngram_keys, _ = logits_processor._compute_keys(
n_minus_1_grams,
torch.stack([torch.arange(vocab_size, device=torch_device) for _ in range(batch_size)]),
)
g_values = logits_processor.sample_g_values(ngram_keys)
# g_values shape should be [batch_size, vocab_size, num_layers]
g_values_mean = torch.mean(torch.mean(g_values.float(), dim=1))
self.assertAlmostEqual(g_values_mean, 0.5, delta=0.001)
@parameterized.expand([(2, "uniform"), (10, "uniform"), (2, "random"), (10, "random")])
def test_synthidtext_watermark_processor_distributional_convergence(self, vocab_size, logits_type):
"""Check if watermarked distribution converges to unwatermarked logits distribution."""
batch_size = 1500
num_keys = 1000
updated_softmaxes = 0
np.random.seed(0)
torch.manual_seed(0)
if logits_type == "uniform":
fixed_logits = torch.ones((batch_size, vocab_size), device=torch_device)
elif logits_type == "random":
fixed_logits = torch.rand(
(
1,
vocab_size,
),
device=torch_device,
)
fixed_logits = fixed_logits.repeat(batch_size, 1)
else:
raise ValueError(f"Unrecognized logits_type {logits_type}")
for _ in range(num_keys):
watermarking_config = {
"ngram_len": 5,
"keys": np.random.randint(0, 10**9, size=(1,), dtype=np.int64),
"sampling_table_size": 2**16,
"sampling_table_seed": 0,
"context_history_size": 1024,
"device": torch_device,
}
logits_processor = SynthIDTextWatermarkLogitsProcessor(**watermarking_config)
ngrams = torch.randint(
low=0,
high=vocab_size,
size=(batch_size, watermarking_config["ngram_len"]),
device=torch_device,
)
# Insert ngram-1 into logit_processor state.
for idx in range(watermarking_config["ngram_len"] - 1):
_ = logits_processor(ngrams[:, :idx], fixed_logits)
updated_scores = logits_processor(ngrams, fixed_logits)
updated_softmaxes += torch.nn.functional.softmax(updated_scores, dim=1).cpu().numpy()
updated_softmaxes = np.mean(updated_softmaxes, axis=0) / num_keys
is_close = torch.all(
torch.isclose(
torch.tensor(updated_softmaxes, device=torch_device),
torch.nn.Softmax()(fixed_logits[0]), # Take any batch entry, all are same.
atol=1e-3,
rtol=0,
)
)
self.assertTrue(is_close)
@parameterized.expand([(2, 10, 1, 0.01), (100, 5, 1, 0.01), (100, 10, 2, 0.02)])
def test_synthidtext_watermark_processor_bias_test(self, vocab_size, ngram_len, num_layers, atol):
"""Test SynthID watermarking bias matches theoretical value."""
batch_size = 20000
generator = torch.Generator(device=torch_device).manual_seed(0)
np.random.seed(0)
keys = [np.random.randint(0, 10**9) for _ in range(num_layers)]
# Use 10**9 rather than vocab_size to ensure variety in (n-1)-grams.
context = torch.randint(
low=0,
high=10**9,
size=(batch_size, ngram_len - 1),
dtype=torch.int64,
generator=generator,
device=torch_device,
)
context_history_size = 1024
logits_processor = SynthIDTextWatermarkLogitsProcessor(
ngram_len=ngram_len,
keys=keys,
sampling_table_size=2**16,
sampling_table_seed=0,
context_history_size=context_history_size,
device=torch_device,
)
scores = torch.ones(
(batch_size, vocab_size),
dtype=torch.float64,
device=torch_device,
)
# Init state of the logits processor.
logits_processor(context, scores)
# insert context into the state.
for idx in range(1, ngram_len - 1):
_ = logits_processor(context[:, :idx], scores)
updated_scores = logits_processor(context, scores)
probs = torch.nn.functional.softmax(updated_scores, dim=1)
generator = torch.Generator(device=torch_device).manual_seed(0)
next_tokens = torch.multinomial(
probs,
num_samples=1,
generator=generator,
)
ngrams = torch.concat((context, next_tokens), dim=1)
g_values = logits_processor.compute_g_values(ngrams)
mean_g_values = g_values.mean(dtype=torch.float64, dim=(0, 1))
expected_mean_g_value = logits_processor.expected_mean_g_value(
vocab_size=vocab_size,
)
is_close = torch.all(
torch.isclose(
mean_g_values,
torch.tensor(expected_mean_g_value, dtype=torch.float64, device=torch_device),
atol=atol,
rtol=0,
)
)
self.assertTrue(is_close)