[Flax] Align jax flax device name (#12987)
* [Flax] Align device name in docs * make style * fix import error
This commit is contained in:
committed by
GitHub
parent
07df5578d9
commit
da9754a3a0
@@ -44,7 +44,6 @@ if is_flax_available():
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jaxlib.xla_extension as jax_xla
|
||||
from flax.core.frozen_dict import unfreeze
|
||||
from flax.traverse_util import flatten_dict
|
||||
from transformers import (
|
||||
@@ -127,7 +126,7 @@ class FlaxModelTesterMixin:
|
||||
if "ForMultipleChoice" in model_class.__name__:
|
||||
inputs_dict = {
|
||||
k: jnp.broadcast_to(v[:, None], (v.shape[0], self.model_tester.num_choices, v.shape[-1]))
|
||||
if isinstance(v, (jax_xla.DeviceArray, np.ndarray))
|
||||
if isinstance(v, (jnp.ndarray, np.ndarray))
|
||||
else v
|
||||
for k, v in inputs_dict.items()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user