Add SynthID (watermerking by Google DeepMind) (#34350)

* Add SynthIDTextWatermarkLogitsProcessor

* esolving comments.

* Resolving comments.

* esolving commits,

* Improving SynthIDWatermark tests.

* switch to PT version

* detector as pretrained model + style

* update training + style

* rebase

* Update logits_process.py

* Improving SynthIDWatermark tests.

* Shift detector training to wikitext negatives and stabilize with lower learning rate.

* Clean up.

* in for 7B

* cleanup

* upport python 3.8.

* README and final cleanup.

* HF Hub upload and initiaze.

* Update requirements for synthid_text.

* Adding SynthIDTextWatermarkDetector.

* Detector testing.

* Documentation changes.

* Copyrights fix.

* Fix detector api.

* ironing out errors

* ironing out errors

* training checks

* make fixup and make fix-copies

* docstrings and add to docs

* copyright

* BC

* test docstrings

* move import

* protect type hints

* top level imports

* watermarking example

* direct imports

* tpr fpr meaning

* process_kwargs

* SynthIDTextWatermarkingConfig docstring

* assert -> exception

* example updates

* no immutable dict (cant be serialized)

* pack fn

* einsum equivalent

* import order

* fix test on gpu

* add detector example

---------

Co-authored-by: Sumedh Ghaisas <sumedhg@google.com>
Co-authored-by: Marc Sun <marc@huggingface.co>
Co-authored-by: sumedhghaisas2 <138781311+sumedhghaisas2@users.noreply.github.com>
Co-authored-by: raushan <raushan@huggingface.co>
This commit is contained in:
Joao Gante
2024-10-23 21:18:52 +01:00
committed by GitHub
parent e50bf61dec
commit b0f0c61899
15 changed files with 2238 additions and 80 deletions

View File

@@ -0,0 +1,34 @@
# SynthID Text
This project showcases the use of SynthIDText for watermarking LLMs. The code shown in this repo also
demostrates the training of the detector for detecting such watermarked text. This detector can be uploaded onto
a private HF hub repo (private for security reasons) and can be initialized again through pretrained model loading also shown in this script.
See our blog post: https://huggingface.co/blog/synthid-text
## Python version
User would need python 3.9 to run this example.
## Installation and running
Once you install transformers you would need to install requirements for this project through requirements.txt provided in this folder.
```
pip install -r requirements.txt
```
## To run the detector training
```
python detector_training.py --model_name=google/gemma-7b-it
```
Check the script for more parameters are are tunable and check out paper at link
https://www.nature.com/articles/s41586-024-08025-4 for more information on these parameters.
## Caveat
Make sure to run the training of the detector and the detection on the same hardware
CPU, GPU or TPU to get consistent results (we use detecterministic randomness which is hardware dependent).

View File

@@ -0,0 +1,502 @@
# coding=utf-8
# Copyright 2024 Google DeepMind.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import dataclasses
import enum
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BayesianDetectorConfig,
BayesianDetectorModel,
SynthIDTextWatermarkDetector,
SynthIDTextWatermarkingConfig,
SynthIDTextWatermarkLogitsProcessor,
)
from utils import (
get_tokenized_uwm_outputs,
get_tokenized_wm_outputs,
process_raw_model_outputs,
update_fn_if_fpr_tpr,
upload_model_to_hf,
)
@enum.unique
class ValidationMetric(enum.Enum):
"""Direction along the z-axis."""
TPR_AT_FPR = "tpr_at_fpr"
CROSS_ENTROPY = "cross_entropy"
@dataclasses.dataclass
class TrainingArguments:
"""Training arguments pertaining to the training loop itself."""
eval_metric: Optional[str] = dataclasses.field(
default=ValidationMetric.TPR_AT_FPR, metadata={"help": "The evaluation metric used."}
)
def train_detector(
detector: torch.nn.Module,
g_values: torch.Tensor,
mask: torch.Tensor,
watermarked: torch.Tensor,
epochs: int = 250,
learning_rate: float = 1e-3,
minibatch_size: int = 64,
seed: int = 0,
l2_weight: float = 0.0,
shuffle: bool = True,
g_values_val: Optional[torch.Tensor] = None,
mask_val: Optional[torch.Tensor] = None,
watermarked_val: Optional[torch.Tensor] = None,
verbose: bool = False,
validation_metric: ValidationMetric = ValidationMetric.TPR_AT_FPR,
) -> Tuple[Dict[str, Any], float]:
"""Trains a Bayesian detector model.
Args:
g_values: g-values of shape [num_train, seq_len, watermarking_depth].
mask: A binary array shape [num_train, seq_len] indicating which g-values
should be used. g-values with mask value 0 are discarded.
watermarked: A binary array of shape [num_train] indicating whether the
example is watermarked (0: unwatermarked, 1: watermarked).
epochs: Number of epochs to train for.
learning_rate: Learning rate for optimizer.
minibatch_size: Minibatch size for training. Note that a minibatch
requires ~ 32 * minibatch_size * seq_len * watermarked_depth *
watermarked_depth bits of memory.
seed: Seed for parameter initialization.
l2_weight: Weight to apply to L2 regularization for delta parameters.
shuffle: Whether to shuffle before training.
g_values_val: Validation g-values of shape [num_val, seq_len,
watermarking_depth].
mask_val: Validation mask of shape [num_val, seq_len].
watermarked_val: Validation watermark labels of shape [num_val].
verbose: Boolean indicating verbosity of training. If true, the loss will
be printed. Defaulted to False.
use_tpr_fpr_for_val: Whether to use TPR@FPR=1% as metric for validation.
If false, use cross entropy loss.
Returns:
Tuple of
training_history: Training history keyed by epoch number where the
values are
dictionaries containing the loss, validation loss, and model
parameters,
keyed by
'loss', 'val_loss', and 'params', respectively.
min_val_loss: Minimum validation loss achieved during training.
"""
# Set the random seed for reproducibility
torch.manual_seed(seed)
# Shuffle the data if required
if shuffle:
indices = torch.randperm(len(g_values))
g_values = g_values[indices]
mask = mask[indices]
watermarked = watermarked[indices]
# Initialize optimizer
optimizer = torch.optim.Adam(detector.parameters(), lr=learning_rate)
history = {}
min_val_loss = float("inf")
for epoch in range(epochs):
losses = []
detector.train()
num_batches = len(g_values) // minibatch_size
for i in range(0, len(g_values), minibatch_size):
end = i + minibatch_size
if end > len(g_values):
break
loss_batch_weight = l2_weight / num_batches
optimizer.zero_grad()
loss = detector(
g_values=g_values[i:end],
mask=mask[i:end],
labels=watermarked[i:end],
loss_batch_weight=loss_batch_weight,
)[1]
loss.backward()
optimizer.step()
losses.append(loss.item())
train_loss = sum(losses) / len(losses)
val_losses = []
if g_values_val is not None:
detector.eval()
if validation_metric == ValidationMetric.TPR_AT_FPR:
val_loss = update_fn_if_fpr_tpr(
detector,
g_values_val,
mask_val,
watermarked_val,
minibatch_size=minibatch_size,
)
else:
for i in range(0, len(g_values_val), minibatch_size):
end = i + minibatch_size
if end > len(g_values_val):
break
with torch.no_grad():
v_loss = detector(
g_values=g_values_val[i:end],
mask=mask_val[i:end],
labels=watermarked_val[i:end],
loss_batch_weight=0,
)[1]
val_losses.append(v_loss.item())
val_loss = sum(val_losses) / len(val_losses)
# Store training history
history[epoch + 1] = {"loss": train_loss, "val_loss": val_loss}
if verbose:
if val_loss is not None:
print(f"Epoch {epoch}: loss {loss} (train), {val_loss} (val)")
else:
print(f"Epoch {epoch}: loss {loss} (train)")
if val_loss is not None and val_loss < min_val_loss:
min_val_loss = val_loss
best_val_epoch = epoch
if verbose:
print(f"Best val Epoch: {best_val_epoch}, min_val_loss: {min_val_loss}")
return history, min_val_loss
def train_best_detector(
tokenized_wm_outputs: Union[List[np.ndarray], np.ndarray],
tokenized_uwm_outputs: Union[List[np.ndarray], np.ndarray],
logits_processor: SynthIDTextWatermarkLogitsProcessor,
tokenizer: Any,
torch_device: torch.device,
test_size: float = 0.3,
pos_truncation_length: Optional[int] = 200,
neg_truncation_length: Optional[int] = 100,
max_padded_length: int = 2300,
n_epochs: int = 50,
learning_rate: float = 2.1e-2,
l2_weights: np.ndarray = np.logspace(-3, -2, num=4),
verbose: bool = False,
validation_metric: ValidationMetric = ValidationMetric.TPR_AT_FPR,
):
"""Train and return the best detector given range of hyperparameters.
In practice, we have found that tuning pos_truncation_length,
neg_truncation_length, n_epochs, learning_rate and l2_weights can help
improve the performance of the detector. We reccommend tuning these
parameters for your data.
"""
l2_weights = list(l2_weights)
(
train_g_values,
train_masks,
train_labels,
cv_g_values,
cv_masks,
cv_labels,
) = process_raw_model_outputs(
logits_processor,
tokenizer,
pos_truncation_length,
neg_truncation_length,
max_padded_length,
tokenized_wm_outputs,
test_size,
tokenized_uwm_outputs,
torch_device,
)
best_detector = None
lowest_loss = float("inf")
val_losses = []
for l2_weight in l2_weights:
config = BayesianDetectorConfig(watermarking_depth=len(logits_processor.keys))
detector = BayesianDetectorModel(config).to(torch_device)
_, min_val_loss = train_detector(
detector=detector,
g_values=train_g_values,
mask=train_masks,
watermarked=train_labels,
g_values_val=cv_g_values,
mask_val=cv_masks,
watermarked_val=cv_labels,
learning_rate=learning_rate,
l2_weight=l2_weight,
epochs=n_epochs,
verbose=verbose,
validation_metric=validation_metric,
)
val_losses.append(min_val_loss)
if min_val_loss < lowest_loss:
lowest_loss = min_val_loss
best_detector = detector
return best_detector, lowest_loss
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
type=str,
default="google/gemma-2b-it",
help=("LM model to train the detector for."),
)
parser.add_argument(
"--temperature",
type=float,
default=1.0,
help=("Temperature to sample from the model."),
)
parser.add_argument(
"--top_k",
type=int,
default=40,
help=("Top K for sampling."),
)
parser.add_argument(
"--top_p",
type=float,
default=1.0,
help=("Top P for sampling."),
)
parser.add_argument(
"--num_negatives",
type=int,
default=10000,
help=("Number of negatives for detector training."),
)
parser.add_argument(
"--pos_batch_size",
type=int,
default=32,
help=("Batch size of watermarked positives while sampling."),
)
parser.add_argument(
"--num_pos_batch",
type=int,
default=313,
help=("Number of positive batches for training."),
)
parser.add_argument(
"--generation_length",
type=int,
default=512,
help=("Generation length for sampling."),
)
parser.add_argument(
"--save_model_to_hf_hub",
action="store_true",
help=("Whether to save the trained model HF hub. By default it will be a private repo."),
)
parser.add_argument(
"--load_from_hf_hub",
action="store_true",
help=(
"Whether to load trained detector model from HF Hub, make sure its the model trained on the same model "
"we are loading in the script."
),
)
parser.add_argument(
"--hf_hub_model_name",
type=str,
default=None,
help=("HF hub model name for loading of saving the model."),
)
parser.add_argument(
"--eval_detector_on_prompts",
action="store_true",
help=("Evaluate detector on a prompt and print probability of watermark."),
)
args = parser.parse_args()
model_name = args.model_name
temperature = args.temperature
top_k = args.top_k
top_p = args.top_p
num_negatives = args.num_negatives
pos_batch_size = args.pos_batch_size
num_pos_batch = args.num_pos_batch
if num_pos_batch < 10:
raise ValueError("--num_pos_batch should be greater than 10.")
generation_length = args.generation_length
save_model_to_hf_hub = args.save_model_to_hf_hub
load_from_hf_hub = args.load_from_hf_hub
repo_name = args.hf_hub_model_name
eval_detector_on_prompts = args.eval_detector_on_prompts
NEG_BATCH_SIZE = 32
# Truncate outputs to this length for training.
POS_TRUNCATION_LENGTH = 200
NEG_TRUNCATION_LENGTH = 100
# Pad trucated outputs to this length for equal shape across all batches.
MAX_PADDED_LENGTH = 1000
DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
if DEVICE.type not in ("cuda", "tpu"):
raise ValueError("We have found the training stable on GPU and TPU, we are working on" " a fix for CPUs")
model = None
if not load_from_hf_hub:
# Change this to make your watermark unique. Check documentation in the paper to understand the
# impact of these parameters.
DEFAULT_WATERMARKING_CONFIG = {
"ngram_len": 5, # This corresponds to H=4 context window size in the paper.
"keys": [
654,
400,
836,
123,
340,
443,
597,
160,
57,
29,
590,
639,
13,
715,
468,
990,
966,
226,
324,
585,
118,
504,
421,
521,
129,
669,
732,
225,
90,
960,
],
"sampling_table_size": 2**16,
"sampling_table_seed": 0,
"context_history_size": 1024,
}
watermark_config = SynthIDTextWatermarkingConfig(**DEFAULT_WATERMARKING_CONFIG)
model = AutoModelForCausalLM.from_pretrained(model_name).to(DEVICE)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
logits_processor = SynthIDTextWatermarkLogitsProcessor(**DEFAULT_WATERMARKING_CONFIG, device=DEVICE)
tokenized_wm_outputs = get_tokenized_wm_outputs(
model,
tokenizer,
watermark_config,
num_pos_batch,
pos_batch_size,
temperature,
generation_length,
top_k,
top_p,
DEVICE,
)
tokenized_uwm_outputs = get_tokenized_uwm_outputs(num_negatives, NEG_BATCH_SIZE, tokenizer, DEVICE)
best_detector, lowest_loss = train_best_detector(
tokenized_wm_outputs=tokenized_wm_outputs,
tokenized_uwm_outputs=tokenized_uwm_outputs,
logits_processor=logits_processor,
tokenizer=tokenizer,
torch_device=DEVICE,
test_size=0.3,
pos_truncation_length=POS_TRUNCATION_LENGTH,
neg_truncation_length=NEG_TRUNCATION_LENGTH,
max_padded_length=MAX_PADDED_LENGTH,
n_epochs=100,
learning_rate=3e-3,
l2_weights=[
0,
],
verbose=True,
validation_metric=ValidationMetric.TPR_AT_FPR,
)
else:
if repo_name is None:
raise ValueError("When loading from pretrained detector model name cannot be None.")
best_detector = BayesianDetectorModel.from_pretrained(repo_name).to(DEVICE)
best_detector.config.set_detector_information(
model_name=model_name, watermarking_config=DEFAULT_WATERMARKING_CONFIG
)
if save_model_to_hf_hub:
upload_model_to_hf(best_detector, repo_name)
# Evaluate model response with the detector
if eval_detector_on_prompts:
model_name = best_detector.config.model_name
watermark_config_dict = best_detector.config.watermarking_config
logits_processor = SynthIDTextWatermarkLogitsProcessor(**watermark_config_dict, device=DEVICE)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
synthid_text_detector = SynthIDTextWatermarkDetector(best_detector, logits_processor, tokenizer)
if model is None:
model = AutoModelForCausalLM.from_pretrained(model_name).to(DEVICE)
watermarking_config = SynthIDTextWatermarkingConfig(**watermark_config_dict)
prompts = ["Write a essay on cats."]
inputs = tokenizer(
prompts,
return_tensors="pt",
padding=True,
).to(DEVICE)
_, inputs_len = inputs["input_ids"].shape
outputs = model.generate(
**inputs,
watermarking_config=watermarking_config,
do_sample=True,
max_length=inputs_len + generation_length,
temperature=temperature,
top_k=40,
top_p=1.0,
)
outputs = outputs[:, inputs_len:]
result = synthid_text_detector(outputs)
# You should set this based on expected fpr (false positive rate) and tpr (true positive rate).
# Check our demo at HF Spaces for more info.
upper_threshold = 0.95
lower_threshold = 0.12
if result[0][0] > upper_threshold:
print("The text is watermarked.")
elif lower_threshold < result[0][0] < upper_threshold:
print("It is hard to determine if the text is watermarked or not.")
else:
print("The text is not watermarked.")

