Add SynthID (watermerking by Google DeepMind) (#34350)
* Add SynthIDTextWatermarkLogitsProcessor * esolving comments. * Resolving comments. * esolving commits, * Improving SynthIDWatermark tests. * switch to PT version * detector as pretrained model + style * update training + style * rebase * Update logits_process.py * Improving SynthIDWatermark tests. * Shift detector training to wikitext negatives and stabilize with lower learning rate. * Clean up. * in for 7B * cleanup * upport python 3.8. * README and final cleanup. * HF Hub upload and initiaze. * Update requirements for synthid_text. * Adding SynthIDTextWatermarkDetector. * Detector testing. * Documentation changes. * Copyrights fix. * Fix detector api. * ironing out errors * ironing out errors * training checks * make fixup and make fix-copies * docstrings and add to docs * copyright * BC * test docstrings * move import * protect type hints * top level imports * watermarking example * direct imports * tpr fpr meaning * process_kwargs * SynthIDTextWatermarkingConfig docstring * assert -> exception * example updates * no immutable dict (cant be serialized) * pack fn * einsum equivalent * import order * fix test on gpu * add detector example --------- Co-authored-by: Sumedh Ghaisas <sumedhg@google.com> Co-authored-by: Marc Sun <marc@huggingface.co> Co-authored-by: sumedhghaisas2 <138781311+sumedhghaisas2@users.noreply.github.com> Co-authored-by: raushan <raushan@huggingface.co>
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
import unittest
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import is_torch_available
|
||||
@@ -48,6 +49,7 @@ if is_torch_available():
|
||||
PrefixConstrainedLogitsProcessor,
|
||||
RepetitionPenaltyLogitsProcessor,
|
||||
SequenceBiasLogitsProcessor,
|
||||
SynthIDTextWatermarkLogitsProcessor,
|
||||
TemperatureLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
@@ -975,3 +977,187 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
scores_wo_bias = scores[:, -1].clone()
|
||||
out = watermark(input_ids=input_ids, scores=scores)
|
||||
self.assertTrue((out[:, 1] == scores_wo_bias + watermark.bias).all())
|
||||
|
||||
@parameterized.expand([(5, 3, 10000), (10, 5, 1000)])
|
||||
def test_synthidtext_watermarking_processor_bias_uniformity(self, ngram_len, num_layers, vocab_size):
|
||||
"""Test SynthID watermarked distribution bias uniformity over iterations."""
|
||||
torch.manual_seed(0)
|
||||
np.random.seed(0)
|
||||
watermarking_config = {
|
||||
"ngram_len": ngram_len,
|
||||
"keys": np.random.randint(low=0, high=2**16, size=(num_layers,)),
|
||||
"sampling_table_size": 2**16,
|
||||
"sampling_table_seed": 0,
|
||||
"context_history_size": 512,
|
||||
"device": torch_device,
|
||||
}
|
||||
batch_size = 100000
|
||||
ngrams = torch.randint(
|
||||
low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, ngram_len),
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
logits_processor = SynthIDTextWatermarkLogitsProcessor(**watermarking_config)
|
||||
g_values = logits_processor.compute_g_values(ngrams)
|
||||
g_values_mean = torch.mean(torch.mean(g_values.float(), dim=0))
|
||||
self.assertAlmostEqual(g_values_mean, 0.5, delta=0.01)
|
||||
|
||||
@parameterized.expand([(10000, 3), (1000, 20)])
|
||||
def test_synthidtext_watermark_processor_bias_uniformity_across_vocab(self, vocab_size, num_layers):
|
||||
"""Test SynthID watermarked distribution bias uniformity over vocabs of the model."""
|
||||
batch_size = 1000
|
||||
ngram_len = 5
|
||||
torch.manual_seed(0)
|
||||
np.random.seed(0)
|
||||
watermarking_config = {
|
||||
"ngram_len": ngram_len,
|
||||
"keys": np.random.randint(low=0, high=2**16, size=(num_layers,)),
|
||||
"sampling_table_size": 2**16,
|
||||
"sampling_table_seed": 0,
|
||||
"context_history_size": 512,
|
||||
"device": torch_device,
|
||||
}
|
||||
n_minus_1_grams = torch.randint(
|
||||
low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, watermarking_config["ngram_len"] - 1),
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
logits_processor = SynthIDTextWatermarkLogitsProcessor(**watermarking_config)
|
||||
ngram_keys, _ = logits_processor._compute_keys(
|
||||
n_minus_1_grams,
|
||||
torch.stack([torch.arange(vocab_size, device=torch_device) for _ in range(batch_size)]),
|
||||
)
|
||||
|
||||
g_values = logits_processor.sample_g_values(ngram_keys)
|
||||
# g_values shape should be [batch_size, vocab_size, num_layers]
|
||||
g_values_mean = torch.mean(torch.mean(g_values.float(), dim=1))
|
||||
self.assertAlmostEqual(g_values_mean, 0.5, delta=0.001)
|
||||
|
||||
@parameterized.expand([(2, "uniform"), (10, "uniform"), (2, "random"), (10, "random")])
|
||||
def test_synthidtext_watermark_processor_distributional_convergence(self, vocab_size, logits_type):
|
||||
"""Check if watermarked distribution converges to unwatermarked logits distribution."""
|
||||
batch_size = 1500
|
||||
num_keys = 1000
|
||||
|
||||
updated_softmaxes = 0
|
||||
np.random.seed(0)
|
||||
torch.manual_seed(0)
|
||||
if logits_type == "uniform":
|
||||
fixed_logits = torch.ones((batch_size, vocab_size), device=torch_device)
|
||||
elif logits_type == "random":
|
||||
fixed_logits = torch.rand(
|
||||
(
|
||||
1,
|
||||
vocab_size,
|
||||
),
|
||||
device=torch_device,
|
||||
)
|
||||
fixed_logits = fixed_logits.repeat(batch_size, 1)
|
||||
else:
|
||||
raise ValueError(f"Unrecognized logits_type {logits_type}")
|
||||
for _ in range(num_keys):
|
||||
watermarking_config = {
|
||||
"ngram_len": 5,
|
||||
"keys": np.random.randint(0, 10**9, size=(1,), dtype=np.int64),
|
||||
"sampling_table_size": 2**16,
|
||||
"sampling_table_seed": 0,
|
||||
"context_history_size": 1024,
|
||||
"device": torch_device,
|
||||
}
|
||||
|
||||
logits_processor = SynthIDTextWatermarkLogitsProcessor(**watermarking_config)
|
||||
|
||||
ngrams = torch.randint(
|
||||
low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, watermarking_config["ngram_len"]),
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
# Insert ngram-1 into logit_processor state.
|
||||
for idx in range(watermarking_config["ngram_len"] - 1):
|
||||
_ = logits_processor(ngrams[:, :idx], fixed_logits)
|
||||
|
||||
updated_scores = logits_processor(ngrams, fixed_logits)
|
||||
updated_softmaxes += torch.nn.functional.softmax(updated_scores, dim=1).cpu().numpy()
|
||||
|
||||
updated_softmaxes = np.mean(updated_softmaxes, axis=0) / num_keys
|
||||
is_close = torch.all(
|
||||
torch.isclose(
|
||||
torch.tensor(updated_softmaxes, device=torch_device),
|
||||
torch.nn.Softmax()(fixed_logits[0]), # Take any batch entry, all are same.
|
||||
atol=1e-3,
|
||||
rtol=0,
|
||||
)
|
||||
)
|
||||
self.assertTrue(is_close)
|
||||
|
||||
@parameterized.expand([(2, 10, 1, 0.01), (100, 5, 1, 0.01), (100, 10, 2, 0.02)])
|
||||
def test_synthidtext_watermark_processor_bias_test(self, vocab_size, ngram_len, num_layers, atol):
|
||||
"""Test SynthID watermarking bias matches theoretical value."""
|
||||
batch_size = 20000
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
np.random.seed(0)
|
||||
|
||||
keys = [np.random.randint(0, 10**9) for _ in range(num_layers)]
|
||||
# Use 10**9 rather than vocab_size to ensure variety in (n-1)-grams.
|
||||
context = torch.randint(
|
||||
low=0,
|
||||
high=10**9,
|
||||
size=(batch_size, ngram_len - 1),
|
||||
dtype=torch.int64,
|
||||
generator=generator,
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
context_history_size = 1024
|
||||
logits_processor = SynthIDTextWatermarkLogitsProcessor(
|
||||
ngram_len=ngram_len,
|
||||
keys=keys,
|
||||
sampling_table_size=2**16,
|
||||
sampling_table_seed=0,
|
||||
context_history_size=context_history_size,
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
scores = torch.ones(
|
||||
(batch_size, vocab_size),
|
||||
dtype=torch.float64,
|
||||
device=torch_device,
|
||||
)
|
||||
# Init state of the logits processor.
|
||||
logits_processor(context, scores)
|
||||
# insert context into the state.
|
||||
for idx in range(1, ngram_len - 1):
|
||||
_ = logits_processor(context[:, :idx], scores)
|
||||
|
||||
updated_scores = logits_processor(context, scores)
|
||||
|
||||
probs = torch.nn.functional.softmax(updated_scores, dim=1)
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
next_tokens = torch.multinomial(
|
||||
probs,
|
||||
num_samples=1,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
ngrams = torch.concat((context, next_tokens), dim=1)
|
||||
g_values = logits_processor.compute_g_values(ngrams)
|
||||
mean_g_values = g_values.mean(dtype=torch.float64, dim=(0, 1))
|
||||
|
||||
expected_mean_g_value = logits_processor.expected_mean_g_value(
|
||||
vocab_size=vocab_size,
|
||||
)
|
||||
is_close = torch.all(
|
||||
torch.isclose(
|
||||
mean_g_values,
|
||||
torch.tensor(expected_mean_g_value, dtype=torch.float64, device=torch_device),
|
||||
atol=atol,
|
||||
rtol=0,
|
||||
)
|
||||
)
|
||||
self.assertTrue(is_close)
|
||||
|
||||
Reference in New Issue
Block a user