From c53a6eae74778ad65de9edb0f0bd1eef3674bde3 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Tue, 25 Jul 2023 10:15:00 +0200 Subject: [PATCH] [`RWKV`] Add note in doc on `RwkvStoppingCriteria` (#25055) * Add note in doc on `RwkvStoppingCriteria` * give some breathing space to the code --- docs/source/en/model_doc/rwkv.md | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/docs/source/en/model_doc/rwkv.md b/docs/source/en/model_doc/rwkv.md index cde8218bd0..9293db14cc 100644 --- a/docs/source/en/model_doc/rwkv.md +++ b/docs/source/en/model_doc/rwkv.md @@ -51,6 +51,24 @@ output_two = outputs.last_hidden_state torch.allclose(torch.cat([output_one, output_two], dim=1), output_whole, atol=1e-5) ``` +If you want to make sure the model stops generating when `'\n\n'` is detected, we recommend using the following stopping criteria: + +```python +from transformers import StoppingCriteria + +class RwkvStoppingCriteria(StoppingCriteria): + def __init__(self, eos_sequence = [187,187], eos_token_id = 537): + self.eos_sequence = eos_sequence + self.eos_token_id = eos_token_id + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + last_2_ids = input_ids[:,-2:].tolist() + return self.eos_sequence in last_2_ids + + +output = model.generate(inputs["input_ids"], max_new_tokens=64, stopping_criteria = [RwkvStoppingCriteria()]) +``` + ## RwkvConfig [[autodoc]] RwkvConfig