Fix bigbird random attention (#21023)

* switch np.random.permutation to jax.random.permuation

* remove comments

* remove leftover comment

* skip similarity tests

* modify indices_prng_key usage, add deterministic behaviour

* update style

* remove unused import

* remove copy statement since classes are not identical

* remove numpy import

* revert removing copied from statements

* make style from copied

* remove copied from statement

* update copied from statement to include only np.ndarry

* add deterministic args, unittestskip equivalence tests
This commit is contained in:
Bartosz Szmelczynski
2023-04-27 19:52:28 +02:00
committed by GitHub
parent 27b66bea01
commit 88399476c3
3 changed files with 135 additions and 58 deletions

View File

@@ -158,7 +158,7 @@ class FlaxModelTesterMixin:
if "ForMultipleChoice" in model_class.__name__:
inputs_dict = {
k: jnp.broadcast_to(v[:, None], (v.shape[0], self.model_tester.num_choices, v.shape[-1]))
if isinstance(v, (jnp.ndarray, np.ndarray))
if isinstance(v, (jnp.ndarray, np.ndarray)) and k != "indices_prng_key"
else v
for k, v in inputs_dict.items()
}
@@ -629,7 +629,6 @@ class FlaxModelTesterMixin:
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states