[RWKV] Add note in doc on RwkvStoppingCriteria (#25055)
* Add note in doc on `RwkvStoppingCriteria` * give some breathing space to the code
This commit is contained in:
@@ -51,6 +51,24 @@ output_two = outputs.last_hidden_state
|
||||
torch.allclose(torch.cat([output_one, output_two], dim=1), output_whole, atol=1e-5)
|
||||
```
|
||||
|
||||
If you want to make sure the model stops generating when `'\n\n'` is detected, we recommend using the following stopping criteria:
|
||||
|
||||
```python
|
||||
from transformers import StoppingCriteria
|
||||
|
||||
class RwkvStoppingCriteria(StoppingCriteria):
|
||||
def __init__(self, eos_sequence = [187,187], eos_token_id = 537):
|
||||
self.eos_sequence = eos_sequence
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
||||
last_2_ids = input_ids[:,-2:].tolist()
|
||||
return self.eos_sequence in last_2_ids
|
||||
|
||||
|
||||
output = model.generate(inputs["input_ids"], max_new_tokens=64, stopping_criteria = [RwkvStoppingCriteria()])
|
||||
```
|
||||
|
||||
## RwkvConfig
|
||||
|
||||
[[autodoc]] RwkvConfig
|
||||
|
||||
Reference in New Issue
Block a user