Add Watermarking LogitsProcessor and WatermarkDetector (#29676)
* add watermarking processor * remove the other hashing (context width=1 always) * make style * Update src/transformers/generation/logits_process.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/logits_process.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/logits_process.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/configuration_utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * update watermarking process * add detector * update tests to use detector * fix failing tests * rename `input_seq` * make style * doc for processor * minor fixes * docs * make quality * Update src/transformers/generation/configuration_utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/logits_process.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/watermarking.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/watermarking.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/watermarking.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * add PR suggestions * let's use lru_cache's default max size (128) * import processor if torch available * maybe like this * lets move the config to torch independet file * add docs * tiny docs fix to make the test happy * Update src/transformers/generation/configuration_utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/watermarking.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * PR suggestions * add docs * fix test * fix docs * address pr comments * style * Revert "style" This reverts commit 7f33cc34ff08b414f8e7f90060889877606b43b2. * correct style * make doctest green --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
committed by
GitHub
parent
65ea1904ff
commit
5ad960f1f4
@@ -173,6 +173,55 @@ your screen, one word at a time:
|
|||||||
An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,
|
An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Watermarking
|
||||||
|
|
||||||
|
The `generate()` supports watermarking the generated text by randomly marking a portion of tokens as "green".
|
||||||
|
When generating the "green" will have a small 'bias' value added to their logits, thus having a higher chance to be generated.
|
||||||
|
The watermarked text can be detected by calculating the proportion of "green" tokens in the text and estimating how likely it is
|
||||||
|
statistically to obtain that amount of "green" tokens for human-generated text. This watermarking strategy was proposed in the paper
|
||||||
|
["On the Reliability of Watermarks for Large Language Models"](https://arxiv.org/abs/2306.04634). For more information on
|
||||||
|
the inner functioning of watermarking, it is recommended to refer to the paper.
|
||||||
|
|
||||||
|
The watermarking can be used with any generative model in `tranformers` and does not require an extra classification model
|
||||||
|
to detect watermarked text. To trigger watermarking, pass in a [`WatermarkingConfig`] with needed arguments directly to the
|
||||||
|
`.generate()` method or add it to the [`GenerationConfig`]. Watermarked text can be later detected with a [`WatermarkDetector`].
|
||||||
|
|
||||||
|
|
||||||
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
The WatermarkDetector internally relies on the proportion of "green" tokens, and whether generated text follows the coloring pattern.
|
||||||
|
That is why it is recommended to strip off the prompt text, if it is much longer than the generated text.
|
||||||
|
This also can have an effect when one sequence in the batch is a lot longer causing other rows to be padded.
|
||||||
|
Additionally, the detector **must** be initiated with identical watermark configuration arguments used when generating.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
Let's generate some text with watermarking. In the below code snippet, we set the bias to 2.5 which is a value that
|
||||||
|
will be added to "green" tokens' logits. After generating watermarked text, we can pass it directly to the `WatermarkDetector`
|
||||||
|
to check if the text is machine-generated (outputs `True` for machine-generated and `False` otherwise).
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, WatermarkDetector, WatermarkingConfig
|
||||||
|
|
||||||
|
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
|
||||||
|
>>> tok = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
||||||
|
>>> tok.pad_token_id = tok.eos_token_id
|
||||||
|
>>> tok.padding_side = "left"
|
||||||
|
|
||||||
|
>>> inputs = tok(["This is the beginning of a long story", "Alice and Bob are"], padding=True, return_tensors="pt")
|
||||||
|
>>> input_len = inputs["input_ids"].shape[-1]
|
||||||
|
|
||||||
|
>>> watermarking_config = WatermarkingConfig(bias=2.5, seeding_scheme="selfhash")
|
||||||
|
>>> out = model.generate(**inputs, watermarking_config=watermarking_config, do_sample=False, max_length=20)
|
||||||
|
|
||||||
|
>>> detector = WatermarkDetector(model_config=model.config, device="cpu", watermarking_config=watermarking_config)
|
||||||
|
>>> detection_out = detector(out, return_dict=True)
|
||||||
|
>>> detection_out.prediction
|
||||||
|
array([True, True])
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
## Decoding strategies
|
## Decoding strategies
|
||||||
|
|
||||||
Certain combinations of the `generate()` parameters, and ultimately `generation_config`, can be used to enable specific
|
Certain combinations of the `generate()` parameters, and ultimately `generation_config`, can be used to enable specific
|
||||||
|
|||||||
@@ -209,6 +209,10 @@ generation.
|
|||||||
[[autodoc]] WhisperTimeStampLogitsProcessor
|
[[autodoc]] WhisperTimeStampLogitsProcessor
|
||||||
- __call__
|
- __call__
|
||||||
|
|
||||||
|
[[autodoc]] WatermarkLogitsProcessor
|
||||||
|
- __call__
|
||||||
|
|
||||||
|
|
||||||
### TensorFlow
|
### TensorFlow
|
||||||
|
|
||||||
[[autodoc]] TFForcedBOSTokenLogitsProcessor
|
[[autodoc]] TFForcedBOSTokenLogitsProcessor
|
||||||
@@ -372,3 +376,10 @@ A [`Constraint`] can be used to force the generation to include specific tokens
|
|||||||
- update
|
- update
|
||||||
- get_seq_length
|
- get_seq_length
|
||||||
- reorder_cache
|
- reorder_cache
|
||||||
|
|
||||||
|
|
||||||
|
## Watermark Utils
|
||||||
|
|
||||||
|
[[autodoc]] WatermarkDetector
|
||||||
|
- __call__
|
||||||
|
|
||||||
|
|||||||
@@ -41,6 +41,8 @@ like token streaming.
|
|||||||
- validate
|
- validate
|
||||||
- get_generation_mode
|
- get_generation_mode
|
||||||
|
|
||||||
|
[[autodoc]] generation.WatermarkingConfig
|
||||||
|
|
||||||
## GenerationMixin
|
## GenerationMixin
|
||||||
|
|
||||||
[[autodoc]] generation.GenerationMixin
|
[[autodoc]] generation.GenerationMixin
|
||||||
|
|||||||
@@ -117,7 +117,12 @@ _import_structure = {
|
|||||||
"feature_extraction_sequence_utils": ["SequenceFeatureExtractor"],
|
"feature_extraction_sequence_utils": ["SequenceFeatureExtractor"],
|
||||||
"feature_extraction_utils": ["BatchFeature", "FeatureExtractionMixin"],
|
"feature_extraction_utils": ["BatchFeature", "FeatureExtractionMixin"],
|
||||||
"file_utils": [],
|
"file_utils": [],
|
||||||
"generation": ["GenerationConfig", "TextIteratorStreamer", "TextStreamer"],
|
"generation": [
|
||||||
|
"GenerationConfig",
|
||||||
|
"TextIteratorStreamer",
|
||||||
|
"TextStreamer",
|
||||||
|
"WatermarkingConfig",
|
||||||
|
],
|
||||||
"hf_argparser": ["HfArgumentParser"],
|
"hf_argparser": ["HfArgumentParser"],
|
||||||
"hyperparameter_search": [],
|
"hyperparameter_search": [],
|
||||||
"image_transforms": [],
|
"image_transforms": [],
|
||||||
@@ -1232,6 +1237,8 @@ else:
|
|||||||
"TopPLogitsWarper",
|
"TopPLogitsWarper",
|
||||||
"TypicalLogitsWarper",
|
"TypicalLogitsWarper",
|
||||||
"UnbatchedClassifierFreeGuidanceLogitsProcessor",
|
"UnbatchedClassifierFreeGuidanceLogitsProcessor",
|
||||||
|
"WatermarkDetector",
|
||||||
|
"WatermarkLogitsProcessor",
|
||||||
"WhisperTimeStampLogitsProcessor",
|
"WhisperTimeStampLogitsProcessor",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -4617,7 +4624,7 @@ if TYPE_CHECKING:
|
|||||||
from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin
|
from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin
|
||||||
|
|
||||||
# Generation
|
# Generation
|
||||||
from .generation import GenerationConfig, TextIteratorStreamer, TextStreamer
|
from .generation import GenerationConfig, TextIteratorStreamer, TextStreamer, WatermarkingConfig
|
||||||
from .hf_argparser import HfArgumentParser
|
from .hf_argparser import HfArgumentParser
|
||||||
|
|
||||||
# Integrations
|
# Integrations
|
||||||
@@ -5797,6 +5804,8 @@ if TYPE_CHECKING:
|
|||||||
TopPLogitsWarper,
|
TopPLogitsWarper,
|
||||||
TypicalLogitsWarper,
|
TypicalLogitsWarper,
|
||||||
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
||||||
|
WatermarkDetector,
|
||||||
|
WatermarkLogitsProcessor,
|
||||||
WhisperTimeStampLogitsProcessor,
|
WhisperTimeStampLogitsProcessor,
|
||||||
)
|
)
|
||||||
from .modeling_utils import PreTrainedModel
|
from .modeling_utils import PreTrainedModel
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_availab
|
|||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
"configuration_utils": ["GenerationConfig", "GenerationMode"],
|
"configuration_utils": ["GenerationConfig", "GenerationMode", "WatermarkingConfig"],
|
||||||
"streamers": ["TextIteratorStreamer", "TextStreamer"],
|
"streamers": ["TextIteratorStreamer", "TextStreamer"],
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -78,6 +78,7 @@ else:
|
|||||||
"TypicalLogitsWarper",
|
"TypicalLogitsWarper",
|
||||||
"UnbatchedClassifierFreeGuidanceLogitsProcessor",
|
"UnbatchedClassifierFreeGuidanceLogitsProcessor",
|
||||||
"WhisperTimeStampLogitsProcessor",
|
"WhisperTimeStampLogitsProcessor",
|
||||||
|
"WatermarkLogitsProcessor",
|
||||||
]
|
]
|
||||||
_import_structure["stopping_criteria"] = [
|
_import_structure["stopping_criteria"] = [
|
||||||
"MaxNewTokensCriteria",
|
"MaxNewTokensCriteria",
|
||||||
@@ -106,6 +107,10 @@ else:
|
|||||||
"GenerateDecoderOnlyOutput",
|
"GenerateDecoderOnlyOutput",
|
||||||
"GenerateEncoderDecoderOutput",
|
"GenerateEncoderDecoderOutput",
|
||||||
]
|
]
|
||||||
|
_import_structure["watermarking"] = [
|
||||||
|
"WatermarkDetector",
|
||||||
|
"WatermarkDetectorOutput",
|
||||||
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_tf_available():
|
if not is_tf_available():
|
||||||
@@ -174,7 +179,7 @@ else:
|
|||||||
]
|
]
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_utils import GenerationConfig, GenerationMode
|
from .configuration_utils import GenerationConfig, GenerationMode, WatermarkingConfig
|
||||||
from .streamers import TextIteratorStreamer, TextStreamer
|
from .streamers import TextIteratorStreamer, TextStreamer
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -218,6 +223,7 @@ if TYPE_CHECKING:
|
|||||||
TopPLogitsWarper,
|
TopPLogitsWarper,
|
||||||
TypicalLogitsWarper,
|
TypicalLogitsWarper,
|
||||||
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
||||||
|
WatermarkLogitsProcessor,
|
||||||
WhisperTimeStampLogitsProcessor,
|
WhisperTimeStampLogitsProcessor,
|
||||||
)
|
)
|
||||||
from .stopping_criteria import (
|
from .stopping_criteria import (
|
||||||
@@ -247,6 +253,10 @@ if TYPE_CHECKING:
|
|||||||
SampleDecoderOnlyOutput,
|
SampleDecoderOnlyOutput,
|
||||||
SampleEncoderDecoderOutput,
|
SampleEncoderDecoderOutput,
|
||||||
)
|
)
|
||||||
|
from .watermarking import (
|
||||||
|
WatermarkDetector,
|
||||||
|
WatermarkDetectorOutput,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_tf_available():
|
if not is_tf_available():
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import copy
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
|
from dataclasses import dataclass, is_dataclass
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
||||||
|
|
||||||
from .. import __version__
|
from .. import __version__
|
||||||
@@ -221,6 +222,23 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
low_memory (`bool`, *optional*):
|
low_memory (`bool`, *optional*):
|
||||||
Switch to sequential beam search and sequential topk for contrastive search to reduce peak memory.
|
Switch to sequential beam search and sequential topk for contrastive search to reduce peak memory.
|
||||||
Used with beam search and contrastive search.
|
Used with beam search and contrastive search.
|
||||||
|
watermarking_config (Union[`WatermarkingConfig`, `dict`], *optional*):
|
||||||
|
Arguments used to watermark the model outputs by adding a small bias to randomly selected set of "green" tokens.
|
||||||
|
If passed as `Dict`, it will be converted to a `WatermarkingConfig` internally.
|
||||||
|
See [this paper](https://arxiv.org/abs/2306.04634) for more details. Accepts the following keys:
|
||||||
|
- greenlist_ratio (`float`):
|
||||||
|
Used for watermarking. The ratio of "green" tokens used to the vocabulary size. Defaults to 0.25.
|
||||||
|
- bias (`float`):
|
||||||
|
Used with watermarking. The bias added to the selected "green" tokens' logits. Defaults to 2.0.
|
||||||
|
- hashing_key (`int`):
|
||||||
|
Hahsing key used for watermarking. Defaults to 15485863 (the millionth prime).
|
||||||
|
- seeding_scheme (`str`):
|
||||||
|
Algorithm to use for watermarking. Accepts values:
|
||||||
|
- "lefthash" (default): "green" tokens selection depend on the last token (Algorithm 2 from the paper)
|
||||||
|
- "selfhash": "green" tokens selection depends on the current token itself (Algorithm 3 from the paper)
|
||||||
|
The downside of this scheme is that it considers all possible next tokens and can be slower than "lefthash".
|
||||||
|
- context_width(`int`):
|
||||||
|
The context length of previous tokens to use in seeding. Higher context length makes watermarking more robust.
|
||||||
|
|
||||||
> Parameters that define the output variables of generate
|
> Parameters that define the output variables of generate
|
||||||
|
|
||||||
@@ -333,6 +351,13 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
self.sequence_bias = kwargs.pop("sequence_bias", None)
|
self.sequence_bias = kwargs.pop("sequence_bias", None)
|
||||||
self.guidance_scale = kwargs.pop("guidance_scale", None)
|
self.guidance_scale = kwargs.pop("guidance_scale", None)
|
||||||
self.low_memory = kwargs.pop("low_memory", None)
|
self.low_memory = kwargs.pop("low_memory", None)
|
||||||
|
watermarking_config = kwargs.pop("watermarking_config", None)
|
||||||
|
if watermarking_config is None:
|
||||||
|
self.watermarking_config = None
|
||||||
|
elif isinstance(watermarking_config, WatermarkingConfig):
|
||||||
|
self.watermarking_config = watermarking_config
|
||||||
|
else:
|
||||||
|
self.watermarking_config = WatermarkingConfig.from_dict(watermarking_config)
|
||||||
|
|
||||||
# Parameters that define the output variables of `generate`
|
# Parameters that define the output variables of `generate`
|
||||||
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
|
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
|
||||||
@@ -613,6 +638,12 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
f"({self.num_beams})."
|
f"({self.num_beams})."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# check watermarking arguments
|
||||||
|
if self.watermarking_config is not None:
|
||||||
|
if not isinstance(self.watermarking_config, WatermarkingConfig):
|
||||||
|
self.watermarking_config = WatermarkingConfig.from_dict(self.watermarking_config)
|
||||||
|
self.watermarking_config.validate()
|
||||||
|
|
||||||
# 5. check common issue: passing `generate` arguments inside the generation config
|
# 5. check common issue: passing `generate` arguments inside the generation config
|
||||||
generate_arguments = (
|
generate_arguments = (
|
||||||
"logits_processor",
|
"logits_processor",
|
||||||
@@ -1021,7 +1052,16 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
else:
|
else:
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
def convert_dataclass_to_dict(obj):
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
return {key: convert_dataclass_to_dict(value) for key, value in obj.items()}
|
||||||
|
elif is_dataclass(obj):
|
||||||
|
return obj.to_dict()
|
||||||
|
else:
|
||||||
|
return obj
|
||||||
|
|
||||||
config_dict = convert_keys_to_string(config_dict)
|
config_dict = convert_keys_to_string(config_dict)
|
||||||
|
config_dict = convert_dataclass_to_dict(config_dict)
|
||||||
|
|
||||||
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
||||||
|
|
||||||
@@ -1093,3 +1133,142 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
# Remove all the attributes that were updated, without modifying the input dict
|
# Remove all the attributes that were updated, without modifying the input dict
|
||||||
unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
|
unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
|
||||||
return unused_kwargs
|
return unused_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class WatermarkingConfig:
|
||||||
|
"""
|
||||||
|
Class that holds arguments for watermark generation and should be passed into `GenerationConfig` during `generate`.
|
||||||
|
See [this paper](https://arxiv.org/abs/2306.04634) for more details on the arguments.
|
||||||
|
|
||||||
|
Accepts the following keys:
|
||||||
|
- greenlist_ratio (`float`):
|
||||||
|
Used for watermarking. The ratio of "green" tokens used to the vocabulary size. Defaults to 0.25.
|
||||||
|
- bias (`float`):
|
||||||
|
Used with watermarking. The bias added to the selected "green" tokens' logits. Defaults to 2.0.
|
||||||
|
- hashing_key (`int`):
|
||||||
|
Hashing key used for watermarking. Defaults to 15485863 (the millionth prime).
|
||||||
|
- seeding_scheme (`str`):
|
||||||
|
Algorithm to use for watermarking. Accepts values:
|
||||||
|
- "lefthash" (default): "green" tokens selection depend on the last token (Algorithm 2 from the paper)
|
||||||
|
- "selfhash": "green" tokens selection depends on the current token itself (Algorithm 3 from the paper)
|
||||||
|
The downside of this scheme is that it considers all possible next tokens and can be slower than "lefthash".
|
||||||
|
- context_width(`int`):
|
||||||
|
The context length of previous tokens to use in seeding. Higher context length makes watermarking more robust.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
greenlist_ratio: Optional[float] = 0.25,
|
||||||
|
bias: Optional[float] = 2.0,
|
||||||
|
hashing_key: Optional[int] = 15485863,
|
||||||
|
seeding_scheme: Optional[str] = "lefthash",
|
||||||
|
context_width: Optional[int] = 1,
|
||||||
|
):
|
||||||
|
self.greenlist_ratio = greenlist_ratio
|
||||||
|
self.bias = bias
|
||||||
|
self.hashing_key = hashing_key
|
||||||
|
self.seeding_scheme = seeding_scheme
|
||||||
|
self.context_width = context_width
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, config_dict, **kwargs):
|
||||||
|
"""
|
||||||
|
Constructs a WatermarkingConfig instance from a dictionary of parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_dict (Dict[str, Any]): Dictionary containing configuration parameters.
|
||||||
|
**kwargs: Additional keyword arguments to override dictionary values.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
WatermarkingConfig: Instance of WatermarkingConfig constructed from the dictionary.
|
||||||
|
"""
|
||||||
|
config = cls(**config_dict)
|
||||||
|
to_remove = []
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if hasattr(config, key):
|
||||||
|
setattr(config, key, value)
|
||||||
|
to_remove.append(key)
|
||||||
|
for key in to_remove:
|
||||||
|
kwargs.pop(key, None)
|
||||||
|
return config
|
||||||
|
|
||||||
|
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
|
||||||
|
"""
|
||||||
|
Save this instance to a JSON file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
json_file_path (Union[str, os.PathLike]): Path to the JSON file in which this configuration instance's parameters will be saved.
|
||||||
|
"""
|
||||||
|
with open(json_file_path, "w", encoding="utf-8") as writer:
|
||||||
|
config_dict = self.to_dict()
|
||||||
|
json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
||||||
|
|
||||||
|
writer.write(json_string)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Serializes this instance to a Python dictionary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: Dictionary of all the attributes that make up this configuration instance.
|
||||||
|
"""
|
||||||
|
output = copy.deepcopy(self.__dict__)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
for attr, value in copy.deepcopy(self.__dict__).items():
|
||||||
|
yield attr, value
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"{self.__class__.__name__} {self.to_json_string()}"
|
||||||
|
|
||||||
|
def to_json_string(self):
|
||||||
|
"""
|
||||||
|
Serializes this instance to a JSON formatted string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: JSON formatted string representing the configuration instance.
|
||||||
|
"""
|
||||||
|
return json.dumps(self.__dict__, indent=2) + "\n"
|
||||||
|
|
||||||
|
def update(self, **kwargs):
|
||||||
|
"""
|
||||||
|
Update the configuration attributes with new values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Keyword arguments representing configuration attributes and their new values.
|
||||||
|
"""
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if hasattr(self, key):
|
||||||
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
def validate(self):
|
||||||
|
watermark_missing_arg_msg = (
|
||||||
|
"Some of the keys in `watermarking_config` are defined incorrectly. `{key}` should be {correct_value}` "
|
||||||
|
"but found {found_value}"
|
||||||
|
)
|
||||||
|
if self.seeding_scheme not in ["selfhash", "lefthash"]:
|
||||||
|
raise ValueError(
|
||||||
|
watermark_missing_arg_msg.format(
|
||||||
|
key="seeding_scheme",
|
||||||
|
correct_value="[`selfhash`, `lefthash`]",
|
||||||
|
found_value=self.seeding_scheme,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if not 0.0 <= self.greenlist_ratio <= 1.0:
|
||||||
|
raise ValueError(
|
||||||
|
watermark_missing_arg_msg.format(
|
||||||
|
key="greenlist_ratio",
|
||||||
|
correct_value="in range between 0.0 and 1.0",
|
||||||
|
found_value=self.seeding_scheme,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if not self.context_width >= 1:
|
||||||
|
raise ValueError(
|
||||||
|
watermark_missing_arg_msg.format(
|
||||||
|
key="context_width",
|
||||||
|
correct_value="a positive integer",
|
||||||
|
found_value=self.context_width,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|||||||
@@ -2321,3 +2321,144 @@ class BarkEosPrioritizerLogitsProcessor(LogitsProcessor):
|
|||||||
scores_processed = torch.where(do_early_stop, early_stop_scores, scores)
|
scores_processed = torch.where(do_early_stop, early_stop_scores, scores)
|
||||||
|
|
||||||
return scores_processed
|
return scores_processed
|
||||||
|
|
||||||
|
|
||||||
|
class WatermarkLogitsProcessor(LogitsProcessor):
|
||||||
|
r"""
|
||||||
|
Logits processor for watermarking generated text. The processor modifies model output scores by adding a small bias to
|
||||||
|
randomized set of "green" tokens before generating the next token. "Green" tokens selection process depends on the
|
||||||
|
`seeding_scheme` used. The code was based on the [original repo](https://github.com/jwkirchenbauer/lm-watermarking/tree/main).
|
||||||
|
|
||||||
|
The text generated by this `LogitsProcessor` can be detected using `WatermarkDetector`. See [`~WatermarkDetector.__call__`] for details,
|
||||||
|
|
||||||
|
See [the paper](https://arxiv.org/abs/2306.04634) for more information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_size (`int`):
|
||||||
|
The model tokenizer's vocab_size. Used to calculate "green" tokens ratio.
|
||||||
|
device (`str`):
|
||||||
|
The device where model is allocated.
|
||||||
|
greenlist_ratio (`float`, optional, *optional*, defaults to 0.25):
|
||||||
|
The ratio of "green" tokens used to the vocabulary size. Defaults to 0.25.
|
||||||
|
bias (`float`, optional, *optional*, defaults to 2.0):
|
||||||
|
The bias added to the selected "green" tokens' logits. Consider lowering the
|
||||||
|
`bias` if the text generation quality degrades. Recommended values are in the
|
||||||
|
range of [0.5, 2.0]. Defaults to 2.0.
|
||||||
|
hashing_key (`int`, optional, *optional*, defaults to 15485863):
|
||||||
|
Key used for hashing. If you deploy this watermark, we advise using another private key.
|
||||||
|
Defaults to 15485863 (the millionth prime).
|
||||||
|
seeding_scheme (`str`, optional, *optional*, defaults to `"lefthash"`):
|
||||||
|
The seeding scheme used for selecting "green" tokens. Accepts values:
|
||||||
|
- "lefthash" (default): "green" tokens selection depend on the last token (Algorithm 2 from paper)
|
||||||
|
- "selfhash": "green" tokens selection depends on the current token itself (Algorithm 3 from paper)
|
||||||
|
The downside of this scheme is that it considers all possible next tokens and can be slower than "lefthash".
|
||||||
|
The context length of previous tokens to use in seeding. Higher context length makes watermarking more robust.
|
||||||
|
context_width (`int`, *optional*, defaults to 1):
|
||||||
|
The number of previous tokens to use when setting the seed.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, WatermarkingConfig
|
||||||
|
|
||||||
|
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
||||||
|
>>> inputs = tokenizer(["Alice and Bob are"], return_tensors="pt")
|
||||||
|
|
||||||
|
>>> # normal generation
|
||||||
|
>>> out = model.generate(inputs["input_ids"], max_length=20, do_sample=False)
|
||||||
|
>>> tokenizer.batch_decode(out, skip_special_tokens=True)[0]
|
||||||
|
'Alice and Bob are both in the same room.\n\n"I\'m not sure if you\'re'
|
||||||
|
|
||||||
|
>>> # watermarked generation
|
||||||
|
>>> watermarking_config = WatermarkingConfig(bias=2.5, context_width=2, seeding_scheme="selfhash")
|
||||||
|
>>> out = model.generate(inputs["input_ids"], watermarking_config=watermarking_config, max_length=20, do_sample=False)
|
||||||
|
>>> tokenizer.batch_decode(out, skip_special_tokens=True)[0]
|
||||||
|
'Alice and Bob are both still alive and well and the story is pretty much a one-hour adventure'
|
||||||
|
|
||||||
|
>>> # to detect watermarked text use the WatermarkDetector class
|
||||||
|
>>> from transformers import WatermarkDetector
|
||||||
|
>>> detector = WatermarkDetector(model_config=model.config, device="cpu", watermarking_config= watermarking_config)
|
||||||
|
>>> detection_preds = detector(out)
|
||||||
|
>>> detection_preds
|
||||||
|
array([ True])
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size,
|
||||||
|
device,
|
||||||
|
greenlist_ratio: float = 0.25,
|
||||||
|
bias: float = 2.0,
|
||||||
|
hashing_key: int = 15485863,
|
||||||
|
seeding_scheme: str = "lefthash",
|
||||||
|
context_width: int = 1,
|
||||||
|
):
|
||||||
|
if seeding_scheme not in ["selfhash", "lefthash"]:
|
||||||
|
raise ValueError(f"seeding_scheme has to be one of [`selfhash`, `lefthash`], but found {seeding_scheme}")
|
||||||
|
if greenlist_ratio >= 1.0 or greenlist_ratio <= 0.0:
|
||||||
|
raise ValueError(
|
||||||
|
f"greenlist_ratio has be in range between 0.0 and 1.0, exclusively. but found {greenlist_ratio}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.greenlist_size = int(self.vocab_size * greenlist_ratio)
|
||||||
|
self.bias = bias
|
||||||
|
self.seeding_scheme = seeding_scheme
|
||||||
|
self.rng = torch.Generator(device=device)
|
||||||
|
self.hash_key = hashing_key
|
||||||
|
self.context_width = context_width
|
||||||
|
|
||||||
|
self.rng.manual_seed(hashing_key)
|
||||||
|
self.table_size = 1_000_003
|
||||||
|
self.fixed_table = torch.randperm(self.table_size, generator=self.rng, device=device)
|
||||||
|
|
||||||
|
def set_seed(self, input_seq: torch.LongTensor):
|
||||||
|
input_seq = input_seq[-self.context_width :]
|
||||||
|
if self.seeding_scheme == "selfhash":
|
||||||
|
a = self.fixed_table[input_seq % self.table_size] + 1
|
||||||
|
b = self.fixed_table[input_seq[-1] % self.table_size] + 1
|
||||||
|
seed = (self.hash_key * a * b).min().item()
|
||||||
|
else:
|
||||||
|
seed = self.hash_key * input_seq[-1].item()
|
||||||
|
self.rng.manual_seed(seed % (2**64 - 1))
|
||||||
|
|
||||||
|
def _get_greenlist_ids(self, input_seq: torch.LongTensor) -> torch.LongTensor:
|
||||||
|
self.set_seed(input_seq)
|
||||||
|
vocab_permutation = torch.randperm(self.vocab_size, device=input_seq.device, generator=self.rng)
|
||||||
|
greenlist_ids = vocab_permutation[: self.greenlist_size]
|
||||||
|
return greenlist_ids
|
||||||
|
|
||||||
|
def _score_rejection_sampling(self, input_seq: torch.LongTensor, scores: torch.FloatTensor) -> torch.LongTensor:
|
||||||
|
"""
|
||||||
|
Generate greenlist based on current candidate next token. Reject and move on if necessary.
|
||||||
|
Runs for a fixed number of steps only for efficiency, since the methods is not batched.
|
||||||
|
"""
|
||||||
|
final_greenlist = []
|
||||||
|
_, greedy_predictions = scores.sort(dim=-1, descending=True)
|
||||||
|
|
||||||
|
# 40 is an arbitrary number chosen to save compute and not run for long (taken from orig repo)
|
||||||
|
for i in range(40):
|
||||||
|
greenlist_ids = self._get_greenlist_ids(torch.cat([input_seq, greedy_predictions[i, None]], dim=-1))
|
||||||
|
if greedy_predictions[i] in greenlist_ids:
|
||||||
|
final_greenlist.append(greedy_predictions[i])
|
||||||
|
return torch.tensor(final_greenlist, device=input_seq.device)
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
if input_ids.shape[-1] < self.context_width:
|
||||||
|
logger.warning(
|
||||||
|
f"`input_ids` should have at least `{self.context_width}` tokens but has {input_ids.shape[-1]}. "
|
||||||
|
"The seeding will be skipped for this generation step!"
|
||||||
|
)
|
||||||
|
return scores
|
||||||
|
|
||||||
|
scores_processed = scores.clone()
|
||||||
|
for b_idx, input_seq in enumerate(input_ids):
|
||||||
|
if self.seeding_scheme == "selfhash":
|
||||||
|
greenlist_ids = self._score_rejection_sampling(input_seq, scores[b_idx])
|
||||||
|
else:
|
||||||
|
greenlist_ids = self._get_greenlist_ids(input_seq)
|
||||||
|
scores_processed[b_idx, greenlist_ids] = scores_processed[b_idx, greenlist_ids] + self.bias
|
||||||
|
|
||||||
|
return scores_processed
|
||||||
|
|||||||
@@ -74,6 +74,7 @@ from .logits_process import (
|
|||||||
TopPLogitsWarper,
|
TopPLogitsWarper,
|
||||||
TypicalLogitsWarper,
|
TypicalLogitsWarper,
|
||||||
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
||||||
|
WatermarkLogitsProcessor,
|
||||||
)
|
)
|
||||||
from .stopping_criteria import (
|
from .stopping_criteria import (
|
||||||
EosTokenCriteria,
|
EosTokenCriteria,
|
||||||
@@ -763,6 +764,7 @@ class GenerationMixin:
|
|||||||
encoder_input_ids: torch.LongTensor,
|
encoder_input_ids: torch.LongTensor,
|
||||||
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
|
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
|
||||||
logits_processor: Optional[LogitsProcessorList],
|
logits_processor: Optional[LogitsProcessorList],
|
||||||
|
device: str = None,
|
||||||
model_kwargs: Optional[Dict[str, Any]] = None,
|
model_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
negative_prompt_ids: Optional[torch.Tensor] = None,
|
negative_prompt_ids: Optional[torch.Tensor] = None,
|
||||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||||
@@ -879,6 +881,18 @@ class GenerationMixin:
|
|||||||
FutureWarning,
|
FutureWarning,
|
||||||
)
|
)
|
||||||
processors.append(ForceTokensLogitsProcessor(generation_config.forced_decoder_ids, _has_warned=True))
|
processors.append(ForceTokensLogitsProcessor(generation_config.forced_decoder_ids, _has_warned=True))
|
||||||
|
if generation_config.watermarking_config is not None:
|
||||||
|
processors.append(
|
||||||
|
WatermarkLogitsProcessor(
|
||||||
|
vocab_size=self.config.vocab_size,
|
||||||
|
device=device,
|
||||||
|
greenlist_ratio=generation_config.watermarking_config.greenlist_ratio,
|
||||||
|
bias=generation_config.watermarking_config.bias,
|
||||||
|
hashing_key=generation_config.watermarking_config.hashing_key,
|
||||||
|
seeding_scheme=generation_config.watermarking_config.seeding_scheme,
|
||||||
|
context_width=generation_config.watermarking_config.context_width,
|
||||||
|
)
|
||||||
|
)
|
||||||
processors = self._merge_criteria_processor_list(processors, logits_processor)
|
processors = self._merge_criteria_processor_list(processors, logits_processor)
|
||||||
# `LogitNormalization` should always be the last logit processor, when present
|
# `LogitNormalization` should always be the last logit processor, when present
|
||||||
if generation_config.renormalize_logits is True:
|
if generation_config.renormalize_logits is True:
|
||||||
@@ -1632,6 +1646,7 @@ class GenerationMixin:
|
|||||||
encoder_input_ids=inputs_tensor,
|
encoder_input_ids=inputs_tensor,
|
||||||
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
|
device=inputs_tensor.device,
|
||||||
model_kwargs=model_kwargs,
|
model_kwargs=model_kwargs,
|
||||||
negative_prompt_ids=negative_prompt_ids,
|
negative_prompt_ids=negative_prompt_ids,
|
||||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||||
|
|||||||
240
src/transformers/generation/watermarking.py
Normal file
240
src/transformers/generation/watermarking.py
Normal file
@@ -0,0 +1,240 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import collections
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from ..configuration_utils import PretrainedConfig
|
||||||
|
from ..utils import is_torch_available, logging
|
||||||
|
from .configuration_utils import WatermarkingConfig
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .logits_process import WatermarkLogitsProcessor
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class WatermarkDetectorOutput:
|
||||||
|
"""
|
||||||
|
Outputs of a watermark detector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_tokens_scored (np.array of shape (batch_size)):
|
||||||
|
Array containing the number of tokens scored for each element in the batch.
|
||||||
|
num_green_tokens (np.array of shape (batch_size)):
|
||||||
|
Array containing the number of green tokens for each element in the batch.
|
||||||
|
green_fraction (np.array of shape (batch_size)):
|
||||||
|
Array containing the fraction of green tokens for each element in the batch.
|
||||||
|
z_score (np.array of shape (batch_size)):
|
||||||
|
Array containing the z-score for each element in the batch. Z-score here shows
|
||||||
|
how many standard deviations away is the green token count in the input text
|
||||||
|
from the expected green token count for machine-generated text.
|
||||||
|
p_value (np.array of shape (batch_size)):
|
||||||
|
Array containing the p-value for each batch obtained from z-scores.
|
||||||
|
prediction (np.array of shape (batch_size)), *optional*:
|
||||||
|
Array containing boolean predictions whether a text is machine-generated for each element in the batch.
|
||||||
|
confidence (np.array of shape (batch_size)), *optional*:
|
||||||
|
Array containing confidence scores of a text being machine-generated for each element in the batch.
|
||||||
|
"""
|
||||||
|
|
||||||
|
num_tokens_scored: np.array = None
|
||||||
|
num_green_tokens: np.array = None
|
||||||
|
green_fraction: np.array = None
|
||||||
|
z_score: np.array = None
|
||||||
|
p_value: np.array = None
|
||||||
|
prediction: Optional[np.array] = None
|
||||||
|
confidence: Optional[np.array] = None
|
||||||
|
|
||||||
|
|
||||||
|
class WatermarkDetector:
|
||||||
|
|
||||||
|
r"""
|
||||||
|
Detector for detection of watermark generated text. The detector needs to be given the exact same settings that were
|
||||||
|
given during text generation to replicate the watermark greenlist generation and so detect the watermark. This includes
|
||||||
|
the correct device that was used during text generation, the correct watermarking arguments and the correct tokenizer vocab size.
|
||||||
|
The code was based on the [original repo](https://github.com/jwkirchenbauer/lm-watermarking/tree/main).
|
||||||
|
|
||||||
|
See [the paper](https://arxiv.org/abs/2306.04634) for more information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_config (`PretrainedConfig`):
|
||||||
|
The model config that will be used to get model specific arguments used when generating.
|
||||||
|
device (`str`):
|
||||||
|
The device which was used during watermarked text generation.
|
||||||
|
watermarking_config (Union[`WatermarkingConfig`, `Dict`]):
|
||||||
|
The exact same watermarking config and arguments used when generating text.
|
||||||
|
ignore_repeated_ngrams (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to count every unique ngram only once or not.
|
||||||
|
max_cache_size (`int`, *optional*, defaults to 128):
|
||||||
|
The max size to be used for LRU caching of seeding/sampling algorithms called for every token.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, WatermarkDetector, WatermarkingConfig
|
||||||
|
|
||||||
|
>>> model_id = "openai-community/gpt2"
|
||||||
|
>>> model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||||
|
>>> tok = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
>>> tok.pad_token_id = tok.eos_token_id
|
||||||
|
>>> tok.padding_side = "left"
|
||||||
|
|
||||||
|
>>> inputs = tok(["This is the beginning of a long story", "Alice and Bob are"], padding=True, return_tensors="pt")
|
||||||
|
>>> input_len = inputs["input_ids"].shape[-1]
|
||||||
|
|
||||||
|
>>> # first generate text with watermark and without
|
||||||
|
>>> watermarking_config = WatermarkingConfig(bias=2.5, seeding_scheme="selfhash")
|
||||||
|
>>> out_watermarked = model.generate(**inputs, watermarking_config=watermarking_config, do_sample=False, max_length=20)
|
||||||
|
>>> out = model.generate(**inputs, do_sample=False, max_length=20)
|
||||||
|
|
||||||
|
>>> # now we can instantiate the detector and check the generated text
|
||||||
|
>>> detector = WatermarkDetector(model_config=model.config, device="cpu", watermarking_config=watermarking_config)
|
||||||
|
>>> detection_out_watermarked = detector(out_watermarked, return_dict=True)
|
||||||
|
>>> detection_out = detector(out, return_dict=True)
|
||||||
|
>>> detection_out_watermarked.prediction
|
||||||
|
array([ True, True])
|
||||||
|
|
||||||
|
>>> detection_out.prediction
|
||||||
|
array([False, False])
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_config: PretrainedConfig,
|
||||||
|
device: str,
|
||||||
|
watermarking_config: Union[WatermarkingConfig, Dict],
|
||||||
|
ignore_repeated_ngrams: bool = False,
|
||||||
|
max_cache_size: int = 128,
|
||||||
|
):
|
||||||
|
if isinstance(watermarking_config, WatermarkingConfig):
|
||||||
|
watermarking_config = watermarking_config.to_dict()
|
||||||
|
|
||||||
|
self.bos_token_id = (
|
||||||
|
model_config.bos_token_id if not model_config.is_encoder_decoder else model_config.decoder_start_token_id
|
||||||
|
)
|
||||||
|
self.greenlist_ratio = watermarking_config["greenlist_ratio"]
|
||||||
|
self.ignore_repeated_ngrams = ignore_repeated_ngrams
|
||||||
|
self.processor = WatermarkLogitsProcessor(
|
||||||
|
vocab_size=model_config.vocab_size, device=device, **watermarking_config
|
||||||
|
)
|
||||||
|
|
||||||
|
# Expensive re-seeding and sampling is cached.
|
||||||
|
self._get_ngram_score_cached = lru_cache(maxsize=max_cache_size)(self._get_ngram_score)
|
||||||
|
|
||||||
|
def _get_ngram_score(self, prefix: torch.LongTensor, target: int):
|
||||||
|
greenlist_ids = self.processor._get_greenlist_ids(prefix)
|
||||||
|
return target in greenlist_ids
|
||||||
|
|
||||||
|
def _score_ngrams_in_passage(self, input_ids: torch.LongTensor):
|
||||||
|
batch_size, seq_length = input_ids.shape
|
||||||
|
selfhash = int(self.processor.seeding_scheme == "selfhash")
|
||||||
|
n = self.processor.context_width + 1 - selfhash
|
||||||
|
indices = torch.arange(n).unsqueeze(0) + torch.arange(seq_length - n + 1).unsqueeze(1)
|
||||||
|
ngram_tensors = input_ids[:, indices]
|
||||||
|
|
||||||
|
num_tokens_scored_batch = np.zeros(batch_size)
|
||||||
|
green_token_count_batch = np.zeros(batch_size)
|
||||||
|
for batch_idx in range(ngram_tensors.shape[0]):
|
||||||
|
frequencies_table = collections.Counter(ngram_tensors[batch_idx])
|
||||||
|
ngram_to_watermark_lookup = {}
|
||||||
|
for ngram_example in frequencies_table.keys():
|
||||||
|
prefix = ngram_example if selfhash else ngram_example[:-1]
|
||||||
|
target = ngram_example[-1]
|
||||||
|
ngram_to_watermark_lookup[ngram_example] = self._get_ngram_score_cached(prefix, target)
|
||||||
|
|
||||||
|
if self.ignore_repeated_ngrams:
|
||||||
|
# counts a green/red hit once per unique ngram.
|
||||||
|
# num total tokens scored becomes the number unique ngrams.
|
||||||
|
num_tokens_scored_batch[batch_idx] = len(frequencies_table.keys())
|
||||||
|
green_token_count_batch[batch_idx] = sum(ngram_to_watermark_lookup.values())
|
||||||
|
else:
|
||||||
|
num_tokens_scored_batch[batch_idx] = sum(frequencies_table.values())
|
||||||
|
green_token_count_batch[batch_idx] = sum(
|
||||||
|
freq * outcome
|
||||||
|
for freq, outcome in zip(frequencies_table.values(), ngram_to_watermark_lookup.values())
|
||||||
|
)
|
||||||
|
return num_tokens_scored_batch, green_token_count_batch
|
||||||
|
|
||||||
|
def _compute_z_score(self, green_token_count: np.array, total_num_tokens: np.array) -> np.array:
|
||||||
|
expected_count = self.greenlist_ratio
|
||||||
|
numer = green_token_count - expected_count * total_num_tokens
|
||||||
|
denom = np.sqrt(total_num_tokens * expected_count * (1 - expected_count))
|
||||||
|
z = numer / denom
|
||||||
|
return z
|
||||||
|
|
||||||
|
def _compute_pval(self, x, loc=0, scale=1):
|
||||||
|
z = (x - loc) / scale
|
||||||
|
return 1 - (0.5 * (1 + np.sign(z) * (1 - np.exp(-2 * z**2 / np.pi))))
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
z_threshold: float = 3.0,
|
||||||
|
return_dict: bool = False,
|
||||||
|
) -> Union[WatermarkDetectorOutput, np.array]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
input_ids (`torch.LongTensor`):
|
||||||
|
The watermark generated text. It is advised to remove the prompt, which can affect the detection.
|
||||||
|
z_threshold (`Dict`, *optional*, defaults to `3.0`):
|
||||||
|
Changing this threshold will change the sensitivity of the detector. Higher z threshold gives less
|
||||||
|
sensitivity and vice versa for lower z threshold.
|
||||||
|
return_dict (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to return `~generation.WatermarkDetectorOutput` or not. If not it will return boolean predictions,
|
||||||
|
ma
|
||||||
|
Return:
|
||||||
|
[`~generation.WatermarkDetectorOutput`] or `np.array`: A [`~generation.WatermarkDetectorOutput`]
|
||||||
|
if `return_dict=True` otherwise a `np.array`.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Let's assume that if one batch start with `bos`, all batched also do
|
||||||
|
if input_ids[0, 0] == self.bos_token_id:
|
||||||
|
input_ids = input_ids[:, 1:]
|
||||||
|
|
||||||
|
if input_ids.shape[-1] - self.processor.context_width < 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Must have at least `1` token to score after the first "
|
||||||
|
f"min_prefix_len={self.processor.context_width} tokens required by the seeding scheme."
|
||||||
|
)
|
||||||
|
|
||||||
|
num_tokens_scored, green_token_count = self._score_ngrams_in_passage(input_ids)
|
||||||
|
z_score = self._compute_z_score(green_token_count, num_tokens_scored)
|
||||||
|
prediction = z_score > z_threshold
|
||||||
|
|
||||||
|
if return_dict:
|
||||||
|
p_value = self._compute_pval(z_score)
|
||||||
|
confidence = 1 - p_value
|
||||||
|
|
||||||
|
return WatermarkDetectorOutput(
|
||||||
|
num_tokens_scored=num_tokens_scored,
|
||||||
|
num_green_tokens=green_token_count,
|
||||||
|
green_fraction=green_token_count / num_tokens_scored,
|
||||||
|
z_score=z_score,
|
||||||
|
p_value=p_value,
|
||||||
|
prediction=prediction,
|
||||||
|
confidence=confidence,
|
||||||
|
)
|
||||||
|
return prediction
|
||||||
@@ -422,6 +422,20 @@ class UnbatchedClassifierFreeGuidanceLogitsProcessor(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class WatermarkDetector(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class WatermarkLogitsProcessor(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class WhisperTimeStampLogitsProcessor(metaclass=DummyObject):
|
class WhisperTimeStampLogitsProcessor(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
|||||||
@@ -53,6 +53,7 @@ if is_torch_available():
|
|||||||
TopPLogitsWarper,
|
TopPLogitsWarper,
|
||||||
TypicalLogitsWarper,
|
TypicalLogitsWarper,
|
||||||
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
||||||
|
WatermarkLogitsProcessor,
|
||||||
)
|
)
|
||||||
from transformers.generation.logits_process import BarkEosPrioritizerLogitsProcessor
|
from transformers.generation.logits_process import BarkEosPrioritizerLogitsProcessor
|
||||||
|
|
||||||
@@ -949,3 +950,26 @@ class LogitsProcessorTest(unittest.TestCase):
|
|||||||
[float("-inf"), float("-inf"), scores[0][0], scores[0][0]],
|
[float("-inf"), float("-inf"), scores[0][0], scores[0][0]],
|
||||||
]
|
]
|
||||||
self.assertListEqual(actual_scores.tolist(), expected_scores_list)
|
self.assertListEqual(actual_scores.tolist(), expected_scores_list)
|
||||||
|
|
||||||
|
def test_watermarking_processor(self):
|
||||||
|
batch_size = 3
|
||||||
|
vocab_size = 20
|
||||||
|
|
||||||
|
input_ids = ids_tensor((batch_size, 5), vocab_size=20)
|
||||||
|
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||||
|
|
||||||
|
# raise error if incorrect seeding_scheme is passed
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
WatermarkLogitsProcessor(vocab_size=vocab_size, device="cpu", seeding_scheme="hash")
|
||||||
|
|
||||||
|
# raise error if the greenlist_ratio in not in range (0.0, 1.0)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
WatermarkLogitsProcessor(vocab_size=vocab_size, device="cpu", greenlist_ratio=1.2)
|
||||||
|
|
||||||
|
watermark = WatermarkLogitsProcessor(vocab_size=vocab_size, device=input_ids.device)
|
||||||
|
|
||||||
|
# use fixed id for last token, needed for reprodicibility and tests
|
||||||
|
input_ids[:, -1] = 10
|
||||||
|
scores_wo_bias = scores[:, -1].clone()
|
||||||
|
out = watermark(input_ids=input_ids, scores=scores)
|
||||||
|
self.assertTrue((out[:, 1] == scores_wo_bias + watermark.bias).all())
|
||||||
|
|||||||
@@ -75,6 +75,8 @@ if is_torch_available():
|
|||||||
SampleEncoderDecoderOutput,
|
SampleEncoderDecoderOutput,
|
||||||
StoppingCriteria,
|
StoppingCriteria,
|
||||||
StoppingCriteriaList,
|
StoppingCriteriaList,
|
||||||
|
WatermarkDetector,
|
||||||
|
WatermarkingConfig,
|
||||||
)
|
)
|
||||||
from transformers.generation.utils import _speculative_sampling
|
from transformers.generation.utils import _speculative_sampling
|
||||||
|
|
||||||
@@ -2098,6 +2100,44 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||||||
)
|
)
|
||||||
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_watermark_generation(self):
|
||||||
|
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
|
||||||
|
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2").to(torch_device)
|
||||||
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
|
model_inputs = tokenizer("I will be", return_tensors="pt").to(torch_device)
|
||||||
|
input_len = model_inputs["input_ids"].shape[-1]
|
||||||
|
|
||||||
|
# generation should work with both input types: WatermarkingConfig or Dict, so let's check it here :)
|
||||||
|
watermark_config = WatermarkingConfig(bias=2.5, seeding_scheme="selfhash")
|
||||||
|
_ = model.generate(**model_inputs, watermarking_config=watermark_config, do_sample=False, max_length=15)
|
||||||
|
|
||||||
|
args = {
|
||||||
|
"bias": 2.0,
|
||||||
|
"context_width": 1,
|
||||||
|
"seeding_scheme": "selfhash",
|
||||||
|
"greenlist_ratio": 0.25,
|
||||||
|
"hashing_key": 15485863,
|
||||||
|
}
|
||||||
|
output = model.generate(**model_inputs, do_sample=False, max_length=15)
|
||||||
|
output_selfhash = model.generate(**model_inputs, watermarking_config=args, do_sample=False, max_length=15)
|
||||||
|
|
||||||
|
# check that the watermarked text is generating what is should
|
||||||
|
self.assertListEqual(
|
||||||
|
output.tolist(), [[40, 481, 307, 262, 717, 284, 9159, 326, 314, 716, 407, 257, 4336, 286, 262]]
|
||||||
|
)
|
||||||
|
self.assertListEqual(
|
||||||
|
output_selfhash.tolist(), [[40, 481, 307, 2263, 616, 640, 284, 651, 616, 1621, 503, 612, 553, 531, 367]]
|
||||||
|
)
|
||||||
|
|
||||||
|
detector = WatermarkDetector(model_config=model.config, device=torch_device, watermarking_config=args)
|
||||||
|
detection_out_watermarked = detector(output_selfhash[:, input_len:], return_dict=True)
|
||||||
|
detection_out = detector(output[:, input_len:], return_dict=True)
|
||||||
|
|
||||||
|
# check that the detector is detecting watermarked text
|
||||||
|
self.assertListEqual(detection_out_watermarked.prediction.tolist(), [True])
|
||||||
|
self.assertListEqual(detection_out.prediction.tolist(), [False])
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_beam_search_example_integration(self):
|
def test_beam_search_example_integration(self):
|
||||||
# PT-only test: TF doesn't have a BeamSearchScorer
|
# PT-only test: TF doesn't have a BeamSearchScorer
|
||||||
|
|||||||
Reference in New Issue
Block a user