Update T5gemma (#39210)
* bug fix: add vocab_size to t5gemmaconfig for pipeline. * Update checkpoint placeholder * minor change * minor change * minor change: update example. * fix: add vocab_size as an explict arg. * buf fix: remove vocab_size verification; instead, re-set encoder/decoder vocab size. Note, in t5gemma, vocab size of encoder/decoder shoud be always the same. * add `add_generation_prompt` for message preprocessing.
This commit is contained in:
@@ -14,7 +14,13 @@ specific language governing permissions and limitations under the License.
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
# T5Gemma
|
||||
|
||||
@@ -24,6 +30,9 @@ T5Gemma has two groups of model sizes: 1) [Gemma 2](https://ai.google.dev/gemma/
|
||||
|
||||
The pretrained varaints are trained with two objectives: prefix language modeling with knowledge distillation (PrefixLM) and UL2, separately. We release both variants for each model size. The instruction-turned varaints was post-trained with supervised fine-tuning and reinforcement learning.
|
||||
|
||||
> [!TIP]
|
||||
> Click on the T5Gemma models in the right sidebar for more examples of how to apply T5Gemma to different language tasks.
|
||||
|
||||
The example below demonstrates how to chat with the model with [`Pipeline`] or the [`AutoModel`] class, and from the command line.
|
||||
|
||||
<hfoptions id="usage">
|
||||
@@ -35,43 +44,52 @@ import torch
|
||||
from transformers import pipeline
|
||||
|
||||
pipe = pipeline(
|
||||
task="text2text-generation",
|
||||
model="google/t5gemma-placeholder",
|
||||
"text2text-generation",
|
||||
model="google/t5gemma-2b-2b-prefixlm-it",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
device="cuda", # replace with "mps" to run on a Mac device
|
||||
)
|
||||
|
||||
pipe("Question: Why is the sky blue?\nAnswer:", max_new_tokens=50)
|
||||
messages = [
|
||||
{"role": "user", "content": "Tell me an unknown interesting biology fact about the brain."},
|
||||
]
|
||||
prompt = pipe.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
|
||||
pipe(prompt, max_new_tokens=32)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
```python
|
||||
import torch
|
||||
# pip install accelerate
|
||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
||||
import torch
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/t5gemma-placeholder")
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/t5gemma-2b-2b-prefixlm-it")
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(
|
||||
"google/t5gemma-placeholder",
|
||||
"google/t5gemma-2b-2b-prefixlm-it",
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto"
|
||||
)
|
||||
|
||||
input_text = "Question: Why is the sky blue?\nAnswer:"
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||
messages = [
|
||||
{"role": "user", "content": "Tell me an unknown interesting biology fact about the brain."},
|
||||
]
|
||||
input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt", return_dict=True, add_generation_prompt=True).to("cuda")
|
||||
|
||||
outputs = model.generate(**input_ids, max_new_tokens=32)
|
||||
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||
|
||||
print(tokenizer.decode(outputs[0]))
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="transformers CLI">
|
||||
|
||||
```
|
||||
echo -e "Question: Why is the sky blue? Answer:" | transformers run --task text2text-generation --model google/t5gemma-placeholder --device 0
|
||||
echo -e "Write me a poem about Machine Learning. Answer:" | transformers run --task text2text-generation --model google/t5gemma-2b-2b-prefixlm --device 0
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## T5GemmaConfig
|
||||
|
||||
|
||||
Reference in New Issue
Block a user