Fix bigbird random attention (#21023)
* switch np.random.permutation to jax.random.permuation * remove comments * remove leftover comment * skip similarity tests * modify indices_prng_key usage, add deterministic behaviour * update style * remove unused import * remove copy statement since classes are not identical * remove numpy import * revert removing copied from statements * make style from copied * remove copied from statement * update copied from statement to include only np.ndarry * add deterministic args, unittestskip equivalence tests
This commit is contained in:
committed by
GitHub
parent
27b66bea01
commit
88399476c3
@@ -158,7 +158,7 @@ class FlaxModelTesterMixin:
|
||||
if "ForMultipleChoice" in model_class.__name__:
|
||||
inputs_dict = {
|
||||
k: jnp.broadcast_to(v[:, None], (v.shape[0], self.model_tester.num_choices, v.shape[-1]))
|
||||
if isinstance(v, (jnp.ndarray, np.ndarray))
|
||||
if isinstance(v, (jnp.ndarray, np.ndarray)) and k != "indices_prng_key"
|
||||
else v
|
||||
for k, v in inputs_dict.items()
|
||||
}
|
||||
@@ -629,7 +629,6 @@ class FlaxModelTesterMixin:
|
||||
def test_hidden_states_output(self):
|
||||
def check_hidden_states_output(inputs_dict, config, model_class):
|
||||
model = model_class(config)
|
||||
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
|
||||
|
||||
|
||||
Reference in New Issue
Block a user