View File

@@ -0,0 +1,5 @@
tensorflow-datasets>=4.9.3
torch >= 1.3
datasets
scikit-learn
tensorflow

View File

@@ -0,0 +1,408 @@
# coding=utf-8
# Copyright 2024 Google DeepMind.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
from typing import Any, List, Optional, Tuple
import datasets
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
import torch
import tqdm
from huggingface_hub import HfApi, create_repo
from huggingface_hub.utils import RepositoryNotFoundError
from sklearn import model_selection
import transformers
def pad_to_len(
arr: torch.Tensor,
target_len: int,
left_pad: bool,
eos_token: int,
device: torch.device,
) -> torch.Tensor:
"""Pad or truncate array to given length."""
if arr.shape[1] < target_len:
shape_for_ones = list(arr.shape)
shape_for_ones[1] = target_len - shape_for_ones[1]
padded = (
torch.ones(
shape_for_ones,
device=device,
dtype=torch.long,
)
* eos_token
)
if not left_pad:
arr = torch.concatenate((arr, padded), dim=1)
else:
arr = torch.concatenate((padded, arr), dim=1)
else:
arr = arr[:, :target_len]
return arr
def filter_and_truncate(
outputs: torch.Tensor,
truncation_length: Optional[int],
eos_token_mask: torch.Tensor,
) -> torch.Tensor:
"""Filter and truncate outputs to given length.
Args:
outputs: output tensor of shape [batch_size, output_len]
truncation_length: Length to truncate the final output.
eos_token_mask: EOS token mask of shape [batch_size, output_len]
Returns:
output tensor of shape [batch_size, truncation_length].
"""
if truncation_length:
outputs = outputs[:, :truncation_length]
truncation_mask = torch.sum(eos_token_mask, dim=1) >= truncation_length
return outputs[truncation_mask, :]
return outputs
def process_outputs_for_training(
all_outputs: List[torch.Tensor],
logits_processor: transformers.generation.SynthIDTextWatermarkLogitsProcessor,
tokenizer: Any,
pos_truncation_length: Optional[int],
neg_truncation_length: Optional[int],
max_length: int,
is_cv: bool,
is_pos: bool,
torch_device: torch.device,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
"""Process raw model outputs into format understandable by the detector.
Args:
all_outputs: sequence of outputs of shape [batch_size, output_len].
logits_processor: logits processor used for watermarking.
tokenizer: tokenizer used for the model.
pos_truncation_length: Length to truncate wm outputs.
neg_truncation_length: Length to truncate uwm outputs.
max_length: Length to pad truncated outputs so that all processed entries.
have same shape.
is_cv: Process given outputs for cross validation.
is_pos: Process given outputs for positives.
torch_device: torch device to use.
Returns:
Tuple of
all_masks: list of masks of shape [batch_size, max_length].
all_g_values: list of g_values of shape [batch_size, max_length, depth].
"""
all_masks = []
all_g_values = []
for outputs in tqdm.tqdm(all_outputs):
# outputs is of shape [batch_size, output_len].
# output_len can differ from batch to batch.
eos_token_mask = logits_processor.compute_eos_token_mask(
input_ids=outputs,
eos_token_id=tokenizer.eos_token_id,
)
if is_pos or is_cv:
# filter with length for positives for both train and CV.
# We also filter for length when CV negatives are processed.
outputs = filter_and_truncate(outputs, pos_truncation_length, eos_token_mask)
elif not is_pos and not is_cv:
outputs = filter_and_truncate(outputs, neg_truncation_length, eos_token_mask)
# If no filtered outputs skip this batch.
if outputs.shape[0] == 0:
continue
# All outputs are padded to max-length with eos-tokens.
outputs = pad_to_len(outputs, max_length, False, tokenizer.eos_token_id, torch_device)
# outputs shape [num_filtered_entries, max_length]
eos_token_mask = logits_processor.compute_eos_token_mask(
input_ids=outputs,
eos_token_id=tokenizer.eos_token_id,
)
context_repetition_mask = logits_processor.compute_context_repetition_mask(
input_ids=outputs,
)
# context_repetition_mask of shape [num_filtered_entries, max_length -
# (ngram_len - 1)].
context_repetition_mask = pad_to_len(context_repetition_mask, max_length, True, 0, torch_device)
# We pad on left to get same max_length shape.
# context_repetition_mask of shape [num_filtered_entries, max_length].
combined_mask = context_repetition_mask * eos_token_mask
g_values = logits_processor.compute_g_values(
input_ids=outputs,
)
# g_values of shape [num_filtered_entries, max_length - (ngram_len - 1),
# depth].
g_values = pad_to_len(g_values, max_length, True, 0, torch_device)
# We pad on left to get same max_length shape.
# g_values of shape [num_filtered_entries, max_length, depth].
all_masks.append(combined_mask)
all_g_values.append(g_values)
return all_masks, all_g_values
def tpr_at_fpr(detector, detector_inputs, w_true, minibatch_size, target_fpr=0.01) -> torch.Tensor:
"""Calculates true positive rate (TPR) at false positive rate (FPR)=target_fpr."""
positive_idxs = w_true == 1
negative_idxs = w_true == 0
num_samples = detector_inputs[0].size(0)
w_preds = []
for start in range(0, num_samples, minibatch_size):
end = start + minibatch_size
detector_inputs_ = (
detector_inputs[0][start:end],
detector_inputs[1][start:end],
)
with torch.no_grad():
w_pred = detector(*detector_inputs_)[0]
w_preds.append(w_pred)
w_pred = torch.cat(w_preds, dim=0) # Concatenate predictions
positive_scores = w_pred[positive_idxs]
negative_scores = w_pred[negative_idxs]
# Calculate the FPR threshold
# Note: percentile -> quantile
fpr_threshold = torch.quantile(negative_scores, 1 - target_fpr)
# Note: need to switch to FP32 since torch.mean doesn't work with torch.bool
return torch.mean((positive_scores >= fpr_threshold).to(dtype=torch.float32)).item() # TPR
def update_fn_if_fpr_tpr(detector, g_values_val, mask_val, watermarked_val, minibatch_size):
"""Loss function for negative TPR@FPR=1% as the validation loss."""
tpr_ = tpr_at_fpr(
detector=detector,
detector_inputs=(g_values_val, mask_val),
w_true=watermarked_val,
minibatch_size=minibatch_size,
)
return -tpr_
def process_raw_model_outputs(
logits_processor,
tokenizer,
pos_truncation_length,
neg_truncation_length,
max_padded_length,
tokenized_wm_outputs,
test_size,
tokenized_uwm_outputs,
torch_device,
):
# Split data into train and CV
train_wm_outputs, cv_wm_outputs = model_selection.train_test_split(tokenized_wm_outputs, test_size=test_size)
train_uwm_outputs, cv_uwm_outputs = model_selection.train_test_split(tokenized_uwm_outputs, test_size=test_size)
process_kwargs = {
"logits_processor": logits_processor,
"tokenizer": tokenizer,
"pos_truncation_length": pos_truncation_length,
"neg_truncation_length": neg_truncation_length,
"max_length": max_padded_length,
"torch_device": torch_device,
}
# Process both train and CV data for training
wm_masks_train, wm_g_values_train = process_outputs_for_training(
[torch.tensor(outputs, device=torch_device, dtype=torch.long) for outputs in train_wm_outputs],
is_pos=True,
is_cv=False,
**process_kwargs,
)
wm_masks_cv, wm_g_values_cv = process_outputs_for_training(
[torch.tensor(outputs, device=torch_device, dtype=torch.long) for outputs in cv_wm_outputs],
is_pos=True,
is_cv=True,
**process_kwargs,
)
uwm_masks_train, uwm_g_values_train = process_outputs_for_training(
[torch.tensor(outputs, device=torch_device, dtype=torch.long) for outputs in train_uwm_outputs],
is_pos=False,
is_cv=False,
**process_kwargs,
)
uwm_masks_cv, uwm_g_values_cv = process_outputs_for_training(
[torch.tensor(outputs, device=torch_device, dtype=torch.long) for outputs in cv_uwm_outputs],
is_pos=False,
is_cv=True,
**process_kwargs,
)
# We get list of data; here we concat all together to be passed to the detector.
def pack(mask, g_values):
mask = torch.cat(mask, dim=0)
g = torch.cat(g_values, dim=0)
return mask, g
wm_masks_train, wm_g_values_train = pack(wm_masks_train, wm_g_values_train)
# Note: Use float instead of bool. Otherwise, the entropy calculation doesn't work
wm_labels_train = torch.ones((wm_masks_train.shape[0],), dtype=torch.float, device=torch_device)
wm_masks_cv, wm_g_values_cv = pack(wm_masks_cv, wm_g_values_cv)
wm_labels_cv = torch.ones((wm_masks_cv.shape[0],), dtype=torch.float, device=torch_device)
uwm_masks_train, uwm_g_values_train = pack(uwm_masks_train, uwm_g_values_train)
uwm_labels_train = torch.zeros((uwm_masks_train.shape[0],), dtype=torch.float, device=torch_device)
uwm_masks_cv, uwm_g_values_cv = pack(uwm_masks_cv, uwm_g_values_cv)
uwm_labels_cv = torch.zeros((uwm_masks_cv.shape[0],), dtype=torch.float, device=torch_device)
# Concat pos and negatives data together.
train_g_values = torch.cat((wm_g_values_train, uwm_g_values_train), dim=0).squeeze()
train_labels = torch.cat((wm_labels_train, uwm_labels_train), axis=0).squeeze()
train_masks = torch.cat((wm_masks_train, uwm_masks_train), axis=0).squeeze()
cv_g_values = torch.cat((wm_g_values_cv, uwm_g_values_cv), axis=0).squeeze()
cv_labels = torch.cat((wm_labels_cv, uwm_labels_cv), axis=0).squeeze()
cv_masks = torch.cat((wm_masks_cv, uwm_masks_cv), axis=0).squeeze()
# Shuffle data.
shuffled_idx = torch.randperm(train_g_values.shape[0]) # Use torch for GPU compatibility
train_g_values = train_g_values[shuffled_idx]
train_labels = train_labels[shuffled_idx]
train_masks = train_masks[shuffled_idx]
# Shuffle the cross-validation data
shuffled_idx_cv = torch.randperm(cv_g_values.shape[0]) # Use torch for GPU compatibility
cv_g_values = cv_g_values[shuffled_idx_cv]
cv_labels = cv_labels[shuffled_idx_cv]
cv_masks = cv_masks[shuffled_idx_cv]
# Del some variables so we free up GPU memory.
del (
wm_g_values_train,
wm_labels_train,
wm_masks_train,
wm_g_values_cv,
wm_labels_cv,
wm_masks_cv,
)
gc.collect()
torch.cuda.empty_cache()
return train_g_values, train_masks, train_labels, cv_g_values, cv_masks, cv_labels
def get_tokenized_uwm_outputs(num_negatives, neg_batch_size, tokenizer, device):
dataset, info = tfds.load("wikipedia/20230601.en", split="train", with_info=True)
dataset = dataset.take(num_negatives)
# Convert the dataset to a DataFrame
df = tfds.as_dataframe(dataset, info)
ds = tf.data.Dataset.from_tensor_slices(dict(df))
tf.random.set_seed(0)
ds = ds.shuffle(buffer_size=10_000)
ds = ds.batch(batch_size=neg_batch_size)
tokenized_uwm_outputs = []
# Pad to this length (on the right) for batching.
padded_length = 1000
for i, batch in tqdm.tqdm(enumerate(ds)):
responses = [val.decode() for val in batch["text"].numpy()]
inputs = tokenizer(
responses,
return_tensors="pt",
padding=True,
).to(device)
inputs = inputs["input_ids"].cpu().numpy()
if inputs.shape[1] >= padded_length:
inputs = inputs[:, :padded_length]
else:
inputs = np.concatenate(
[inputs, np.ones((neg_batch_size, padded_length - inputs.shape[1])) * tokenizer.eos_token_id], axis=1
)
tokenized_uwm_outputs.append(inputs)
if len(tokenized_uwm_outputs) * neg_batch_size > num_negatives:
break
return tokenized_uwm_outputs
def get_tokenized_wm_outputs(
model,
tokenizer,
watermark_config,
num_pos_batches,
pos_batch_size,
temperature,
max_output_len,
top_k,
top_p,
device,
):
eli5_prompts = datasets.load_dataset("Pavithree/eli5")
wm_outputs = []
for batch_id in tqdm.tqdm(range(num_pos_batches)):
prompts = eli5_prompts["train"]["title"][batch_id * pos_batch_size : (batch_id + 1) * pos_batch_size]
prompts = [prompt.strip('"') for prompt in prompts]
inputs = tokenizer(
prompts,
return_tensors="pt",
padding=True,
).to(device)
_, inputs_len = inputs["input_ids"].shape
outputs = model.generate(
**inputs,
watermarking_config=watermark_config,
do_sample=True,
max_length=inputs_len + max_output_len,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
wm_outputs.append(outputs[:, inputs_len:].cpu().detach())
del outputs, inputs, prompts
gc.collect()
gc.collect()
torch.cuda.empty_cache()
return wm_outputs
def upload_model_to_hf(model, hf_repo_name: str, private: bool = True):
api = HfApi()
# Check if the repository exists
try:
api.repo_info(repo_id=hf_repo_name, use_auth_token=True)
print(f"Repository '{hf_repo_name}' already exists.")
except RepositoryNotFoundError:
# If the repository does not exist, create it
print(f"Repository '{hf_repo_name}' not found. Creating it...")
create_repo(repo_id=hf_repo_name, private=private, use_auth_token=True)
print(f"Repository '{hf_repo_name}' created successfully.")
# Push the model to the Hugging Face Hub
print(f"Uploading model to Hugging Face repo '{hf_repo_name}'...")
model.push_to_hub(repo_id=hf_repo_name, use_auth_token=True)