Patch fix - don't use safetensors for TF models (#30118)
* Patch fix - don't use safetensors for TF models * Skip test for TF for now * Update for another test
This commit is contained in:
@@ -111,7 +111,7 @@ class GenerationIntegrationTestsMixin:
|
||||
article = """Justin Timberlake."""
|
||||
gpt2_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
|
||||
gpt2_model = model_cls.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
gpt2_model = model_cls.from_pretrained("hf-internal-testing/tiny-random-gpt2", use_safetensors=is_pt)
|
||||
input_ids = gpt2_tokenizer(article, return_tensors=return_tensors).input_ids
|
||||
if is_pt:
|
||||
gpt2_model = gpt2_model.to(torch_device)
|
||||
@@ -582,7 +582,7 @@ class GenerationIntegrationTestsMixin:
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
text = """Hello, my dog is cute and"""
|
||||
tokens = tokenizer(text, return_tensors=return_tensors)
|
||||
model = model_cls.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
model = model_cls.from_pretrained("hf-internal-testing/tiny-random-gpt2", use_safetensors=is_pt)
|
||||
if is_pt:
|
||||
model = model.to(torch_device)
|
||||
tokens = tokens.to(torch_device)
|
||||
@@ -611,7 +611,7 @@ class GenerationIntegrationTestsMixin:
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
text = """Hello, my dog is cute and"""
|
||||
tokens = tokenizer(text, return_tensors=return_tensors)
|
||||
model = model_cls.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
model = model_cls.from_pretrained("hf-internal-testing/tiny-random-gpt2", use_safetensors=is_pt)
|
||||
if is_pt:
|
||||
model = model.to(torch_device)
|
||||
tokens = tokens.to(torch_device)
|
||||
@@ -638,7 +638,7 @@ class GenerationIntegrationTestsMixin:
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
text = """Hello, my dog is cute and"""
|
||||
tokens = tokenizer(text, return_tensors=return_tensors)
|
||||
model = model_cls.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
model = model_cls.from_pretrained("hf-internal-testing/tiny-random-gpt2", use_safetensors=is_pt)
|
||||
if is_pt:
|
||||
model = model.to(torch_device)
|
||||
tokens = tokens.to(torch_device)
|
||||
|
||||
@@ -194,7 +194,7 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
text = """Hello, my dog is cute and"""
|
||||
tokens = tokenizer(text, return_tensors="tf")
|
||||
model = TFAutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
model = TFAutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2", use_safetensors=False)
|
||||
|
||||
eos_token_id = 638
|
||||
# forces the generation to happen on CPU, to avoid GPU-related quirks
|
||||
|
||||
Reference in New Issue
Block a user