[Flax] Correct all return tensors to numpy (#13307)
* fix_torch_device_generate_test * remove @ * finish find and replace
This commit is contained in:
committed by
GitHub
parent
8aa67fc192
commit
2bef3433e5
@@ -110,7 +110,7 @@ def main():
|
|||||||
inputs = tokenizer(
|
inputs = tokenizer(
|
||||||
example["question"],
|
example["question"],
|
||||||
example["context"],
|
example["context"],
|
||||||
return_tensors="jax",
|
return_tensors="np",
|
||||||
max_length=4096,
|
max_length=4096,
|
||||||
padding="max_length",
|
padding="max_length",
|
||||||
truncation=True,
|
truncation=True,
|
||||||
|
|||||||
@@ -1121,7 +1121,7 @@ FLAX_CAUSAL_LM_SAMPLE = r"""
|
|||||||
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
||||||
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
||||||
|
|
||||||
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax")
|
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np")
|
||||||
>>> outputs = model(**inputs)
|
>>> outputs = model(**inputs)
|
||||||
|
|
||||||
>>> # retrieve logts for next token
|
>>> # retrieve logts for next token
|
||||||
|
|||||||
@@ -231,7 +231,7 @@ class FlaxGenerationMixin:
|
|||||||
>>> model = FlaxAutoModelForCausalLM.from_pretrained("distilgpt2")
|
>>> model = FlaxAutoModelForCausalLM.from_pretrained("distilgpt2")
|
||||||
>>> input_context = "The dog"
|
>>> input_context = "The dog"
|
||||||
>>> # encode input context
|
>>> # encode input context
|
||||||
>>> input_ids = tokenizer(input_context, return_tensors="jax").input_ids
|
>>> input_ids = tokenizer(input_context, return_tensors="np").input_ids
|
||||||
>>> # generate candidates using sampling
|
>>> # generate candidates using sampling
|
||||||
>>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True)
|
>>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True)
|
||||||
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
|
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
|
||||||
|
|||||||
@@ -757,7 +757,7 @@ FLAX_BERT_FOR_PRETRAINING_DOCSTRING = """
|
|||||||
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||||
>>> model = FlaxBertForPreTraining.from_pretrained('bert-base-uncased')
|
>>> model = FlaxBertForPreTraining.from_pretrained('bert-base-uncased')
|
||||||
|
|
||||||
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax")
|
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np")
|
||||||
>>> outputs = model(**inputs)
|
>>> outputs = model(**inputs)
|
||||||
|
|
||||||
>>> prediction_logits = outputs.prediction_logits
|
>>> prediction_logits = outputs.prediction_logits
|
||||||
|
|||||||
@@ -1567,7 +1567,7 @@ FLAX_BIG_BIRD_FOR_PRETRAINING_DOCSTRING = """
|
|||||||
>>> tokenizer = BigBirdTokenizer.from_pretrained('google/bigbird-roberta-base')
|
>>> tokenizer = BigBirdTokenizer.from_pretrained('google/bigbird-roberta-base')
|
||||||
>>> model = FlaxBigBirdForPreTraining.from_pretrained('google/bigbird-roberta-base')
|
>>> model = FlaxBigBirdForPreTraining.from_pretrained('google/bigbird-roberta-base')
|
||||||
|
|
||||||
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax")
|
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np")
|
||||||
>>> outputs = model(**inputs)
|
>>> outputs = model(**inputs)
|
||||||
|
|
||||||
>>> prediction_logits = outputs.prediction_logits
|
>>> prediction_logits = outputs.prediction_logits
|
||||||
|
|||||||
@@ -761,7 +761,7 @@ FLAX_ELECTRA_FOR_PRETRAINING_DOCSTRING = """
|
|||||||
>>> tokenizer = ElectraTokenizer.from_pretrained('google/electra-small-discriminator')
|
>>> tokenizer = ElectraTokenizer.from_pretrained('google/electra-small-discriminator')
|
||||||
>>> model = FlaxElectraForPreTraining.from_pretrained('google/electra-small-discriminator')
|
>>> model = FlaxElectraForPreTraining.from_pretrained('google/electra-small-discriminator')
|
||||||
|
|
||||||
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax")
|
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np")
|
||||||
>>> outputs = model(**inputs)
|
>>> outputs = model(**inputs)
|
||||||
|
|
||||||
>>> prediction_logits = outputs.logits
|
>>> prediction_logits = outputs.logits
|
||||||
|
|||||||
@@ -512,7 +512,7 @@ FLAX_VISION_MODEL_DOCSTRING = """
|
|||||||
>>> feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
|
>>> feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
|
||||||
>>> model = FlaxViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
|
>>> model = FlaxViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
|
||||||
|
|
||||||
>>> inputs = feature_extractor(images=image, return_tensors="jax")
|
>>> inputs = feature_extractor(images=image, return_tensors="np")
|
||||||
>>> outputs = model(**inputs)
|
>>> outputs = model(**inputs)
|
||||||
>>> last_hidden_states = outputs.last_hidden_state
|
>>> last_hidden_states = outputs.last_hidden_state
|
||||||
"""
|
"""
|
||||||
@@ -592,7 +592,7 @@ FLAX_VISION_CLASSIF_DOCSTRING = """
|
|||||||
>>> feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
|
>>> feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
|
||||||
>>> model = FlaxViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
|
>>> model = FlaxViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
|
||||||
|
|
||||||
>>> inputs = feature_extractor(images=image, return_tensors="jax")
|
>>> inputs = feature_extractor(images=image, return_tensors="np")
|
||||||
>>> outputs = model(**inputs)
|
>>> outputs = model(**inputs)
|
||||||
>>> logits = outputs.logits
|
>>> logits = outputs.logits
|
||||||
|
|
||||||
|
|||||||
@@ -453,7 +453,7 @@ class FlaxBartModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGenerationT
|
|||||||
padding="max_length",
|
padding="max_length",
|
||||||
truncation_strategy="only_first",
|
truncation_strategy="only_first",
|
||||||
truncation=True,
|
truncation=True,
|
||||||
return_tensors="jax",
|
return_tensors="np",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(1024, dct["input_ids"].shape[1])
|
self.assertEqual(1024, dct["input_ids"].shape[1])
|
||||||
|
|||||||
@@ -213,7 +213,7 @@ class FlaxGPT2ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittes
|
|||||||
@slow
|
@slow
|
||||||
def test_batch_generation(self):
|
def test_batch_generation(self):
|
||||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", pad_token="</s>", padding_side="left")
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", pad_token="</s>", padding_side="left")
|
||||||
inputs = tokenizer(["Hello this is a long string", "Hey"], return_tensors="jax", padding=True, truncation=True)
|
inputs = tokenizer(["Hello this is a long string", "Hey"], return_tensors="np", padding=True, truncation=True)
|
||||||
|
|
||||||
model = FlaxGPT2LMHeadModel.from_pretrained("gpt2")
|
model = FlaxGPT2LMHeadModel.from_pretrained("gpt2")
|
||||||
model.do_sample = False
|
model.do_sample = False
|
||||||
|
|||||||
@@ -204,7 +204,7 @@ class FlaxGPTNeoModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unitt
|
|||||||
@slow
|
@slow
|
||||||
def test_batch_generation(self):
|
def test_batch_generation(self):
|
||||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", pad_token="<|endoftext|>", padding_side="left")
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", pad_token="<|endoftext|>", padding_side="left")
|
||||||
inputs = tokenizer(["Hello this is a long string", "Hey"], return_tensors="jax", padding=True, truncation=True)
|
inputs = tokenizer(["Hello this is a long string", "Hey"], return_tensors="np", padding=True, truncation=True)
|
||||||
|
|
||||||
model = FlaxGPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-125M")
|
model = FlaxGPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-125M")
|
||||||
model.do_sample = False
|
model.do_sample = False
|
||||||
|
|||||||
Reference in New Issue
Block a user