Add Watermarking LogitsProcessor and WatermarkDetector (#29676)
* add watermarking processor * remove the other hashing (context width=1 always) * make style * Update src/transformers/generation/logits_process.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/logits_process.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/logits_process.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/configuration_utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * update watermarking process * add detector * update tests to use detector * fix failing tests * rename `input_seq` * make style * doc for processor * minor fixes * docs * make quality * Update src/transformers/generation/configuration_utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/logits_process.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/watermarking.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/watermarking.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/watermarking.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * add PR suggestions * let's use lru_cache's default max size (128) * import processor if torch available * maybe like this * lets move the config to torch independet file * add docs * tiny docs fix to make the test happy * Update src/transformers/generation/configuration_utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/watermarking.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * PR suggestions * add docs * fix test * fix docs * address pr comments * style * Revert "style" This reverts commit 7f33cc34ff08b414f8e7f90060889877606b43b2. * correct style * make doctest green --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
committed by
GitHub
parent
65ea1904ff
commit
5ad960f1f4
@@ -173,6 +173,55 @@ your screen, one word at a time:
|
||||
An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,
|
||||
```
|
||||
|
||||
|
||||
## Watermarking
|
||||
|
||||
The `generate()` supports watermarking the generated text by randomly marking a portion of tokens as "green".
|
||||
When generating the "green" will have a small 'bias' value added to their logits, thus having a higher chance to be generated.
|
||||
The watermarked text can be detected by calculating the proportion of "green" tokens in the text and estimating how likely it is
|
||||
statistically to obtain that amount of "green" tokens for human-generated text. This watermarking strategy was proposed in the paper
|
||||
["On the Reliability of Watermarks for Large Language Models"](https://arxiv.org/abs/2306.04634). For more information on
|
||||
the inner functioning of watermarking, it is recommended to refer to the paper.
|
||||
|
||||
The watermarking can be used with any generative model in `tranformers` and does not require an extra classification model
|
||||
to detect watermarked text. To trigger watermarking, pass in a [`WatermarkingConfig`] with needed arguments directly to the
|
||||
`.generate()` method or add it to the [`GenerationConfig`]. Watermarked text can be later detected with a [`WatermarkDetector`].
|
||||
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
The WatermarkDetector internally relies on the proportion of "green" tokens, and whether generated text follows the coloring pattern.
|
||||
That is why it is recommended to strip off the prompt text, if it is much longer than the generated text.
|
||||
This also can have an effect when one sequence in the batch is a lot longer causing other rows to be padded.
|
||||
Additionally, the detector **must** be initiated with identical watermark configuration arguments used when generating.
|
||||
|
||||
</Tip>
|
||||
|
||||
Let's generate some text with watermarking. In the below code snippet, we set the bias to 2.5 which is a value that
|
||||
will be added to "green" tokens' logits. After generating watermarked text, we can pass it directly to the `WatermarkDetector`
|
||||
to check if the text is machine-generated (outputs `True` for machine-generated and `False` otherwise).
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, WatermarkDetector, WatermarkingConfig
|
||||
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
|
||||
>>> tok = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
||||
>>> tok.pad_token_id = tok.eos_token_id
|
||||
>>> tok.padding_side = "left"
|
||||
|
||||
>>> inputs = tok(["This is the beginning of a long story", "Alice and Bob are"], padding=True, return_tensors="pt")
|
||||
>>> input_len = inputs["input_ids"].shape[-1]
|
||||
|
||||
>>> watermarking_config = WatermarkingConfig(bias=2.5, seeding_scheme="selfhash")
|
||||
>>> out = model.generate(**inputs, watermarking_config=watermarking_config, do_sample=False, max_length=20)
|
||||
|
||||
>>> detector = WatermarkDetector(model_config=model.config, device="cpu", watermarking_config=watermarking_config)
|
||||
>>> detection_out = detector(out, return_dict=True)
|
||||
>>> detection_out.prediction
|
||||
array([True, True])
|
||||
```
|
||||
|
||||
|
||||
## Decoding strategies
|
||||
|
||||
Certain combinations of the `generate()` parameters, and ultimately `generation_config`, can be used to enable specific
|
||||
|
||||
@@ -209,6 +209,10 @@ generation.
|
||||
[[autodoc]] WhisperTimeStampLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] WatermarkLogitsProcessor
|
||||
- __call__
|
||||
|
||||
|
||||
### TensorFlow
|
||||
|
||||
[[autodoc]] TFForcedBOSTokenLogitsProcessor
|
||||
@@ -372,3 +376,10 @@ A [`Constraint`] can be used to force the generation to include specific tokens
|
||||
- update
|
||||
- get_seq_length
|
||||
- reorder_cache
|
||||
|
||||
|
||||
## Watermark Utils
|
||||
|
||||
[[autodoc]] WatermarkDetector
|
||||
- __call__
|
||||
|
||||
|
||||
@@ -41,6 +41,8 @@ like token streaming.
|
||||
- validate
|
||||
- get_generation_mode
|
||||
|
||||
[[autodoc]] generation.WatermarkingConfig
|
||||
|
||||
## GenerationMixin
|
||||
|
||||
[[autodoc]] generation.GenerationMixin
|
||||
|
||||
@@ -117,7 +117,12 @@ _import_structure = {
|
||||
"feature_extraction_sequence_utils": ["SequenceFeatureExtractor"],
|
||||
"feature_extraction_utils": ["BatchFeature", "FeatureExtractionMixin"],
|
||||
"file_utils": [],
|
||||
"generation": ["GenerationConfig", "TextIteratorStreamer", "TextStreamer"],
|
||||
"generation": [
|
||||
"GenerationConfig",
|
||||
"TextIteratorStreamer",
|
||||
"TextStreamer",
|
||||
"WatermarkingConfig",
|
||||
],
|
||||
"hf_argparser": ["HfArgumentParser"],
|
||||
"hyperparameter_search": [],
|
||||
"image_transforms": [],
|
||||
@@ -1232,6 +1237,8 @@ else:
|
||||
"TopPLogitsWarper",
|
||||
"TypicalLogitsWarper",
|
||||
"UnbatchedClassifierFreeGuidanceLogitsProcessor",
|
||||
"WatermarkDetector",
|
||||
"WatermarkLogitsProcessor",
|
||||
"WhisperTimeStampLogitsProcessor",
|
||||
]
|
||||
)
|
||||
@@ -4617,7 +4624,7 @@ if TYPE_CHECKING:
|
||||
from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin
|
||||
|
||||
# Generation
|
||||
from .generation import GenerationConfig, TextIteratorStreamer, TextStreamer
|
||||
from .generation import GenerationConfig, TextIteratorStreamer, TextStreamer, WatermarkingConfig
|
||||
from .hf_argparser import HfArgumentParser
|
||||
|
||||
# Integrations
|
||||
@@ -5797,6 +5804,8 @@ if TYPE_CHECKING:
|
||||
TopPLogitsWarper,
|
||||
TypicalLogitsWarper,
|
||||
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
||||
WatermarkDetector,
|
||||
WatermarkLogitsProcessor,
|
||||
WhisperTimeStampLogitsProcessor,
|
||||
)
|
||||
from .modeling_utils import PreTrainedModel
|
||||
|
||||
@@ -18,7 +18,7 @@ from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_availab
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"configuration_utils": ["GenerationConfig", "GenerationMode"],
|
||||
"configuration_utils": ["GenerationConfig", "GenerationMode", "WatermarkingConfig"],
|
||||
"streamers": ["TextIteratorStreamer", "TextStreamer"],
|
||||
}
|
||||
|
||||
@@ -78,6 +78,7 @@ else:
|
||||
"TypicalLogitsWarper",
|
||||
"UnbatchedClassifierFreeGuidanceLogitsProcessor",
|
||||
"WhisperTimeStampLogitsProcessor",
|
||||
"WatermarkLogitsProcessor",
|
||||
]
|
||||
_import_structure["stopping_criteria"] = [
|
||||
"MaxNewTokensCriteria",
|
||||
@@ -106,6 +107,10 @@ else:
|
||||
"GenerateDecoderOnlyOutput",
|
||||
"GenerateEncoderDecoderOutput",
|
||||
]
|
||||
_import_structure["watermarking"] = [
|
||||
"WatermarkDetector",
|
||||
"WatermarkDetectorOutput",
|
||||
]
|
||||
|
||||
try:
|
||||
if not is_tf_available():
|
||||
@@ -174,7 +179,7 @@ else:
|
||||
]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_utils import GenerationConfig, GenerationMode
|
||||
from .configuration_utils import GenerationConfig, GenerationMode, WatermarkingConfig
|
||||
from .streamers import TextIteratorStreamer, TextStreamer
|
||||
|
||||
try:
|
||||
@@ -218,6 +223,7 @@ if TYPE_CHECKING:
|
||||
TopPLogitsWarper,
|
||||
TypicalLogitsWarper,
|
||||
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
||||
WatermarkLogitsProcessor,
|
||||
WhisperTimeStampLogitsProcessor,
|
||||
)
|
||||
from .stopping_criteria import (
|
||||
@@ -247,6 +253,10 @@ if TYPE_CHECKING:
|
||||
SampleDecoderOnlyOutput,
|
||||
SampleEncoderDecoderOutput,
|
||||
)
|
||||
from .watermarking import (
|
||||
WatermarkDetector,
|
||||
WatermarkDetectorOutput,
|
||||
)
|
||||
|
||||
try:
|
||||
if not is_tf_available():
|
||||
|
||||
@@ -18,6 +18,7 @@ import copy
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from dataclasses import dataclass, is_dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
||||
|
||||
from .. import __version__
|
||||
@@ -221,6 +222,23 @@ class GenerationConfig(PushToHubMixin):
|
||||
low_memory (`bool`, *optional*):
|
||||
Switch to sequential beam search and sequential topk for contrastive search to reduce peak memory.
|
||||
Used with beam search and contrastive search.
|
||||
watermarking_config (Union[`WatermarkingConfig`, `dict`], *optional*):
|
||||
Arguments used to watermark the model outputs by adding a small bias to randomly selected set of "green" tokens.
|
||||
If passed as `Dict`, it will be converted to a `WatermarkingConfig` internally.
|
||||
See [this paper](https://arxiv.org/abs/2306.04634) for more details. Accepts the following keys:
|
||||
- greenlist_ratio (`float`):
|
||||
Used for watermarking. The ratio of "green" tokens used to the vocabulary size. Defaults to 0.25.
|
||||
- bias (`float`):
|
||||
Used with watermarking. The bias added to the selected "green" tokens' logits. Defaults to 2.0.
|
||||
- hashing_key (`int`):
|
||||
Hahsing key used for watermarking. Defaults to 15485863 (the millionth prime).
|
||||
- seeding_scheme (`str`):
|
||||
Algorithm to use for watermarking. Accepts values:
|
||||
- "lefthash" (default): "green" tokens selection depend on the last token (Algorithm 2 from the paper)
|
||||
- "selfhash": "green" tokens selection depends on the current token itself (Algorithm 3 from the paper)
|
||||
The downside of this scheme is that it considers all possible next tokens and can be slower than "lefthash".
|
||||
- context_width(`int`):
|
||||
The context length of previous tokens to use in seeding. Higher context length makes watermarking more robust.
|
||||
|
||||
> Parameters that define the output variables of generate
|
||||
|
||||
@@ -333,6 +351,13 @@ class GenerationConfig(PushToHubMixin):
|
||||
self.sequence_bias = kwargs.pop("sequence_bias", None)
|
||||
self.guidance_scale = kwargs.pop("guidance_scale", None)
|
||||
self.low_memory = kwargs.pop("low_memory", None)
|
||||
watermarking_config = kwargs.pop("watermarking_config", None)
|
||||
if watermarking_config is None:
|
||||
self.watermarking_config = None
|
||||
elif isinstance(watermarking_config, WatermarkingConfig):
|
||||
self.watermarking_config = watermarking_config
|
||||
else:
|
||||
self.watermarking_config = WatermarkingConfig.from_dict(watermarking_config)
|
||||
|
||||
# Parameters that define the output variables of `generate`
|
||||
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
|
||||
@@ -613,6 +638,12 @@ class GenerationConfig(PushToHubMixin):
|
||||
f"({self.num_beams})."
|
||||
)
|
||||
|
||||
# check watermarking arguments
|
||||
if self.watermarking_config is not None:
|
||||
if not isinstance(self.watermarking_config, WatermarkingConfig):
|
||||
self.watermarking_config = WatermarkingConfig.from_dict(self.watermarking_config)
|
||||
self.watermarking_config.validate()
|
||||
|
||||
# 5. check common issue: passing `generate` arguments inside the generation config
|
||||
generate_arguments = (
|
||||
"logits_processor",
|
||||
@@ -1021,7 +1052,16 @@ class GenerationConfig(PushToHubMixin):
|
||||
else:
|
||||
return obj
|
||||
|
||||
def convert_dataclass_to_dict(obj):
|
||||
if isinstance(obj, dict):
|
||||
return {key: convert_dataclass_to_dict(value) for key, value in obj.items()}
|
||||
elif is_dataclass(obj):
|
||||
return obj.to_dict()
|
||||
else:
|
||||
return obj
|
||||
|
||||
config_dict = convert_keys_to_string(config_dict)
|
||||
config_dict = convert_dataclass_to_dict(config_dict)
|
||||
|
||||
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
||||
|
||||
@@ -1093,3 +1133,142 @@ class GenerationConfig(PushToHubMixin):
|
||||
# Remove all the attributes that were updated, without modifying the input dict
|
||||
unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
|
||||
return unused_kwargs
|
||||
|
||||
|
||||
@dataclass
|
||||
class WatermarkingConfig:
|
||||
"""
|
||||
Class that holds arguments for watermark generation and should be passed into `GenerationConfig` during `generate`.
|
||||
See [this paper](https://arxiv.org/abs/2306.04634) for more details on the arguments.
|
||||
|
||||
Accepts the following keys:
|
||||
- greenlist_ratio (`float`):
|
||||
Used for watermarking. The ratio of "green" tokens used to the vocabulary size. Defaults to 0.25.
|
||||
- bias (`float`):
|
||||
Used with watermarking. The bias added to the selected "green" tokens' logits. Defaults to 2.0.
|
||||
- hashing_key (`int`):
|
||||
Hashing key used for watermarking. Defaults to 15485863 (the millionth prime).
|
||||
- seeding_scheme (`str`):
|
||||
Algorithm to use for watermarking. Accepts values:
|
||||
- "lefthash" (default): "green" tokens selection depend on the last token (Algorithm 2 from the paper)
|
||||
- "selfhash": "green" tokens selection depends on the current token itself (Algorithm 3 from the paper)
|
||||
The downside of this scheme is that it considers all possible next tokens and can be slower than "lefthash".
|
||||
- context_width(`int`):
|
||||
The context length of previous tokens to use in seeding. Higher context length makes watermarking more robust.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
greenlist_ratio: Optional[float] = 0.25,
|
||||
bias: Optional[float] = 2.0,
|
||||
hashing_key: Optional[int] = 15485863,
|
||||
seeding_scheme: Optional[str] = "lefthash",
|
||||
context_width: Optional[int] = 1,
|
||||
):
|
||||
self.greenlist_ratio = greenlist_ratio
|
||||
self.bias = bias
|
||||
self.hashing_key = hashing_key
|
||||
self.seeding_scheme = seeding_scheme
|
||||
self.context_width = context_width
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict, **kwargs):
|
||||
"""
|
||||
Constructs a WatermarkingConfig instance from a dictionary of parameters.
|
||||
|
||||
Args:
|
||||
config_dict (Dict[str, Any]): Dictionary containing configuration parameters.
|
||||
**kwargs: Additional keyword arguments to override dictionary values.
|
||||
|
||||
Returns:
|
||||
WatermarkingConfig: Instance of WatermarkingConfig constructed from the dictionary.
|
||||
"""
|
||||
config = cls(**config_dict)
|
||||
to_remove = []
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(config, key):
|
||||
setattr(config, key, value)
|
||||
to_remove.append(key)
|
||||
for key in to_remove:
|
||||
kwargs.pop(key, None)
|
||||
return config
|
||||
|
||||
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
|
||||
"""
|
||||
Save this instance to a JSON file.
|
||||
|
||||
Args:
|
||||
json_file_path (Union[str, os.PathLike]): Path to the JSON file in which this configuration instance's parameters will be saved.
|
||||
"""
|
||||
with open(json_file_path, "w", encoding="utf-8") as writer:
|
||||
config_dict = self.to_dict()
|
||||
json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
||||
|
||||
writer.write(json_string)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serializes this instance to a Python dictionary.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Dictionary of all the attributes that make up this configuration instance.
|
||||
"""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
return output
|
||||
|
||||
def __iter__(self):
|
||||
for attr, value in copy.deepcopy(self.__dict__).items():
|
||||
yield attr, value
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__} {self.to_json_string()}"
|
||||
|
||||
def to_json_string(self):
|
||||
"""
|
||||
Serializes this instance to a JSON formatted string.
|
||||
|
||||
Returns:
|
||||
str: JSON formatted string representing the configuration instance.
|
||||
"""
|
||||
return json.dumps(self.__dict__, indent=2) + "\n"
|
||||
|
||||
def update(self, **kwargs):
|
||||
"""
|
||||
Update the configuration attributes with new values.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments representing configuration attributes and their new values.
|
||||
"""
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
|
||||
def validate(self):
|
||||
watermark_missing_arg_msg = (
|
||||
"Some of the keys in `watermarking_config` are defined incorrectly. `{key}` should be {correct_value}` "
|
||||
"but found {found_value}"
|
||||
)
|
||||
if self.seeding_scheme not in ["selfhash", "lefthash"]:
|
||||
raise ValueError(
|
||||
watermark_missing_arg_msg.format(
|
||||
key="seeding_scheme",
|
||||
correct_value="[`selfhash`, `lefthash`]",
|
||||
found_value=self.seeding_scheme,
|
||||
),
|
||||
)
|
||||
if not 0.0 <= self.greenlist_ratio <= 1.0:
|
||||
raise ValueError(
|
||||
watermark_missing_arg_msg.format(
|
||||
key="greenlist_ratio",
|
||||
correct_value="in range between 0.0 and 1.0",
|
||||
found_value=self.seeding_scheme,
|
||||
),
|
||||
)
|
||||
if not self.context_width >= 1:
|
||||
raise ValueError(
|
||||
watermark_missing_arg_msg.format(
|
||||
key="context_width",
|
||||
correct_value="a positive integer",
|
||||
found_value=self.context_width,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -2321,3 +2321,144 @@ class BarkEosPrioritizerLogitsProcessor(LogitsProcessor):
|
||||
scores_processed = torch.where(do_early_stop, early_stop_scores, scores)
|
||||
|
||||
return scores_processed
|
||||
|
||||
|
||||
class WatermarkLogitsProcessor(LogitsProcessor):
|
||||
r"""
|
||||
Logits processor for watermarking generated text. The processor modifies model output scores by adding a small bias to
|
||||
randomized set of "green" tokens before generating the next token. "Green" tokens selection process depends on the
|
||||
`seeding_scheme` used. The code was based on the [original repo](https://github.com/jwkirchenbauer/lm-watermarking/tree/main).
|
||||
|
||||
The text generated by this `LogitsProcessor` can be detected using `WatermarkDetector`. See [`~WatermarkDetector.__call__`] for details,
|
||||
|
||||
See [the paper](https://arxiv.org/abs/2306.04634) for more information.
|
||||
|
||||
Args:
|
||||
vocab_size (`int`):
|
||||
The model tokenizer's vocab_size. Used to calculate "green" tokens ratio.
|
||||
device (`str`):
|
||||
The device where model is allocated.
|
||||
greenlist_ratio (`float`, optional, *optional*, defaults to 0.25):
|
||||
The ratio of "green" tokens used to the vocabulary size. Defaults to 0.25.
|
||||
bias (`float`, optional, *optional*, defaults to 2.0):
|
||||
The bias added to the selected "green" tokens' logits. Consider lowering the
|
||||
`bias` if the text generation quality degrades. Recommended values are in the
|
||||
range of [0.5, 2.0]. Defaults to 2.0.
|
||||
hashing_key (`int`, optional, *optional*, defaults to 15485863):
|
||||
Key used for hashing. If you deploy this watermark, we advise using another private key.
|
||||
Defaults to 15485863 (the millionth prime).
|
||||
seeding_scheme (`str`, optional, *optional*, defaults to `"lefthash"`):
|
||||
The seeding scheme used for selecting "green" tokens. Accepts values:
|
||||
- "lefthash" (default): "green" tokens selection depend on the last token (Algorithm 2 from paper)
|
||||
- "selfhash": "green" tokens selection depends on the current token itself (Algorithm 3 from paper)
|
||||
The downside of this scheme is that it considers all possible next tokens and can be slower than "lefthash".
|
||||
The context length of previous tokens to use in seeding. Higher context length makes watermarking more robust.
|
||||
context_width (`int`, *optional*, defaults to 1):
|
||||
The number of previous tokens to use when setting the seed.
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, WatermarkingConfig
|
||||
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
||||
>>> inputs = tokenizer(["Alice and Bob are"], return_tensors="pt")
|
||||
|
||||
>>> # normal generation
|
||||
>>> out = model.generate(inputs["input_ids"], max_length=20, do_sample=False)
|
||||
>>> tokenizer.batch_decode(out, skip_special_tokens=True)[0]
|
||||
'Alice and Bob are both in the same room.\n\n"I\'m not sure if you\'re'
|
||||
|
||||
>>> # watermarked generation
|
||||
>>> watermarking_config = WatermarkingConfig(bias=2.5, context_width=2, seeding_scheme="selfhash")
|
||||
>>> out = model.generate(inputs["input_ids"], watermarking_config=watermarking_config, max_length=20, do_sample=False)
|
||||
>>> tokenizer.batch_decode(out, skip_special_tokens=True)[0]
|
||||
'Alice and Bob are both still alive and well and the story is pretty much a one-hour adventure'
|
||||
|
||||
>>> # to detect watermarked text use the WatermarkDetector class
|
||||
>>> from transformers import WatermarkDetector
|
||||
>>> detector = WatermarkDetector(model_config=model.config, device="cpu", watermarking_config= watermarking_config)
|
||||
>>> detection_preds = detector(out)
|
||||
>>> detection_preds
|
||||
array([ True])
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size,
|
||||
device,
|
||||
greenlist_ratio: float = 0.25,
|
||||
bias: float = 2.0,
|
||||
hashing_key: int = 15485863,
|
||||
seeding_scheme: str = "lefthash",
|
||||
context_width: int = 1,
|
||||
):
|
||||
if seeding_scheme not in ["selfhash", "lefthash"]:
|
||||
raise ValueError(f"seeding_scheme has to be one of [`selfhash`, `lefthash`], but found {seeding_scheme}")
|
||||
if greenlist_ratio >= 1.0 or greenlist_ratio <= 0.0:
|
||||
raise ValueError(
|
||||
f"greenlist_ratio has be in range between 0.0 and 1.0, exclusively. but found {greenlist_ratio}"
|
||||
)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.greenlist_size = int(self.vocab_size * greenlist_ratio)
|
||||
self.bias = bias
|
||||
self.seeding_scheme = seeding_scheme
|
||||
self.rng = torch.Generator(device=device)
|
||||
self.hash_key = hashing_key
|
||||
self.context_width = context_width
|
||||
|
||||
self.rng.manual_seed(hashing_key)
|
||||
self.table_size = 1_000_003
|
||||
self.fixed_table = torch.randperm(self.table_size, generator=self.rng, device=device)
|
||||
|
||||
def set_seed(self, input_seq: torch.LongTensor):
|
||||
input_seq = input_seq[-self.context_width :]
|
||||
if self.seeding_scheme == "selfhash":
|
||||
a = self.fixed_table[input_seq % self.table_size] + 1
|
||||
b = self.fixed_table[input_seq[-1] % self.table_size] + 1
|
||||
seed = (self.hash_key * a * b).min().item()
|
||||
else:
|
||||
seed = self.hash_key * input_seq[-1].item()
|
||||
self.rng.manual_seed(seed % (2**64 - 1))
|
||||
|
||||
def _get_greenlist_ids(self, input_seq: torch.LongTensor) -> torch.LongTensor:
|
||||
self.set_seed(input_seq)
|
||||
vocab_permutation = torch.randperm(self.vocab_size, device=input_seq.device, generator=self.rng)
|
||||
greenlist_ids = vocab_permutation[: self.greenlist_size]
|
||||
return greenlist_ids
|
||||
|
||||
def _score_rejection_sampling(self, input_seq: torch.LongTensor, scores: torch.FloatTensor) -> torch.LongTensor:
|
||||
"""
|
||||
Generate greenlist based on current candidate next token. Reject and move on if necessary.
|
||||
Runs for a fixed number of steps only for efficiency, since the methods is not batched.
|
||||
"""
|
||||
final_greenlist = []
|
||||
_, greedy_predictions = scores.sort(dim=-1, descending=True)
|
||||
|
||||
# 40 is an arbitrary number chosen to save compute and not run for long (taken from orig repo)
|
||||
for i in range(40):
|
||||
greenlist_ids = self._get_greenlist_ids(torch.cat([input_seq, greedy_predictions[i, None]], dim=-1))
|
||||
if greedy_predictions[i] in greenlist_ids:
|
||||
final_greenlist.append(greedy_predictions[i])
|
||||
return torch.tensor(final_greenlist, device=input_seq.device)
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if input_ids.shape[-1] < self.context_width:
|
||||
logger.warning(
|
||||
f"`input_ids` should have at least `{self.context_width}` tokens but has {input_ids.shape[-1]}. "
|
||||
"The seeding will be skipped for this generation step!"
|
||||
)
|
||||
return scores
|
||||
|
||||
scores_processed = scores.clone()
|
||||
for b_idx, input_seq in enumerate(input_ids):
|
||||
if self.seeding_scheme == "selfhash":
|
||||
greenlist_ids = self._score_rejection_sampling(input_seq, scores[b_idx])
|
||||
else:
|
||||
greenlist_ids = self._get_greenlist_ids(input_seq)
|
||||
scores_processed[b_idx, greenlist_ids] = scores_processed[b_idx, greenlist_ids] + self.bias
|
||||
|
||||
return scores_processed
|
||||
|
||||
@@ -74,6 +74,7 @@ from .logits_process import (
|
||||
TopPLogitsWarper,
|
||||
TypicalLogitsWarper,
|
||||
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
||||
WatermarkLogitsProcessor,
|
||||
)
|
||||
from .stopping_criteria import (
|
||||
EosTokenCriteria,
|
||||
@@ -763,6 +764,7 @@ class GenerationMixin:
|
||||
encoder_input_ids: torch.LongTensor,
|
||||
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
|
||||
logits_processor: Optional[LogitsProcessorList],
|
||||
device: str = None,
|
||||
model_kwargs: Optional[Dict[str, Any]] = None,
|
||||
negative_prompt_ids: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
@@ -879,6 +881,18 @@ class GenerationMixin:
|
||||
FutureWarning,
|
||||
)
|
||||
processors.append(ForceTokensLogitsProcessor(generation_config.forced_decoder_ids, _has_warned=True))
|
||||
if generation_config.watermarking_config is not None:
|
||||
processors.append(
|
||||
WatermarkLogitsProcessor(
|
||||
vocab_size=self.config.vocab_size,
|
||||
device=device,
|
||||
greenlist_ratio=generation_config.watermarking_config.greenlist_ratio,
|
||||
bias=generation_config.watermarking_config.bias,
|
||||
hashing_key=generation_config.watermarking_config.hashing_key,
|
||||
seeding_scheme=generation_config.watermarking_config.seeding_scheme,
|
||||
context_width=generation_config.watermarking_config.context_width,
|
||||
)
|
||||
)
|
||||
processors = self._merge_criteria_processor_list(processors, logits_processor)
|
||||
# `LogitNormalization` should always be the last logit processor, when present
|
||||
if generation_config.renormalize_logits is True:
|
||||
@@ -1632,6 +1646,7 @@ class GenerationMixin:
|
||||
encoder_input_ids=inputs_tensor,
|
||||
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||||
logits_processor=logits_processor,
|
||||
device=inputs_tensor.device,
|
||||
model_kwargs=model_kwargs,
|
||||
negative_prompt_ids=negative_prompt_ids,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
|
||||
240
src/transformers/generation/watermarking.py
Normal file
240
src/transformers/generation/watermarking.py
Normal file
@@ -0,0 +1,240 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import collections
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..configuration_utils import PretrainedConfig
|
||||
from ..utils import is_torch_available, logging
|
||||
from .configuration_utils import WatermarkingConfig
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from .logits_process import WatermarkLogitsProcessor
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class WatermarkDetectorOutput:
|
||||
"""
|
||||
Outputs of a watermark detector.
|
||||
|
||||
Args:
|
||||
num_tokens_scored (np.array of shape (batch_size)):
|
||||
Array containing the number of tokens scored for each element in the batch.
|
||||
num_green_tokens (np.array of shape (batch_size)):
|
||||
Array containing the number of green tokens for each element in the batch.
|
||||
green_fraction (np.array of shape (batch_size)):
|
||||
Array containing the fraction of green tokens for each element in the batch.
|
||||
z_score (np.array of shape (batch_size)):
|
||||
Array containing the z-score for each element in the batch. Z-score here shows
|
||||
how many standard deviations away is the green token count in the input text
|
||||
from the expected green token count for machine-generated text.
|
||||
p_value (np.array of shape (batch_size)):
|
||||
Array containing the p-value for each batch obtained from z-scores.
|
||||
prediction (np.array of shape (batch_size)), *optional*:
|
||||
Array containing boolean predictions whether a text is machine-generated for each element in the batch.
|
||||
confidence (np.array of shape (batch_size)), *optional*:
|
||||
Array containing confidence scores of a text being machine-generated for each element in the batch.
|
||||
"""
|
||||
|
||||
num_tokens_scored: np.array = None
|
||||
num_green_tokens: np.array = None
|
||||
green_fraction: np.array = None
|
||||
z_score: np.array = None
|
||||
p_value: np.array = None
|
||||
prediction: Optional[np.array] = None
|
||||
confidence: Optional[np.array] = None
|
||||
|
||||
|
||||
class WatermarkDetector:
|
||||
|
||||
r"""
|
||||
Detector for detection of watermark generated text. The detector needs to be given the exact same settings that were
|
||||
given during text generation to replicate the watermark greenlist generation and so detect the watermark. This includes
|
||||
the correct device that was used during text generation, the correct watermarking arguments and the correct tokenizer vocab size.
|
||||
The code was based on the [original repo](https://github.com/jwkirchenbauer/lm-watermarking/tree/main).
|
||||
|
||||
See [the paper](https://arxiv.org/abs/2306.04634) for more information.
|
||||
|
||||
Args:
|
||||
model_config (`PretrainedConfig`):
|
||||
The model config that will be used to get model specific arguments used when generating.
|
||||
device (`str`):
|
||||
The device which was used during watermarked text generation.
|
||||
watermarking_config (Union[`WatermarkingConfig`, `Dict`]):
|
||||
The exact same watermarking config and arguments used when generating text.
|
||||
ignore_repeated_ngrams (`bool`, *optional*, defaults to `False`):
|
||||
Whether to count every unique ngram only once or not.
|
||||
max_cache_size (`int`, *optional*, defaults to 128):
|
||||
The max size to be used for LRU caching of seeding/sampling algorithms called for every token.
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, WatermarkDetector, WatermarkingConfig
|
||||
|
||||
>>> model_id = "openai-community/gpt2"
|
||||
>>> model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
>>> tok = AutoTokenizer.from_pretrained(model_id)
|
||||
>>> tok.pad_token_id = tok.eos_token_id
|
||||
>>> tok.padding_side = "left"
|
||||
|
||||
>>> inputs = tok(["This is the beginning of a long story", "Alice and Bob are"], padding=True, return_tensors="pt")
|
||||
>>> input_len = inputs["input_ids"].shape[-1]
|
||||
|
||||
>>> # first generate text with watermark and without
|
||||
>>> watermarking_config = WatermarkingConfig(bias=2.5, seeding_scheme="selfhash")
|
||||
>>> out_watermarked = model.generate(**inputs, watermarking_config=watermarking_config, do_sample=False, max_length=20)
|
||||
>>> out = model.generate(**inputs, do_sample=False, max_length=20)
|
||||
|
||||
>>> # now we can instantiate the detector and check the generated text
|
||||
>>> detector = WatermarkDetector(model_config=model.config, device="cpu", watermarking_config=watermarking_config)
|
||||
>>> detection_out_watermarked = detector(out_watermarked, return_dict=True)
|
||||
>>> detection_out = detector(out, return_dict=True)
|
||||
>>> detection_out_watermarked.prediction
|
||||
array([ True, True])
|
||||
|
||||
>>> detection_out.prediction
|
||||
array([False, False])
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: PretrainedConfig,
|
||||
device: str,
|
||||
watermarking_config: Union[WatermarkingConfig, Dict],
|
||||
ignore_repeated_ngrams: bool = False,
|
||||
max_cache_size: int = 128,
|
||||
):
|
||||
if isinstance(watermarking_config, WatermarkingConfig):
|
||||
watermarking_config = watermarking_config.to_dict()
|
||||
|
||||
self.bos_token_id = (
|
||||
model_config.bos_token_id if not model_config.is_encoder_decoder else model_config.decoder_start_token_id
|
||||
)
|
||||
self.greenlist_ratio = watermarking_config["greenlist_ratio"]
|
||||
self.ignore_repeated_ngrams = ignore_repeated_ngrams
|
||||
self.processor = WatermarkLogitsProcessor(
|
||||
vocab_size=model_config.vocab_size, device=device, **watermarking_config
|
||||
)
|
||||
|
||||
# Expensive re-seeding and sampling is cached.
|
||||
self._get_ngram_score_cached = lru_cache(maxsize=max_cache_size)(self._get_ngram_score)
|
||||
|
||||
def _get_ngram_score(self, prefix: torch.LongTensor, target: int):
|
||||
greenlist_ids = self.processor._get_greenlist_ids(prefix)
|
||||
return target in greenlist_ids
|
||||
|
||||
def _score_ngrams_in_passage(self, input_ids: torch.LongTensor):
|
||||
batch_size, seq_length = input_ids.shape
|
||||
selfhash = int(self.processor.seeding_scheme == "selfhash")
|
||||
n = self.processor.context_width + 1 - selfhash
|
||||
indices = torch.arange(n).unsqueeze(0) + torch.arange(seq_length - n + 1).unsqueeze(1)
|
||||
ngram_tensors = input_ids[:, indices]
|
||||
|
||||
num_tokens_scored_batch = np.zeros(batch_size)
|
||||
green_token_count_batch = np.zeros(batch_size)
|
||||
for batch_idx in range(ngram_tensors.shape[0]):
|
||||
frequencies_table = collections.Counter(ngram_tensors[batch_idx])
|
||||
ngram_to_watermark_lookup = {}
|
||||
for ngram_example in frequencies_table.keys():
|
||||
prefix = ngram_example if selfhash else ngram_example[:-1]
|
||||
target = ngram_example[-1]
|
||||
ngram_to_watermark_lookup[ngram_example] = self._get_ngram_score_cached(prefix, target)
|
||||
|
||||
if self.ignore_repeated_ngrams:
|
||||
# counts a green/red hit once per unique ngram.
|
||||
# num total tokens scored becomes the number unique ngrams.
|
||||
num_tokens_scored_batch[batch_idx] = len(frequencies_table.keys())
|
||||
green_token_count_batch[batch_idx] = sum(ngram_to_watermark_lookup.values())
|
||||
else:
|
||||
num_tokens_scored_batch[batch_idx] = sum(frequencies_table.values())
|
||||
green_token_count_batch[batch_idx] = sum(
|
||||
freq * outcome
|
||||
for freq, outcome in zip(frequencies_table.values(), ngram_to_watermark_lookup.values())
|
||||
)
|
||||
return num_tokens_scored_batch, green_token_count_batch
|
||||
|
||||
def _compute_z_score(self, green_token_count: np.array, total_num_tokens: np.array) -> np.array:
|
||||
expected_count = self.greenlist_ratio
|
||||
numer = green_token_count - expected_count * total_num_tokens
|
||||
denom = np.sqrt(total_num_tokens * expected_count * (1 - expected_count))
|
||||
z = numer / denom
|
||||
return z
|
||||
|
||||
def _compute_pval(self, x, loc=0, scale=1):
|
||||
z = (x - loc) / scale
|
||||
return 1 - (0.5 * (1 + np.sign(z) * (1 - np.exp(-2 * z**2 / np.pi))))
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
z_threshold: float = 3.0,
|
||||
return_dict: bool = False,
|
||||
) -> Union[WatermarkDetectorOutput, np.array]:
|
||||
"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor`):
|
||||
The watermark generated text. It is advised to remove the prompt, which can affect the detection.
|
||||
z_threshold (`Dict`, *optional*, defaults to `3.0`):
|
||||
Changing this threshold will change the sensitivity of the detector. Higher z threshold gives less
|
||||
sensitivity and vice versa for lower z threshold.
|
||||
return_dict (`bool`, *optional*, defaults to `False`):
|
||||
Whether to return `~generation.WatermarkDetectorOutput` or not. If not it will return boolean predictions,
|
||||
ma
|
||||
Return:
|
||||
[`~generation.WatermarkDetectorOutput`] or `np.array`: A [`~generation.WatermarkDetectorOutput`]
|
||||
if `return_dict=True` otherwise a `np.array`.
|
||||
|
||||
"""
|
||||
|
||||
# Let's assume that if one batch start with `bos`, all batched also do
|
||||
if input_ids[0, 0] == self.bos_token_id:
|
||||
input_ids = input_ids[:, 1:]
|
||||
|
||||
if input_ids.shape[-1] - self.processor.context_width < 1:
|
||||
raise ValueError(
|
||||
f"Must have at least `1` token to score after the first "
|
||||
f"min_prefix_len={self.processor.context_width} tokens required by the seeding scheme."
|
||||
)
|
||||
|
||||
num_tokens_scored, green_token_count = self._score_ngrams_in_passage(input_ids)
|
||||
z_score = self._compute_z_score(green_token_count, num_tokens_scored)
|
||||
prediction = z_score > z_threshold
|
||||
|
||||
if return_dict:
|
||||
p_value = self._compute_pval(z_score)
|
||||
confidence = 1 - p_value
|
||||
|
||||
return WatermarkDetectorOutput(
|
||||
num_tokens_scored=num_tokens_scored,
|
||||
num_green_tokens=green_token_count,
|
||||
green_fraction=green_token_count / num_tokens_scored,
|
||||
z_score=z_score,
|
||||
p_value=p_value,
|
||||
prediction=prediction,
|
||||
confidence=confidence,
|
||||
)
|
||||
return prediction
|
||||
@@ -422,6 +422,20 @@ class UnbatchedClassifierFreeGuidanceLogitsProcessor(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class WatermarkDetector(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class WatermarkLogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class WhisperTimeStampLogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -53,6 +53,7 @@ if is_torch_available():
|
||||
TopPLogitsWarper,
|
||||
TypicalLogitsWarper,
|
||||
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
||||
WatermarkLogitsProcessor,
|
||||
)
|
||||
from transformers.generation.logits_process import BarkEosPrioritizerLogitsProcessor
|
||||
|
||||
@@ -949,3 +950,26 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
[float("-inf"), float("-inf"), scores[0][0], scores[0][0]],
|
||||
]
|
||||
self.assertListEqual(actual_scores.tolist(), expected_scores_list)
|
||||
|
||||
def test_watermarking_processor(self):
|
||||
batch_size = 3
|
||||
vocab_size = 20
|
||||
|
||||
input_ids = ids_tensor((batch_size, 5), vocab_size=20)
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
|
||||
# raise error if incorrect seeding_scheme is passed
|
||||
with self.assertRaises(ValueError):
|
||||
WatermarkLogitsProcessor(vocab_size=vocab_size, device="cpu", seeding_scheme="hash")
|
||||
|
||||
# raise error if the greenlist_ratio in not in range (0.0, 1.0)
|
||||
with self.assertRaises(ValueError):
|
||||
WatermarkLogitsProcessor(vocab_size=vocab_size, device="cpu", greenlist_ratio=1.2)
|
||||
|
||||
watermark = WatermarkLogitsProcessor(vocab_size=vocab_size, device=input_ids.device)
|
||||
|
||||
# use fixed id for last token, needed for reprodicibility and tests
|
||||
input_ids[:, -1] = 10
|
||||
scores_wo_bias = scores[:, -1].clone()
|
||||
out = watermark(input_ids=input_ids, scores=scores)
|
||||
self.assertTrue((out[:, 1] == scores_wo_bias + watermark.bias).all())
|
||||
|
||||
@@ -75,6 +75,8 @@ if is_torch_available():
|
||||
SampleEncoderDecoderOutput,
|
||||
StoppingCriteria,
|
||||
StoppingCriteriaList,
|
||||
WatermarkDetector,
|
||||
WatermarkingConfig,
|
||||
)
|
||||
from transformers.generation.utils import _speculative_sampling
|
||||
|
||||
@@ -2098,6 +2100,44 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
)
|
||||
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
||||
|
||||
@slow
|
||||
def test_watermark_generation(self):
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
|
||||
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2").to(torch_device)
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
model_inputs = tokenizer("I will be", return_tensors="pt").to(torch_device)
|
||||
input_len = model_inputs["input_ids"].shape[-1]
|
||||
|
||||
# generation should work with both input types: WatermarkingConfig or Dict, so let's check it here :)
|
||||
watermark_config = WatermarkingConfig(bias=2.5, seeding_scheme="selfhash")
|
||||
_ = model.generate(**model_inputs, watermarking_config=watermark_config, do_sample=False, max_length=15)
|
||||
|
||||
args = {
|
||||
"bias": 2.0,
|
||||
"context_width": 1,
|
||||
"seeding_scheme": "selfhash",
|
||||
"greenlist_ratio": 0.25,
|
||||
"hashing_key": 15485863,
|
||||
}
|
||||
output = model.generate(**model_inputs, do_sample=False, max_length=15)
|
||||
output_selfhash = model.generate(**model_inputs, watermarking_config=args, do_sample=False, max_length=15)
|
||||
|
||||
# check that the watermarked text is generating what is should
|
||||
self.assertListEqual(
|
||||
output.tolist(), [[40, 481, 307, 262, 717, 284, 9159, 326, 314, 716, 407, 257, 4336, 286, 262]]
|
||||
)
|
||||
self.assertListEqual(
|
||||
output_selfhash.tolist(), [[40, 481, 307, 2263, 616, 640, 284, 651, 616, 1621, 503, 612, 553, 531, 367]]
|
||||
)
|
||||
|
||||
detector = WatermarkDetector(model_config=model.config, device=torch_device, watermarking_config=args)
|
||||
detection_out_watermarked = detector(output_selfhash[:, input_len:], return_dict=True)
|
||||
detection_out = detector(output[:, input_len:], return_dict=True)
|
||||
|
||||
# check that the detector is detecting watermarked text
|
||||
self.assertListEqual(detection_out_watermarked.prediction.tolist(), [True])
|
||||
self.assertListEqual(detection_out.prediction.tolist(), [False])
|
||||
|
||||
@slow
|
||||
def test_beam_search_example_integration(self):
|
||||
# PT-only test: TF doesn't have a BeamSearchScorer
|
||||
|
||||
Reference in New Issue
Block a user