T5 compile compatibilty (#34089)
* this worked in normal generation, needs more tests * fix almost all tests in t5 * nit * longt5, umt5, mt5 * style * udop, pix2struct * more models * fix some tests * fix onnx tests * tracing tests fixed * compile enabled and tested for t5 models * fix small bug in slow tests * [run-slow] t5 * uncomment * style * update with new generation refactoring * nit * fix copies * this is the fix, had to change t5 to fix copies * update * [run-slow] t5 * [run-slow] t5 * update * add test for encoder only T5 * clean up after rebase * fix pop2piano * add comment * style * fix copies after rebase * fix copies missed this one
This commit is contained in:
committed by
GitHub
parent
5077bc034f
commit
73d65e637b
@@ -37,6 +37,7 @@ import transformers
|
||||
from transformers import (
|
||||
AutoModel,
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoTokenizer,
|
||||
GenerationConfig,
|
||||
@@ -5109,10 +5110,15 @@ class ModelTesterMixin:
|
||||
batch_size = 1
|
||||
n_iter = 3
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(ckpt, revision=revision)
|
||||
model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to(
|
||||
torch_device
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(ckpt)
|
||||
if self.is_encoder_decoder:
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to(
|
||||
torch_device
|
||||
)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
model.generation_config.max_new_tokens = 4
|
||||
|
||||
@@ -5184,10 +5190,15 @@ class ModelTesterMixin:
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(ckpt, revision=revision)
|
||||
model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to(
|
||||
torch_device
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(ckpt)
|
||||
if self.is_encoder_decoder:
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to(
|
||||
torch_device
|
||||
)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
cache_implementation = "static"
|
||||
if model.config.model_type == "gemma2":
|
||||
|
||||
Reference in New Issue
Block a user