[Flax] Correct all return tensors to numpy (#13307)

* fix_torch_device_generate_test

* remove @

* finish find and replace
This commit is contained in:
Patrick von Platen
2021-08-27 17:38:34 +02:00
committed by GitHub
parent 8aa67fc192
commit 2bef3433e5
10 changed files with 11 additions and 11 deletions

View File

@@ -110,7 +110,7 @@ def main():
inputs = tokenizer(
example["question"],
example["context"],
return_tensors="jax",
return_tensors="np",
max_length=4096,
padding="max_length",
truncation=True,