T5 compile compatibilty (#34089)

* this worked in normal generation, needs more tests

* fix almost all tests in t5

* nit

* longt5, umt5, mt5

* style

* udop, pix2struct

* more models

* fix some tests

* fix onnx tests

* tracing tests fixed

* compile enabled and tested for t5 models

* fix small bug in slow tests

* [run-slow] t5

* uncomment

* style

* update with new generation refactoring

* nit

* fix copies

* this is the fix, had to change t5 to fix copies

* update

* [run-slow] t5

* [run-slow] t5

* update

* add test for encoder only T5

* clean up after rebase

* fix pop2piano

* add comment

* style

* fix copies after rebase

* fix copies  missed this one
This commit is contained in:
Raushan Turganbay
2024-10-22 08:23:53 +02:00
committed by GitHub
parent 5077bc034f
commit 73d65e637b
22 changed files with 2744 additions and 1179 deletions

View File

@@ -37,6 +37,7 @@ import transformers
from transformers import (
AutoModel,
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoTokenizer,
GenerationConfig,
@@ -5109,10 +5110,15 @@ class ModelTesterMixin:
batch_size = 1
n_iter = 3
tokenizer = AutoTokenizer.from_pretrained(ckpt, revision=revision)
model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to(
torch_device
)
tokenizer = AutoTokenizer.from_pretrained(ckpt)
if self.is_encoder_decoder:
model = AutoModelForSeq2SeqLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to(
torch_device
)
else:
model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to(
torch_device
)
model.generation_config.max_new_tokens = 4
@@ -5184,10 +5190,15 @@ class ModelTesterMixin:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
tokenizer = AutoTokenizer.from_pretrained(ckpt, revision=revision)
model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to(
torch_device
)
tokenizer = AutoTokenizer.from_pretrained(ckpt)
if self.is_encoder_decoder:
model = AutoModelForSeq2SeqLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to(
torch_device
)
else:
model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to(
torch_device
)
cache_implementation = "static"
if model.config.model_type == "gemma2":