[Flax] Correct all return tensors to numpy (#13307)
* fix_torch_device_generate_test * remove @ * finish find and replace
This commit is contained in:
committed by
GitHub
parent
8aa67fc192
commit
2bef3433e5
@@ -453,7 +453,7 @@ class FlaxBartModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGenerationT
|
||||
padding="max_length",
|
||||
truncation_strategy="only_first",
|
||||
truncation=True,
|
||||
return_tensors="jax",
|
||||
return_tensors="np",
|
||||
)
|
||||
|
||||
self.assertEqual(1024, dct["input_ids"].shape[1])
|
||||
|
||||
@@ -213,7 +213,7 @@ class FlaxGPT2ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittes
|
||||
@slow
|
||||
def test_batch_generation(self):
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", pad_token="</s>", padding_side="left")
|
||||
inputs = tokenizer(["Hello this is a long string", "Hey"], return_tensors="jax", padding=True, truncation=True)
|
||||
inputs = tokenizer(["Hello this is a long string", "Hey"], return_tensors="np", padding=True, truncation=True)
|
||||
|
||||
model = FlaxGPT2LMHeadModel.from_pretrained("gpt2")
|
||||
model.do_sample = False
|
||||
|
||||
@@ -204,7 +204,7 @@ class FlaxGPTNeoModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unitt
|
||||
@slow
|
||||
def test_batch_generation(self):
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", pad_token="<|endoftext|>", padding_side="left")
|
||||
inputs = tokenizer(["Hello this is a long string", "Hey"], return_tensors="jax", padding=True, truncation=True)
|
||||
inputs = tokenizer(["Hello this is a long string", "Hey"], return_tensors="np", padding=True, truncation=True)
|
||||
|
||||
model = FlaxGPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-125M")
|
||||
model.do_sample = False
|
||||
|
||||
Reference in New Issue
Block a user