From 8814043c8c62034277b04e73a44e25231ab020ad Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 25 Oct 2024 11:46:46 +0100 Subject: [PATCH] SynthID: better example (#34372) * better example * Update src/transformers/generation/configuration_utils.py * Update src/transformers/generation/logits_process.py * nits --- docs/source/en/internal/generation_utils.md | 4 +--- src/transformers/generation/configuration_utils.py | 10 +++++----- src/transformers/generation/logits_process.py | 10 +++++----- 3 files changed, 11 insertions(+), 13 deletions(-) diff --git a/docs/source/en/internal/generation_utils.md b/docs/source/en/internal/generation_utils.md index 946940cb01..eb25ddb632 100644 --- a/docs/source/en/internal/generation_utils.md +++ b/docs/source/en/internal/generation_utils.md @@ -428,13 +428,11 @@ A [`Constraint`] can be used to force the generation to include specific tokens - __call__ [[autodoc]] BayesianDetectorConfig - - __call__ [[autodoc]] BayesianDetectorModel - - __call__ + - forward [[autodoc]] SynthIDTextWatermarkingConfig - - __call__ [[autodoc]] SynthIDTextWatermarkDetector - __call__ diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index c460a19885..3c204481b0 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -1471,8 +1471,8 @@ class SynthIDTextWatermarkingConfig(BaseWatermarkingConfig): ```python >>> from transformers import AutoModelForCausalLM, AutoTokenizer, SynthIDTextWatermarkingConfig - >>> tokenizer = AutoTokenizer.from_pretrained('google/gemma-2-2b-it') - >>> model = AutoModelForCausalLM.from_pretrained('google/gemma-2-2b-it') + >>> tokenizer = AutoTokenizer.from_pretrained('google/gemma-2-2b', padding_side="left") + >>> model = AutoModelForCausalLM.from_pretrained('google/gemma-2-2b') >>> # SynthID Text configuration >>> watermarking_config = SynthIDTextWatermarkingConfig( @@ -1481,11 +1481,11 @@ class SynthIDTextWatermarkingConfig(BaseWatermarkingConfig): ... ) >>> # Generation with watermarking - >>> tokenized_prompts = tokenizer(["your prompts here"]) + >>> tokenized_prompts = tokenizer(["Once upon a time, "], return_tensors="pt", padding=True) >>> output_sequences = model.generate( - ... **tokenized_prompts, watermarking_config=watermarking_config, do_sample=True, + ... **tokenized_prompts, watermarking_config=watermarking_config, do_sample=True, max_new_tokens=10 ... ) - >>> watermarked_text = tokenizer.batch_decode(output_sequences) + >>> watermarked_text = tokenizer.batch_decode(output_sequences, skip_special_tokens=True) ``` """ diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index fde95c7a85..9d244191da 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -2565,8 +2565,8 @@ class SynthIDTextWatermarkLogitsProcessor(LogitsProcessor): ```python >>> from transformers import AutoModelForCausalLM, AutoTokenizer, SynthIDTextWatermarkingConfig - >>> tokenizer = AutoTokenizer.from_pretrained('google/gemma-2-2b-it') - >>> model = AutoModelForCausalLM.from_pretrained('google/gemma-2-2b-it') + >>> tokenizer = AutoTokenizer.from_pretrained('google/gemma-2-2b', padding_side="left") + >>> model = AutoModelForCausalLM.from_pretrained('google/gemma-2-2b') >>> # SynthID Text configuration >>> watermarking_config = SynthIDTextWatermarkingConfig( @@ -2575,11 +2575,11 @@ class SynthIDTextWatermarkLogitsProcessor(LogitsProcessor): ... ) >>> # Generation with watermarking - >>> tokenized_prompts = tokenizer(["your prompts here"]) + >>> tokenized_prompts = tokenizer(["Once upon a time, "], return_tensors="pt", padding=True) >>> output_sequences = model.generate( - ... **tokenized_prompts, watermarking_config=watermarking_config, do_sample=True, + ... **tokenized_prompts, watermarking_config=watermarking_config, do_sample=True, max_new_tokens=10 ... ) - >>> watermarked_text = tokenizer.batch_decode(output_sequences) + >>> watermarked_text = tokenizer.batch_decode(output_sequences, skip_special_tokens=True) ``` """