From 5ad960f1f4f77f436ddf3de3692d09949a27c2df Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Tue, 14 May 2024 13:31:39 +0500 Subject: [PATCH] Add Watermarking LogitsProcessor and WatermarkDetector (#29676) * add watermarking processor * remove the other hashing (context width=1 always) * make style * Update src/transformers/generation/logits_process.py Co-authored-by: Joao Gante * Update src/transformers/generation/logits_process.py Co-authored-by: Joao Gante * Update src/transformers/generation/logits_process.py Co-authored-by: Joao Gante * Update src/transformers/generation/configuration_utils.py Co-authored-by: Joao Gante * update watermarking process * add detector * update tests to use detector * fix failing tests * rename `input_seq` * make style * doc for processor * minor fixes * docs * make quality * Update src/transformers/generation/configuration_utils.py Co-authored-by: Joao Gante * Update src/transformers/generation/logits_process.py Co-authored-by: Joao Gante * Update src/transformers/generation/watermarking.py Co-authored-by: Joao Gante * Update src/transformers/generation/watermarking.py Co-authored-by: Joao Gante * Update src/transformers/generation/watermarking.py Co-authored-by: Joao Gante * add PR suggestions * let's use lru_cache's default max size (128) * import processor if torch available * maybe like this * lets move the config to torch independet file * add docs * tiny docs fix to make the test happy * Update src/transformers/generation/configuration_utils.py Co-authored-by: Joao Gante * Update src/transformers/generation/watermarking.py Co-authored-by: Joao Gante * PR suggestions * add docs * fix test * fix docs * address pr comments * style * Revert "style" This reverts commit 7f33cc34ff08b414f8e7f90060889877606b43b2. * correct style * make doctest green --------- Co-authored-by: Joao Gante --- docs/source/en/generation_strategies.md | 49 ++++ docs/source/en/internal/generation_utils.md | 11 + .../source/en/main_classes/text_generation.md | 2 + src/transformers/__init__.py | 13 +- src/transformers/generation/__init__.py | 14 +- .../generation/configuration_utils.py | 179 +++++++++++++ src/transformers/generation/logits_process.py | 141 ++++++++++ src/transformers/generation/utils.py | 15 ++ src/transformers/generation/watermarking.py | 240 ++++++++++++++++++ src/transformers/utils/dummy_pt_objects.py | 14 + tests/generation/test_logits_process.py | 24 ++ tests/generation/test_utils.py | 40 +++ 12 files changed, 738 insertions(+), 4 deletions(-) create mode 100644 src/transformers/generation/watermarking.py diff --git a/docs/source/en/generation_strategies.md b/docs/source/en/generation_strategies.md index 3d4829c3e3..6c7c70cb14 100644 --- a/docs/source/en/generation_strategies.md +++ b/docs/source/en/generation_strategies.md @@ -173,6 +173,55 @@ your screen, one word at a time: An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven, ``` + +## Watermarking + +The `generate()` supports watermarking the generated text by randomly marking a portion of tokens as "green". +When generating the "green" will have a small 'bias' value added to their logits, thus having a higher chance to be generated. +The watermarked text can be detected by calculating the proportion of "green" tokens in the text and estimating how likely it is +statistically to obtain that amount of "green" tokens for human-generated text. This watermarking strategy was proposed in the paper +["On the Reliability of Watermarks for Large Language Models"](https://arxiv.org/abs/2306.04634). For more information on +the inner functioning of watermarking, it is recommended to refer to the paper. + +The watermarking can be used with any generative model in `tranformers` and does not require an extra classification model +to detect watermarked text. To trigger watermarking, pass in a [`WatermarkingConfig`] with needed arguments directly to the +`.generate()` method or add it to the [`GenerationConfig`]. Watermarked text can be later detected with a [`WatermarkDetector`]. + + + + +The WatermarkDetector internally relies on the proportion of "green" tokens, and whether generated text follows the coloring pattern. +That is why it is recommended to strip off the prompt text, if it is much longer than the generated text. +This also can have an effect when one sequence in the batch is a lot longer causing other rows to be padded. +Additionally, the detector **must** be initiated with identical watermark configuration arguments used when generating. + + + +Let's generate some text with watermarking. In the below code snippet, we set the bias to 2.5 which is a value that +will be added to "green" tokens' logits. After generating watermarked text, we can pass it directly to the `WatermarkDetector` +to check if the text is machine-generated (outputs `True` for machine-generated and `False` otherwise). + +```python +>>> from transformers import AutoTokenizer, AutoModelForCausalLM, WatermarkDetector, WatermarkingConfig + +>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") +>>> tok = AutoTokenizer.from_pretrained("openai-community/gpt2") +>>> tok.pad_token_id = tok.eos_token_id +>>> tok.padding_side = "left" + +>>> inputs = tok(["This is the beginning of a long story", "Alice and Bob are"], padding=True, return_tensors="pt") +>>> input_len = inputs["input_ids"].shape[-1] + +>>> watermarking_config = WatermarkingConfig(bias=2.5, seeding_scheme="selfhash") +>>> out = model.generate(**inputs, watermarking_config=watermarking_config, do_sample=False, max_length=20) + +>>> detector = WatermarkDetector(model_config=model.config, device="cpu", watermarking_config=watermarking_config) +>>> detection_out = detector(out, return_dict=True) +>>> detection_out.prediction +array([True, True]) +``` + + ## Decoding strategies Certain combinations of the `generate()` parameters, and ultimately `generation_config`, can be used to enable specific diff --git a/docs/source/en/internal/generation_utils.md b/docs/source/en/internal/generation_utils.md index 19b80914c9..04a4428a00 100644 --- a/docs/source/en/internal/generation_utils.md +++ b/docs/source/en/internal/generation_utils.md @@ -209,6 +209,10 @@ generation. [[autodoc]] WhisperTimeStampLogitsProcessor - __call__ +[[autodoc]] WatermarkLogitsProcessor + - __call__ + + ### TensorFlow [[autodoc]] TFForcedBOSTokenLogitsProcessor @@ -372,3 +376,10 @@ A [`Constraint`] can be used to force the generation to include specific tokens - update - get_seq_length - reorder_cache + + +## Watermark Utils + +[[autodoc]] WatermarkDetector + - __call__ + diff --git a/docs/source/en/main_classes/text_generation.md b/docs/source/en/main_classes/text_generation.md index dec524d257..e2c5ce9c0b 100644 --- a/docs/source/en/main_classes/text_generation.md +++ b/docs/source/en/main_classes/text_generation.md @@ -41,6 +41,8 @@ like token streaming. - validate - get_generation_mode +[[autodoc]] generation.WatermarkingConfig + ## GenerationMixin [[autodoc]] generation.GenerationMixin diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 97a4e89684..502c561cc8 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -117,7 +117,12 @@ _import_structure = { "feature_extraction_sequence_utils": ["SequenceFeatureExtractor"], "feature_extraction_utils": ["BatchFeature", "FeatureExtractionMixin"], "file_utils": [], - "generation": ["GenerationConfig", "TextIteratorStreamer", "TextStreamer"], + "generation": [ + "GenerationConfig", + "TextIteratorStreamer", + "TextStreamer", + "WatermarkingConfig", + ], "hf_argparser": ["HfArgumentParser"], "hyperparameter_search": [], "image_transforms": [], @@ -1232,6 +1237,8 @@ else: "TopPLogitsWarper", "TypicalLogitsWarper", "UnbatchedClassifierFreeGuidanceLogitsProcessor", + "WatermarkDetector", + "WatermarkLogitsProcessor", "WhisperTimeStampLogitsProcessor", ] ) @@ -4617,7 +4624,7 @@ if TYPE_CHECKING: from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin # Generation - from .generation import GenerationConfig, TextIteratorStreamer, TextStreamer + from .generation import GenerationConfig, TextIteratorStreamer, TextStreamer, WatermarkingConfig from .hf_argparser import HfArgumentParser # Integrations @@ -5797,6 +5804,8 @@ if TYPE_CHECKING: TopPLogitsWarper, TypicalLogitsWarper, UnbatchedClassifierFreeGuidanceLogitsProcessor, + WatermarkDetector, + WatermarkLogitsProcessor, WhisperTimeStampLogitsProcessor, ) from .modeling_utils import PreTrainedModel diff --git a/src/transformers/generation/__init__.py b/src/transformers/generation/__init__.py index d5912984c1..6880321d63 100644 --- a/src/transformers/generation/__init__.py +++ b/src/transformers/generation/__init__.py @@ -18,7 +18,7 @@ from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_availab _import_structure = { - "configuration_utils": ["GenerationConfig", "GenerationMode"], + "configuration_utils": ["GenerationConfig", "GenerationMode", "WatermarkingConfig"], "streamers": ["TextIteratorStreamer", "TextStreamer"], } @@ -78,6 +78,7 @@ else: "TypicalLogitsWarper", "UnbatchedClassifierFreeGuidanceLogitsProcessor", "WhisperTimeStampLogitsProcessor", + "WatermarkLogitsProcessor", ] _import_structure["stopping_criteria"] = [ "MaxNewTokensCriteria", @@ -106,6 +107,10 @@ else: "GenerateDecoderOnlyOutput", "GenerateEncoderDecoderOutput", ] + _import_structure["watermarking"] = [ + "WatermarkDetector", + "WatermarkDetectorOutput", + ] try: if not is_tf_available(): @@ -174,7 +179,7 @@ else: ] if TYPE_CHECKING: - from .configuration_utils import GenerationConfig, GenerationMode + from .configuration_utils import GenerationConfig, GenerationMode, WatermarkingConfig from .streamers import TextIteratorStreamer, TextStreamer try: @@ -218,6 +223,7 @@ if TYPE_CHECKING: TopPLogitsWarper, TypicalLogitsWarper, UnbatchedClassifierFreeGuidanceLogitsProcessor, + WatermarkLogitsProcessor, WhisperTimeStampLogitsProcessor, ) from .stopping_criteria import ( @@ -247,6 +253,10 @@ if TYPE_CHECKING: SampleDecoderOnlyOutput, SampleEncoderDecoderOutput, ) + from .watermarking import ( + WatermarkDetector, + WatermarkDetectorOutput, + ) try: if not is_tf_available(): diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 2bdf20c686..9a5fee5207 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -18,6 +18,7 @@ import copy import json import os import warnings +from dataclasses import dataclass, is_dataclass from typing import TYPE_CHECKING, Any, Dict, Optional, Union from .. import __version__ @@ -221,6 +222,23 @@ class GenerationConfig(PushToHubMixin): low_memory (`bool`, *optional*): Switch to sequential beam search and sequential topk for contrastive search to reduce peak memory. Used with beam search and contrastive search. + watermarking_config (Union[`WatermarkingConfig`, `dict`], *optional*): + Arguments used to watermark the model outputs by adding a small bias to randomly selected set of "green" tokens. + If passed as `Dict`, it will be converted to a `WatermarkingConfig` internally. + See [this paper](https://arxiv.org/abs/2306.04634) for more details. Accepts the following keys: + - greenlist_ratio (`float`): + Used for watermarking. The ratio of "green" tokens used to the vocabulary size. Defaults to 0.25. + - bias (`float`): + Used with watermarking. The bias added to the selected "green" tokens' logits. Defaults to 2.0. + - hashing_key (`int`): + Hahsing key used for watermarking. Defaults to 15485863 (the millionth prime). + - seeding_scheme (`str`): + Algorithm to use for watermarking. Accepts values: + - "lefthash" (default): "green" tokens selection depend on the last token (Algorithm 2 from the paper) + - "selfhash": "green" tokens selection depends on the current token itself (Algorithm 3 from the paper) + The downside of this scheme is that it considers all possible next tokens and can be slower than "lefthash". + - context_width(`int`): + The context length of previous tokens to use in seeding. Higher context length makes watermarking more robust. > Parameters that define the output variables of generate @@ -333,6 +351,13 @@ class GenerationConfig(PushToHubMixin): self.sequence_bias = kwargs.pop("sequence_bias", None) self.guidance_scale = kwargs.pop("guidance_scale", None) self.low_memory = kwargs.pop("low_memory", None) + watermarking_config = kwargs.pop("watermarking_config", None) + if watermarking_config is None: + self.watermarking_config = None + elif isinstance(watermarking_config, WatermarkingConfig): + self.watermarking_config = watermarking_config + else: + self.watermarking_config = WatermarkingConfig.from_dict(watermarking_config) # Parameters that define the output variables of `generate` self.num_return_sequences = kwargs.pop("num_return_sequences", 1) @@ -613,6 +638,12 @@ class GenerationConfig(PushToHubMixin): f"({self.num_beams})." ) + # check watermarking arguments + if self.watermarking_config is not None: + if not isinstance(self.watermarking_config, WatermarkingConfig): + self.watermarking_config = WatermarkingConfig.from_dict(self.watermarking_config) + self.watermarking_config.validate() + # 5. check common issue: passing `generate` arguments inside the generation config generate_arguments = ( "logits_processor", @@ -1021,7 +1052,16 @@ class GenerationConfig(PushToHubMixin): else: return obj + def convert_dataclass_to_dict(obj): + if isinstance(obj, dict): + return {key: convert_dataclass_to_dict(value) for key, value in obj.items()} + elif is_dataclass(obj): + return obj.to_dict() + else: + return obj + config_dict = convert_keys_to_string(config_dict) + config_dict = convert_dataclass_to_dict(config_dict) return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" @@ -1093,3 +1133,142 @@ class GenerationConfig(PushToHubMixin): # Remove all the attributes that were updated, without modifying the input dict unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove} return unused_kwargs + + +@dataclass +class WatermarkingConfig: + """ + Class that holds arguments for watermark generation and should be passed into `GenerationConfig` during `generate`. + See [this paper](https://arxiv.org/abs/2306.04634) for more details on the arguments. + + Accepts the following keys: + - greenlist_ratio (`float`): + Used for watermarking. The ratio of "green" tokens used to the vocabulary size. Defaults to 0.25. + - bias (`float`): + Used with watermarking. The bias added to the selected "green" tokens' logits. Defaults to 2.0. + - hashing_key (`int`): + Hashing key used for watermarking. Defaults to 15485863 (the millionth prime). + - seeding_scheme (`str`): + Algorithm to use for watermarking. Accepts values: + - "lefthash" (default): "green" tokens selection depend on the last token (Algorithm 2 from the paper) + - "selfhash": "green" tokens selection depends on the current token itself (Algorithm 3 from the paper) + The downside of this scheme is that it considers all possible next tokens and can be slower than "lefthash". + - context_width(`int`): + The context length of previous tokens to use in seeding. Higher context length makes watermarking more robust. + """ + + def __init__( + self, + greenlist_ratio: Optional[float] = 0.25, + bias: Optional[float] = 2.0, + hashing_key: Optional[int] = 15485863, + seeding_scheme: Optional[str] = "lefthash", + context_width: Optional[int] = 1, + ): + self.greenlist_ratio = greenlist_ratio + self.bias = bias + self.hashing_key = hashing_key + self.seeding_scheme = seeding_scheme + self.context_width = context_width + + @classmethod + def from_dict(cls, config_dict, **kwargs): + """ + Constructs a WatermarkingConfig instance from a dictionary of parameters. + + Args: + config_dict (Dict[str, Any]): Dictionary containing configuration parameters. + **kwargs: Additional keyword arguments to override dictionary values. + + Returns: + WatermarkingConfig: Instance of WatermarkingConfig constructed from the dictionary. + """ + config = cls(**config_dict) + to_remove = [] + for key, value in kwargs.items(): + if hasattr(config, key): + setattr(config, key, value) + to_remove.append(key) + for key in to_remove: + kwargs.pop(key, None) + return config + + def to_json_file(self, json_file_path: Union[str, os.PathLike]): + """ + Save this instance to a JSON file. + + Args: + json_file_path (Union[str, os.PathLike]): Path to the JSON file in which this configuration instance's parameters will be saved. + """ + with open(json_file_path, "w", encoding="utf-8") as writer: + config_dict = self.to_dict() + json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n" + + writer.write(json_string) + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes this instance to a Python dictionary. + + Returns: + Dict[str, Any]: Dictionary of all the attributes that make up this configuration instance. + """ + output = copy.deepcopy(self.__dict__) + return output + + def __iter__(self): + for attr, value in copy.deepcopy(self.__dict__).items(): + yield attr, value + + def __repr__(self): + return f"{self.__class__.__name__} {self.to_json_string()}" + + def to_json_string(self): + """ + Serializes this instance to a JSON formatted string. + + Returns: + str: JSON formatted string representing the configuration instance. + """ + return json.dumps(self.__dict__, indent=2) + "\n" + + def update(self, **kwargs): + """ + Update the configuration attributes with new values. + + Args: + **kwargs: Keyword arguments representing configuration attributes and their new values. + """ + for key, value in kwargs.items(): + if hasattr(self, key): + setattr(self, key, value) + + def validate(self): + watermark_missing_arg_msg = ( + "Some of the keys in `watermarking_config` are defined incorrectly. `{key}` should be {correct_value}` " + "but found {found_value}" + ) + if self.seeding_scheme not in ["selfhash", "lefthash"]: + raise ValueError( + watermark_missing_arg_msg.format( + key="seeding_scheme", + correct_value="[`selfhash`, `lefthash`]", + found_value=self.seeding_scheme, + ), + ) + if not 0.0 <= self.greenlist_ratio <= 1.0: + raise ValueError( + watermark_missing_arg_msg.format( + key="greenlist_ratio", + correct_value="in range between 0.0 and 1.0", + found_value=self.seeding_scheme, + ), + ) + if not self.context_width >= 1: + raise ValueError( + watermark_missing_arg_msg.format( + key="context_width", + correct_value="a positive integer", + found_value=self.context_width, + ), + ) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 958d436a6a..d870446504 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -2321,3 +2321,144 @@ class BarkEosPrioritizerLogitsProcessor(LogitsProcessor): scores_processed = torch.where(do_early_stop, early_stop_scores, scores) return scores_processed + + +class WatermarkLogitsProcessor(LogitsProcessor): + r""" + Logits processor for watermarking generated text. The processor modifies model output scores by adding a small bias to + randomized set of "green" tokens before generating the next token. "Green" tokens selection process depends on the + `seeding_scheme` used. The code was based on the [original repo](https://github.com/jwkirchenbauer/lm-watermarking/tree/main). + + The text generated by this `LogitsProcessor` can be detected using `WatermarkDetector`. See [`~WatermarkDetector.__call__`] for details, + + See [the paper](https://arxiv.org/abs/2306.04634) for more information. + + Args: + vocab_size (`int`): + The model tokenizer's vocab_size. Used to calculate "green" tokens ratio. + device (`str`): + The device where model is allocated. + greenlist_ratio (`float`, optional, *optional*, defaults to 0.25): + The ratio of "green" tokens used to the vocabulary size. Defaults to 0.25. + bias (`float`, optional, *optional*, defaults to 2.0): + The bias added to the selected "green" tokens' logits. Consider lowering the + `bias` if the text generation quality degrades. Recommended values are in the + range of [0.5, 2.0]. Defaults to 2.0. + hashing_key (`int`, optional, *optional*, defaults to 15485863): + Key used for hashing. If you deploy this watermark, we advise using another private key. + Defaults to 15485863 (the millionth prime). + seeding_scheme (`str`, optional, *optional*, defaults to `"lefthash"`): + The seeding scheme used for selecting "green" tokens. Accepts values: + - "lefthash" (default): "green" tokens selection depend on the last token (Algorithm 2 from paper) + - "selfhash": "green" tokens selection depends on the current token itself (Algorithm 3 from paper) + The downside of this scheme is that it considers all possible next tokens and can be slower than "lefthash". + The context length of previous tokens to use in seeding. Higher context length makes watermarking more robust. + context_width (`int`, *optional*, defaults to 1): + The number of previous tokens to use when setting the seed. + + Examples: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, WatermarkingConfig + + >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") + >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + >>> inputs = tokenizer(["Alice and Bob are"], return_tensors="pt") + + >>> # normal generation + >>> out = model.generate(inputs["input_ids"], max_length=20, do_sample=False) + >>> tokenizer.batch_decode(out, skip_special_tokens=True)[0] + 'Alice and Bob are both in the same room.\n\n"I\'m not sure if you\'re' + + >>> # watermarked generation + >>> watermarking_config = WatermarkingConfig(bias=2.5, context_width=2, seeding_scheme="selfhash") + >>> out = model.generate(inputs["input_ids"], watermarking_config=watermarking_config, max_length=20, do_sample=False) + >>> tokenizer.batch_decode(out, skip_special_tokens=True)[0] + 'Alice and Bob are both still alive and well and the story is pretty much a one-hour adventure' + + >>> # to detect watermarked text use the WatermarkDetector class + >>> from transformers import WatermarkDetector + >>> detector = WatermarkDetector(model_config=model.config, device="cpu", watermarking_config= watermarking_config) + >>> detection_preds = detector(out) + >>> detection_preds + array([ True]) + ``` + """ + + def __init__( + self, + vocab_size, + device, + greenlist_ratio: float = 0.25, + bias: float = 2.0, + hashing_key: int = 15485863, + seeding_scheme: str = "lefthash", + context_width: int = 1, + ): + if seeding_scheme not in ["selfhash", "lefthash"]: + raise ValueError(f"seeding_scheme has to be one of [`selfhash`, `lefthash`], but found {seeding_scheme}") + if greenlist_ratio >= 1.0 or greenlist_ratio <= 0.0: + raise ValueError( + f"greenlist_ratio has be in range between 0.0 and 1.0, exclusively. but found {greenlist_ratio}" + ) + + self.vocab_size = vocab_size + self.greenlist_size = int(self.vocab_size * greenlist_ratio) + self.bias = bias + self.seeding_scheme = seeding_scheme + self.rng = torch.Generator(device=device) + self.hash_key = hashing_key + self.context_width = context_width + + self.rng.manual_seed(hashing_key) + self.table_size = 1_000_003 + self.fixed_table = torch.randperm(self.table_size, generator=self.rng, device=device) + + def set_seed(self, input_seq: torch.LongTensor): + input_seq = input_seq[-self.context_width :] + if self.seeding_scheme == "selfhash": + a = self.fixed_table[input_seq % self.table_size] + 1 + b = self.fixed_table[input_seq[-1] % self.table_size] + 1 + seed = (self.hash_key * a * b).min().item() + else: + seed = self.hash_key * input_seq[-1].item() + self.rng.manual_seed(seed % (2**64 - 1)) + + def _get_greenlist_ids(self, input_seq: torch.LongTensor) -> torch.LongTensor: + self.set_seed(input_seq) + vocab_permutation = torch.randperm(self.vocab_size, device=input_seq.device, generator=self.rng) + greenlist_ids = vocab_permutation[: self.greenlist_size] + return greenlist_ids + + def _score_rejection_sampling(self, input_seq: torch.LongTensor, scores: torch.FloatTensor) -> torch.LongTensor: + """ + Generate greenlist based on current candidate next token. Reject and move on if necessary. + Runs for a fixed number of steps only for efficiency, since the methods is not batched. + """ + final_greenlist = [] + _, greedy_predictions = scores.sort(dim=-1, descending=True) + + # 40 is an arbitrary number chosen to save compute and not run for long (taken from orig repo) + for i in range(40): + greenlist_ids = self._get_greenlist_ids(torch.cat([input_seq, greedy_predictions[i, None]], dim=-1)) + if greedy_predictions[i] in greenlist_ids: + final_greenlist.append(greedy_predictions[i]) + return torch.tensor(final_greenlist, device=input_seq.device) + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + if input_ids.shape[-1] < self.context_width: + logger.warning( + f"`input_ids` should have at least `{self.context_width}` tokens but has {input_ids.shape[-1]}. " + "The seeding will be skipped for this generation step!" + ) + return scores + + scores_processed = scores.clone() + for b_idx, input_seq in enumerate(input_ids): + if self.seeding_scheme == "selfhash": + greenlist_ids = self._score_rejection_sampling(input_seq, scores[b_idx]) + else: + greenlist_ids = self._get_greenlist_ids(input_seq) + scores_processed[b_idx, greenlist_ids] = scores_processed[b_idx, greenlist_ids] + self.bias + + return scores_processed diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 1c90fdd307..086cc99846 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -74,6 +74,7 @@ from .logits_process import ( TopPLogitsWarper, TypicalLogitsWarper, UnbatchedClassifierFreeGuidanceLogitsProcessor, + WatermarkLogitsProcessor, ) from .stopping_criteria import ( EosTokenCriteria, @@ -763,6 +764,7 @@ class GenerationMixin: encoder_input_ids: torch.LongTensor, prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], logits_processor: Optional[LogitsProcessorList], + device: str = None, model_kwargs: Optional[Dict[str, Any]] = None, negative_prompt_ids: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, @@ -879,6 +881,18 @@ class GenerationMixin: FutureWarning, ) processors.append(ForceTokensLogitsProcessor(generation_config.forced_decoder_ids, _has_warned=True)) + if generation_config.watermarking_config is not None: + processors.append( + WatermarkLogitsProcessor( + vocab_size=self.config.vocab_size, + device=device, + greenlist_ratio=generation_config.watermarking_config.greenlist_ratio, + bias=generation_config.watermarking_config.bias, + hashing_key=generation_config.watermarking_config.hashing_key, + seeding_scheme=generation_config.watermarking_config.seeding_scheme, + context_width=generation_config.watermarking_config.context_width, + ) + ) processors = self._merge_criteria_processor_list(processors, logits_processor) # `LogitNormalization` should always be the last logit processor, when present if generation_config.renormalize_logits is True: @@ -1632,6 +1646,7 @@ class GenerationMixin: encoder_input_ids=inputs_tensor, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, logits_processor=logits_processor, + device=inputs_tensor.device, model_kwargs=model_kwargs, negative_prompt_ids=negative_prompt_ids, negative_prompt_attention_mask=negative_prompt_attention_mask, diff --git a/src/transformers/generation/watermarking.py b/src/transformers/generation/watermarking.py new file mode 100644 index 0000000000..297d388d54 --- /dev/null +++ b/src/transformers/generation/watermarking.py @@ -0,0 +1,240 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +from dataclasses import dataclass +from functools import lru_cache +from typing import Dict, Optional, Union + +import numpy as np + +from ..configuration_utils import PretrainedConfig +from ..utils import is_torch_available, logging +from .configuration_utils import WatermarkingConfig + + +if is_torch_available(): + import torch + + from .logits_process import WatermarkLogitsProcessor + + +logger = logging.get_logger(__name__) + + +@dataclass +class WatermarkDetectorOutput: + """ + Outputs of a watermark detector. + + Args: + num_tokens_scored (np.array of shape (batch_size)): + Array containing the number of tokens scored for each element in the batch. + num_green_tokens (np.array of shape (batch_size)): + Array containing the number of green tokens for each element in the batch. + green_fraction (np.array of shape (batch_size)): + Array containing the fraction of green tokens for each element in the batch. + z_score (np.array of shape (batch_size)): + Array containing the z-score for each element in the batch. Z-score here shows + how many standard deviations away is the green token count in the input text + from the expected green token count for machine-generated text. + p_value (np.array of shape (batch_size)): + Array containing the p-value for each batch obtained from z-scores. + prediction (np.array of shape (batch_size)), *optional*: + Array containing boolean predictions whether a text is machine-generated for each element in the batch. + confidence (np.array of shape (batch_size)), *optional*: + Array containing confidence scores of a text being machine-generated for each element in the batch. + """ + + num_tokens_scored: np.array = None + num_green_tokens: np.array = None + green_fraction: np.array = None + z_score: np.array = None + p_value: np.array = None + prediction: Optional[np.array] = None + confidence: Optional[np.array] = None + + +class WatermarkDetector: + + r""" + Detector for detection of watermark generated text. The detector needs to be given the exact same settings that were + given during text generation to replicate the watermark greenlist generation and so detect the watermark. This includes + the correct device that was used during text generation, the correct watermarking arguments and the correct tokenizer vocab size. + The code was based on the [original repo](https://github.com/jwkirchenbauer/lm-watermarking/tree/main). + + See [the paper](https://arxiv.org/abs/2306.04634) for more information. + + Args: + model_config (`PretrainedConfig`): + The model config that will be used to get model specific arguments used when generating. + device (`str`): + The device which was used during watermarked text generation. + watermarking_config (Union[`WatermarkingConfig`, `Dict`]): + The exact same watermarking config and arguments used when generating text. + ignore_repeated_ngrams (`bool`, *optional*, defaults to `False`): + Whether to count every unique ngram only once or not. + max_cache_size (`int`, *optional*, defaults to 128): + The max size to be used for LRU caching of seeding/sampling algorithms called for every token. + + Examples: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, WatermarkDetector, WatermarkingConfig + + >>> model_id = "openai-community/gpt2" + >>> model = AutoModelForCausalLM.from_pretrained(model_id) + >>> tok = AutoTokenizer.from_pretrained(model_id) + >>> tok.pad_token_id = tok.eos_token_id + >>> tok.padding_side = "left" + + >>> inputs = tok(["This is the beginning of a long story", "Alice and Bob are"], padding=True, return_tensors="pt") + >>> input_len = inputs["input_ids"].shape[-1] + + >>> # first generate text with watermark and without + >>> watermarking_config = WatermarkingConfig(bias=2.5, seeding_scheme="selfhash") + >>> out_watermarked = model.generate(**inputs, watermarking_config=watermarking_config, do_sample=False, max_length=20) + >>> out = model.generate(**inputs, do_sample=False, max_length=20) + + >>> # now we can instantiate the detector and check the generated text + >>> detector = WatermarkDetector(model_config=model.config, device="cpu", watermarking_config=watermarking_config) + >>> detection_out_watermarked = detector(out_watermarked, return_dict=True) + >>> detection_out = detector(out, return_dict=True) + >>> detection_out_watermarked.prediction + array([ True, True]) + + >>> detection_out.prediction + array([False, False]) + ``` + """ + + def __init__( + self, + model_config: PretrainedConfig, + device: str, + watermarking_config: Union[WatermarkingConfig, Dict], + ignore_repeated_ngrams: bool = False, + max_cache_size: int = 128, + ): + if isinstance(watermarking_config, WatermarkingConfig): + watermarking_config = watermarking_config.to_dict() + + self.bos_token_id = ( + model_config.bos_token_id if not model_config.is_encoder_decoder else model_config.decoder_start_token_id + ) + self.greenlist_ratio = watermarking_config["greenlist_ratio"] + self.ignore_repeated_ngrams = ignore_repeated_ngrams + self.processor = WatermarkLogitsProcessor( + vocab_size=model_config.vocab_size, device=device, **watermarking_config + ) + + # Expensive re-seeding and sampling is cached. + self._get_ngram_score_cached = lru_cache(maxsize=max_cache_size)(self._get_ngram_score) + + def _get_ngram_score(self, prefix: torch.LongTensor, target: int): + greenlist_ids = self.processor._get_greenlist_ids(prefix) + return target in greenlist_ids + + def _score_ngrams_in_passage(self, input_ids: torch.LongTensor): + batch_size, seq_length = input_ids.shape + selfhash = int(self.processor.seeding_scheme == "selfhash") + n = self.processor.context_width + 1 - selfhash + indices = torch.arange(n).unsqueeze(0) + torch.arange(seq_length - n + 1).unsqueeze(1) + ngram_tensors = input_ids[:, indices] + + num_tokens_scored_batch = np.zeros(batch_size) + green_token_count_batch = np.zeros(batch_size) + for batch_idx in range(ngram_tensors.shape[0]): + frequencies_table = collections.Counter(ngram_tensors[batch_idx]) + ngram_to_watermark_lookup = {} + for ngram_example in frequencies_table.keys(): + prefix = ngram_example if selfhash else ngram_example[:-1] + target = ngram_example[-1] + ngram_to_watermark_lookup[ngram_example] = self._get_ngram_score_cached(prefix, target) + + if self.ignore_repeated_ngrams: + # counts a green/red hit once per unique ngram. + # num total tokens scored becomes the number unique ngrams. + num_tokens_scored_batch[batch_idx] = len(frequencies_table.keys()) + green_token_count_batch[batch_idx] = sum(ngram_to_watermark_lookup.values()) + else: + num_tokens_scored_batch[batch_idx] = sum(frequencies_table.values()) + green_token_count_batch[batch_idx] = sum( + freq * outcome + for freq, outcome in zip(frequencies_table.values(), ngram_to_watermark_lookup.values()) + ) + return num_tokens_scored_batch, green_token_count_batch + + def _compute_z_score(self, green_token_count: np.array, total_num_tokens: np.array) -> np.array: + expected_count = self.greenlist_ratio + numer = green_token_count - expected_count * total_num_tokens + denom = np.sqrt(total_num_tokens * expected_count * (1 - expected_count)) + z = numer / denom + return z + + def _compute_pval(self, x, loc=0, scale=1): + z = (x - loc) / scale + return 1 - (0.5 * (1 + np.sign(z) * (1 - np.exp(-2 * z**2 / np.pi)))) + + def __call__( + self, + input_ids: torch.LongTensor, + z_threshold: float = 3.0, + return_dict: bool = False, + ) -> Union[WatermarkDetectorOutput, np.array]: + """ + Args: + input_ids (`torch.LongTensor`): + The watermark generated text. It is advised to remove the prompt, which can affect the detection. + z_threshold (`Dict`, *optional*, defaults to `3.0`): + Changing this threshold will change the sensitivity of the detector. Higher z threshold gives less + sensitivity and vice versa for lower z threshold. + return_dict (`bool`, *optional*, defaults to `False`): + Whether to return `~generation.WatermarkDetectorOutput` or not. If not it will return boolean predictions, + ma + Return: + [`~generation.WatermarkDetectorOutput`] or `np.array`: A [`~generation.WatermarkDetectorOutput`] + if `return_dict=True` otherwise a `np.array`. + + """ + + # Let's assume that if one batch start with `bos`, all batched also do + if input_ids[0, 0] == self.bos_token_id: + input_ids = input_ids[:, 1:] + + if input_ids.shape[-1] - self.processor.context_width < 1: + raise ValueError( + f"Must have at least `1` token to score after the first " + f"min_prefix_len={self.processor.context_width} tokens required by the seeding scheme." + ) + + num_tokens_scored, green_token_count = self._score_ngrams_in_passage(input_ids) + z_score = self._compute_z_score(green_token_count, num_tokens_scored) + prediction = z_score > z_threshold + + if return_dict: + p_value = self._compute_pval(z_score) + confidence = 1 - p_value + + return WatermarkDetectorOutput( + num_tokens_scored=num_tokens_scored, + num_green_tokens=green_token_count, + green_fraction=green_token_count / num_tokens_scored, + z_score=z_score, + p_value=p_value, + prediction=prediction, + confidence=confidence, + ) + return prediction diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index d6eae8aafd..7d9813d150 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -422,6 +422,20 @@ class UnbatchedClassifierFreeGuidanceLogitsProcessor(metaclass=DummyObject): requires_backends(self, ["torch"]) +class WatermarkDetector(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class WatermarkLogitsProcessor(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class WhisperTimeStampLogitsProcessor(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/generation/test_logits_process.py b/tests/generation/test_logits_process.py index 30083e4f1f..775e702a02 100644 --- a/tests/generation/test_logits_process.py +++ b/tests/generation/test_logits_process.py @@ -53,6 +53,7 @@ if is_torch_available(): TopPLogitsWarper, TypicalLogitsWarper, UnbatchedClassifierFreeGuidanceLogitsProcessor, + WatermarkLogitsProcessor, ) from transformers.generation.logits_process import BarkEosPrioritizerLogitsProcessor @@ -949,3 +950,26 @@ class LogitsProcessorTest(unittest.TestCase): [float("-inf"), float("-inf"), scores[0][0], scores[0][0]], ] self.assertListEqual(actual_scores.tolist(), expected_scores_list) + + def test_watermarking_processor(self): + batch_size = 3 + vocab_size = 20 + + input_ids = ids_tensor((batch_size, 5), vocab_size=20) + scores = self._get_uniform_logits(batch_size, vocab_size) + + # raise error if incorrect seeding_scheme is passed + with self.assertRaises(ValueError): + WatermarkLogitsProcessor(vocab_size=vocab_size, device="cpu", seeding_scheme="hash") + + # raise error if the greenlist_ratio in not in range (0.0, 1.0) + with self.assertRaises(ValueError): + WatermarkLogitsProcessor(vocab_size=vocab_size, device="cpu", greenlist_ratio=1.2) + + watermark = WatermarkLogitsProcessor(vocab_size=vocab_size, device=input_ids.device) + + # use fixed id for last token, needed for reprodicibility and tests + input_ids[:, -1] = 10 + scores_wo_bias = scores[:, -1].clone() + out = watermark(input_ids=input_ids, scores=scores) + self.assertTrue((out[:, 1] == scores_wo_bias + watermark.bias).all()) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 3539398334..cf703d8a22 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -75,6 +75,8 @@ if is_torch_available(): SampleEncoderDecoderOutput, StoppingCriteria, StoppingCriteriaList, + WatermarkDetector, + WatermarkingConfig, ) from transformers.generation.utils import _speculative_sampling @@ -2098,6 +2100,44 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi ) self.assertListEqual(low_output.tolist(), high_output.tolist()) + @slow + def test_watermark_generation(self): + tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") + model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2").to(torch_device) + tokenizer.pad_token_id = tokenizer.eos_token_id + model_inputs = tokenizer("I will be", return_tensors="pt").to(torch_device) + input_len = model_inputs["input_ids"].shape[-1] + + # generation should work with both input types: WatermarkingConfig or Dict, so let's check it here :) + watermark_config = WatermarkingConfig(bias=2.5, seeding_scheme="selfhash") + _ = model.generate(**model_inputs, watermarking_config=watermark_config, do_sample=False, max_length=15) + + args = { + "bias": 2.0, + "context_width": 1, + "seeding_scheme": "selfhash", + "greenlist_ratio": 0.25, + "hashing_key": 15485863, + } + output = model.generate(**model_inputs, do_sample=False, max_length=15) + output_selfhash = model.generate(**model_inputs, watermarking_config=args, do_sample=False, max_length=15) + + # check that the watermarked text is generating what is should + self.assertListEqual( + output.tolist(), [[40, 481, 307, 262, 717, 284, 9159, 326, 314, 716, 407, 257, 4336, 286, 262]] + ) + self.assertListEqual( + output_selfhash.tolist(), [[40, 481, 307, 2263, 616, 640, 284, 651, 616, 1621, 503, 612, 553, 531, 367]] + ) + + detector = WatermarkDetector(model_config=model.config, device=torch_device, watermarking_config=args) + detection_out_watermarked = detector(output_selfhash[:, input_len:], return_dict=True) + detection_out = detector(output[:, input_len:], return_dict=True) + + # check that the detector is detecting watermarked text + self.assertListEqual(detection_out_watermarked.prediction.tolist(), [True]) + self.assertListEqual(detection_out.prediction.tolist(), [False]) + @slow def test_beam_search_example_integration(self): # PT-only test: TF doesn't have a BeamSearchScorer