[Flax] Align jax flax device name (#12987)

* [Flax] Align device name in docs

* make style

* fix import error
This commit is contained in:
Patrick von Platen
2021-08-04 16:00:09 +02:00
committed by GitHub
parent 07df5578d9
commit da9754a3a0
9 changed files with 365 additions and 387 deletions

View File

@@ -44,7 +44,6 @@ if is_flax_available():
import jax
import jax.numpy as jnp
import jaxlib.xla_extension as jax_xla
from flax.core.frozen_dict import unfreeze
from flax.traverse_util import flatten_dict
from transformers import (
@@ -127,7 +126,7 @@ class FlaxModelTesterMixin:
if "ForMultipleChoice" in model_class.__name__:
inputs_dict = {
k: jnp.broadcast_to(v[:, None], (v.shape[0], self.model_tester.num_choices, v.shape[-1]))
if isinstance(v, (jax_xla.DeviceArray, np.ndarray))
if isinstance(v, (jnp.ndarray, np.ndarray))
else v
for k, v in inputs_dict.items()
}