Add Watermarking LogitsProcessor and WatermarkDetector (#29676)
* add watermarking processor * remove the other hashing (context width=1 always) * make style * Update src/transformers/generation/logits_process.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/logits_process.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/logits_process.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/configuration_utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * update watermarking process * add detector * update tests to use detector * fix failing tests * rename `input_seq` * make style * doc for processor * minor fixes * docs * make quality * Update src/transformers/generation/configuration_utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/logits_process.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/watermarking.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/watermarking.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/watermarking.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * add PR suggestions * let's use lru_cache's default max size (128) * import processor if torch available * maybe like this * lets move the config to torch independet file * add docs * tiny docs fix to make the test happy * Update src/transformers/generation/configuration_utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/watermarking.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * PR suggestions * add docs * fix test * fix docs * address pr comments * style * Revert "style" This reverts commit 7f33cc34ff08b414f8e7f90060889877606b43b2. * correct style * make doctest green --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
committed by
GitHub
parent
65ea1904ff
commit
5ad960f1f4
@@ -53,6 +53,7 @@ if is_torch_available():
|
||||
TopPLogitsWarper,
|
||||
TypicalLogitsWarper,
|
||||
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
||||
WatermarkLogitsProcessor,
|
||||
)
|
||||
from transformers.generation.logits_process import BarkEosPrioritizerLogitsProcessor
|
||||
|
||||
@@ -949,3 +950,26 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
[float("-inf"), float("-inf"), scores[0][0], scores[0][0]],
|
||||
]
|
||||
self.assertListEqual(actual_scores.tolist(), expected_scores_list)
|
||||
|
||||
def test_watermarking_processor(self):
|
||||
batch_size = 3
|
||||
vocab_size = 20
|
||||
|
||||
input_ids = ids_tensor((batch_size, 5), vocab_size=20)
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
|
||||
# raise error if incorrect seeding_scheme is passed
|
||||
with self.assertRaises(ValueError):
|
||||
WatermarkLogitsProcessor(vocab_size=vocab_size, device="cpu", seeding_scheme="hash")
|
||||
|
||||
# raise error if the greenlist_ratio in not in range (0.0, 1.0)
|
||||
with self.assertRaises(ValueError):
|
||||
WatermarkLogitsProcessor(vocab_size=vocab_size, device="cpu", greenlist_ratio=1.2)
|
||||
|
||||
watermark = WatermarkLogitsProcessor(vocab_size=vocab_size, device=input_ids.device)
|
||||
|
||||
# use fixed id for last token, needed for reprodicibility and tests
|
||||
input_ids[:, -1] = 10
|
||||
scores_wo_bias = scores[:, -1].clone()
|
||||
out = watermark(input_ids=input_ids, scores=scores)
|
||||
self.assertTrue((out[:, 1] == scores_wo_bias + watermark.bias).all())
|
||||
|
||||
@@ -75,6 +75,8 @@ if is_torch_available():
|
||||
SampleEncoderDecoderOutput,
|
||||
StoppingCriteria,
|
||||
StoppingCriteriaList,
|
||||
WatermarkDetector,
|
||||
WatermarkingConfig,
|
||||
)
|
||||
from transformers.generation.utils import _speculative_sampling
|
||||
|
||||
@@ -2098,6 +2100,44 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
)
|
||||
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
||||
|
||||
@slow
|
||||
def test_watermark_generation(self):
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
|
||||
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2").to(torch_device)
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
model_inputs = tokenizer("I will be", return_tensors="pt").to(torch_device)
|
||||
input_len = model_inputs["input_ids"].shape[-1]
|
||||
|
||||
# generation should work with both input types: WatermarkingConfig or Dict, so let's check it here :)
|
||||
watermark_config = WatermarkingConfig(bias=2.5, seeding_scheme="selfhash")
|
||||
_ = model.generate(**model_inputs, watermarking_config=watermark_config, do_sample=False, max_length=15)
|
||||
|
||||
args = {
|
||||
"bias": 2.0,
|
||||
"context_width": 1,
|
||||
"seeding_scheme": "selfhash",
|
||||
"greenlist_ratio": 0.25,
|
||||
"hashing_key": 15485863,
|
||||
}
|
||||
output = model.generate(**model_inputs, do_sample=False, max_length=15)
|
||||
output_selfhash = model.generate(**model_inputs, watermarking_config=args, do_sample=False, max_length=15)
|
||||
|
||||
# check that the watermarked text is generating what is should
|
||||
self.assertListEqual(
|
||||
output.tolist(), [[40, 481, 307, 262, 717, 284, 9159, 326, 314, 716, 407, 257, 4336, 286, 262]]
|
||||
)
|
||||
self.assertListEqual(
|
||||
output_selfhash.tolist(), [[40, 481, 307, 2263, 616, 640, 284, 651, 616, 1621, 503, 612, 553, 531, 367]]
|
||||
)
|
||||
|
||||
detector = WatermarkDetector(model_config=model.config, device=torch_device, watermarking_config=args)
|
||||
detection_out_watermarked = detector(output_selfhash[:, input_len:], return_dict=True)
|
||||
detection_out = detector(output[:, input_len:], return_dict=True)
|
||||
|
||||
# check that the detector is detecting watermarked text
|
||||
self.assertListEqual(detection_out_watermarked.prediction.tolist(), [True])
|
||||
self.assertListEqual(detection_out.prediction.tolist(), [False])
|
||||
|
||||
@slow
|
||||
def test_beam_search_example_integration(self):
|
||||
# PT-only test: TF doesn't have a BeamSearchScorer
|
||||
|
||||
Reference in New Issue
Block a user