From 2bef3433e5fa86345331ac3d856a63f6931f70bd Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 27 Aug 2021 17:38:34 +0200 Subject: [PATCH] [Flax] Correct all return tensors to numpy (#13307) * fix_torch_device_generate_test * remove @ * finish find and replace --- examples/research_projects/jax-projects/big_bird/evaluate.py | 2 +- src/transformers/file_utils.py | 2 +- src/transformers/generation_flax_utils.py | 2 +- src/transformers/models/bert/modeling_flax_bert.py | 2 +- src/transformers/models/big_bird/modeling_flax_big_bird.py | 2 +- src/transformers/models/electra/modeling_flax_electra.py | 2 +- src/transformers/models/vit/modeling_flax_vit.py | 4 ++-- tests/test_modeling_flax_bart.py | 2 +- tests/test_modeling_flax_gpt2.py | 2 +- tests/test_modeling_flax_gpt_neo.py | 2 +- 10 files changed, 11 insertions(+), 11 deletions(-) diff --git a/examples/research_projects/jax-projects/big_bird/evaluate.py b/examples/research_projects/jax-projects/big_bird/evaluate.py index d81db40a95..de01e8fc81 100644 --- a/examples/research_projects/jax-projects/big_bird/evaluate.py +++ b/examples/research_projects/jax-projects/big_bird/evaluate.py @@ -110,7 +110,7 @@ def main(): inputs = tokenizer( example["question"], example["context"], - return_tensors="jax", + return_tensors="np", max_length=4096, padding="max_length", truncation=True, diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index 0b31c17bd5..cd31371e03 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -1121,7 +1121,7 @@ FLAX_CAUSAL_LM_SAMPLE = r""" >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') >>> model = {model_class}.from_pretrained('{checkpoint}') - >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax") + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np") >>> outputs = model(**inputs) >>> # retrieve logts for next token diff --git a/src/transformers/generation_flax_utils.py b/src/transformers/generation_flax_utils.py index ec7e62b3fc..2b686a139b 100644 --- a/src/transformers/generation_flax_utils.py +++ b/src/transformers/generation_flax_utils.py @@ -231,7 +231,7 @@ class FlaxGenerationMixin: >>> model = FlaxAutoModelForCausalLM.from_pretrained("distilgpt2") >>> input_context = "The dog" >>> # encode input context - >>> input_ids = tokenizer(input_context, return_tensors="jax").input_ids + >>> input_ids = tokenizer(input_context, return_tensors="np").input_ids >>> # generate candidates using sampling >>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True) >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) diff --git a/src/transformers/models/bert/modeling_flax_bert.py b/src/transformers/models/bert/modeling_flax_bert.py index 4db61eece2..75204debf0 100644 --- a/src/transformers/models/bert/modeling_flax_bert.py +++ b/src/transformers/models/bert/modeling_flax_bert.py @@ -757,7 +757,7 @@ FLAX_BERT_FOR_PRETRAINING_DOCSTRING = """ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') >>> model = FlaxBertForPreTraining.from_pretrained('bert-base-uncased') - >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax") + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np") >>> outputs = model(**inputs) >>> prediction_logits = outputs.prediction_logits diff --git a/src/transformers/models/big_bird/modeling_flax_big_bird.py b/src/transformers/models/big_bird/modeling_flax_big_bird.py index 46809667dd..f001498041 100644 --- a/src/transformers/models/big_bird/modeling_flax_big_bird.py +++ b/src/transformers/models/big_bird/modeling_flax_big_bird.py @@ -1567,7 +1567,7 @@ FLAX_BIG_BIRD_FOR_PRETRAINING_DOCSTRING = """ >>> tokenizer = BigBirdTokenizer.from_pretrained('google/bigbird-roberta-base') >>> model = FlaxBigBirdForPreTraining.from_pretrained('google/bigbird-roberta-base') - >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax") + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np") >>> outputs = model(**inputs) >>> prediction_logits = outputs.prediction_logits diff --git a/src/transformers/models/electra/modeling_flax_electra.py b/src/transformers/models/electra/modeling_flax_electra.py index 43c38fcdd3..12c1afb897 100644 --- a/src/transformers/models/electra/modeling_flax_electra.py +++ b/src/transformers/models/electra/modeling_flax_electra.py @@ -761,7 +761,7 @@ FLAX_ELECTRA_FOR_PRETRAINING_DOCSTRING = """ >>> tokenizer = ElectraTokenizer.from_pretrained('google/electra-small-discriminator') >>> model = FlaxElectraForPreTraining.from_pretrained('google/electra-small-discriminator') - >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax") + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np") >>> outputs = model(**inputs) >>> prediction_logits = outputs.logits diff --git a/src/transformers/models/vit/modeling_flax_vit.py b/src/transformers/models/vit/modeling_flax_vit.py index 7b448da8b0..d2b23b5493 100644 --- a/src/transformers/models/vit/modeling_flax_vit.py +++ b/src/transformers/models/vit/modeling_flax_vit.py @@ -512,7 +512,7 @@ FLAX_VISION_MODEL_DOCSTRING = """ >>> feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k') >>> model = FlaxViTModel.from_pretrained('google/vit-base-patch16-224-in21k') - >>> inputs = feature_extractor(images=image, return_tensors="jax") + >>> inputs = feature_extractor(images=image, return_tensors="np") >>> outputs = model(**inputs) >>> last_hidden_states = outputs.last_hidden_state """ @@ -592,7 +592,7 @@ FLAX_VISION_CLASSIF_DOCSTRING = """ >>> feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224') >>> model = FlaxViTForImageClassification.from_pretrained('google/vit-base-patch16-224') - >>> inputs = feature_extractor(images=image, return_tensors="jax") + >>> inputs = feature_extractor(images=image, return_tensors="np") >>> outputs = model(**inputs) >>> logits = outputs.logits diff --git a/tests/test_modeling_flax_bart.py b/tests/test_modeling_flax_bart.py index ea19b4b6d9..d1a51e3612 100644 --- a/tests/test_modeling_flax_bart.py +++ b/tests/test_modeling_flax_bart.py @@ -453,7 +453,7 @@ class FlaxBartModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGenerationT padding="max_length", truncation_strategy="only_first", truncation=True, - return_tensors="jax", + return_tensors="np", ) self.assertEqual(1024, dct["input_ids"].shape[1]) diff --git a/tests/test_modeling_flax_gpt2.py b/tests/test_modeling_flax_gpt2.py index b93c6e5985..0c793ebd27 100644 --- a/tests/test_modeling_flax_gpt2.py +++ b/tests/test_modeling_flax_gpt2.py @@ -213,7 +213,7 @@ class FlaxGPT2ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittes @slow def test_batch_generation(self): tokenizer = GPT2Tokenizer.from_pretrained("gpt2", pad_token="", padding_side="left") - inputs = tokenizer(["Hello this is a long string", "Hey"], return_tensors="jax", padding=True, truncation=True) + inputs = tokenizer(["Hello this is a long string", "Hey"], return_tensors="np", padding=True, truncation=True) model = FlaxGPT2LMHeadModel.from_pretrained("gpt2") model.do_sample = False diff --git a/tests/test_modeling_flax_gpt_neo.py b/tests/test_modeling_flax_gpt_neo.py index 93eccf0872..2916bec5b9 100644 --- a/tests/test_modeling_flax_gpt_neo.py +++ b/tests/test_modeling_flax_gpt_neo.py @@ -204,7 +204,7 @@ class FlaxGPTNeoModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unitt @slow def test_batch_generation(self): tokenizer = GPT2Tokenizer.from_pretrained("gpt2", pad_token="<|endoftext|>", padding_side="left") - inputs = tokenizer(["Hello this is a long string", "Hey"], return_tensors="jax", padding=True, truncation=True) + inputs = tokenizer(["Hello this is a long string", "Hey"], return_tensors="np", padding=True, truncation=True) model = FlaxGPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-125M") model.do_sample = False