[RWKV] Add note in doc on RwkvStoppingCriteria (#25055)

* Add note in doc on `RwkvStoppingCriteria`

* give some breathing space to the code
This commit is contained in:
Arthur
2023-07-25 10:15:00 +02:00
committed by GitHub
parent d2295708a6
commit c53a6eae74

View File

@@ -51,6 +51,24 @@ output_two = outputs.last_hidden_state
torch.allclose(torch.cat([output_one, output_two], dim=1), output_whole, atol=1e-5)
```
If you want to make sure the model stops generating when `'\n\n'` is detected, we recommend using the following stopping criteria:
```python
from transformers import StoppingCriteria
class RwkvStoppingCriteria(StoppingCriteria):
def __init__(self, eos_sequence = [187,187], eos_token_id = 537):
self.eos_sequence = eos_sequence
self.eos_token_id = eos_token_id
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
last_2_ids = input_ids[:,-2:].tolist()
return self.eos_sequence in last_2_ids
output = model.generate(inputs["input_ids"], max_new_tokens=64, stopping_criteria = [RwkvStoppingCriteria()])
```
## RwkvConfig
[[autodoc]] RwkvConfig