Doc styler examples (#14953)
* Fix bad examples * Add black formatting to style_doc * Use first nonempty line * Put it at the right place * Don't add spaces to empty lines * Better templates * Deal with triple quotes in docstrings * Result of style_doc * Enable mdx treatment and fix code examples in MDXs * Result of doc styler on doc source files * Last fixes * Break copy from
This commit is contained in:
@@ -69,18 +69,22 @@ All the [checkpoints](https://huggingface.co/models?search=pegasus) are fine-tun
|
||||
```python
|
||||
>>> from transformers import PegasusForConditionalGeneration, PegasusTokenizer
|
||||
>>> import torch
|
||||
|
||||
>>> src_text = [
|
||||
... """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""
|
||||
>>> ]
|
||||
... ]
|
||||
|
||||
>>> model_name = 'google/pegasus-xsum'
|
||||
>>> device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
>>> tokenizer = PegasusTokenizer.from_pretrained(model_name)
|
||||
>>> model = PegasusForConditionalGeneration.from_pretrained(model_name).to(device)
|
||||
>>> batch = tokenizer(src_text, truncation=True, padding='longest', return_tensors="pt").to(device)
|
||||
>>> translated = model.generate(**batch)
|
||||
>>> tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
|
||||
>>> assert tgt_text[0] == "California's largest electricity provider has turned off power to hundreds of thousands of customers."
|
||||
... model_name = "google/pegasus-xsum"
|
||||
... device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
... tokenizer = PegasusTokenizer.from_pretrained(model_name)
|
||||
... model = PegasusForConditionalGeneration.from_pretrained(model_name).to(device)
|
||||
... batch = tokenizer(src_text, truncation=True, padding="longest", return_tensors="pt").to(device)
|
||||
... translated = model.generate(**batch)
|
||||
... tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
|
||||
... assert (
|
||||
... tgt_text[0]
|
||||
... == "California's largest electricity provider has turned off power to hundreds of thousands of customers."
|
||||
... )
|
||||
```
|
||||
|
||||
## PegasusConfig
|
||||
|
||||
Reference in New Issue
Block a user