[RWKV] Add note in doc on RwkvStoppingCriteria (#25055)
* Add note in doc on `RwkvStoppingCriteria` * give some breathing space to the code
This commit is contained in:
@@ -51,6 +51,24 @@ output_two = outputs.last_hidden_state
|
|||||||
torch.allclose(torch.cat([output_one, output_two], dim=1), output_whole, atol=1e-5)
|
torch.allclose(torch.cat([output_one, output_two], dim=1), output_whole, atol=1e-5)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
If you want to make sure the model stops generating when `'\n\n'` is detected, we recommend using the following stopping criteria:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from transformers import StoppingCriteria
|
||||||
|
|
||||||
|
class RwkvStoppingCriteria(StoppingCriteria):
|
||||||
|
def __init__(self, eos_sequence = [187,187], eos_token_id = 537):
|
||||||
|
self.eos_sequence = eos_sequence
|
||||||
|
self.eos_token_id = eos_token_id
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
||||||
|
last_2_ids = input_ids[:,-2:].tolist()
|
||||||
|
return self.eos_sequence in last_2_ids
|
||||||
|
|
||||||
|
|
||||||
|
output = model.generate(inputs["input_ids"], max_new_tokens=64, stopping_criteria = [RwkvStoppingCriteria()])
|
||||||
|
```
|
||||||
|
|
||||||
## RwkvConfig
|
## RwkvConfig
|
||||||
|
|
||||||
[[autodoc]] RwkvConfig
|
[[autodoc]] RwkvConfig
|
||||||
|
|||||||
Reference in New Issue
Block a user