From c87bbe1ff0886044a3b2add3530becff4b2dcc9b Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Mon, 20 Feb 2023 03:27:09 -0500 Subject: [PATCH] Fix quality --- src/transformers/models/whisper/modeling_flax_whisper.py | 6 +++--- tests/models/whisper/test_modeling_flax_whisper.py | 4 ++-- tests/models/whisper/test_modeling_whisper.py | 1 + 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index f66a02453d..a928e145e8 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -821,7 +821,7 @@ class FlaxWhisperPreTrainedModel(FlaxPreTrainedModel): seed: int = 0, dtype: jnp.dtype = jnp.float32, _do_init: bool = True, - **kwargs + **kwargs, ): module = self.module_class(config=config, dtype=dtype, **kwargs) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) @@ -1348,7 +1348,7 @@ class FlaxWhisperForConditionalGeneration(FlaxWhisperPreTrainedModel): task=None, language=None, is_multilingual=None, - **kwargs + **kwargs, ): if generation_config is None: generation_config = self.generation_config @@ -1411,7 +1411,7 @@ class FlaxWhisperForConditionalGeneration(FlaxWhisperPreTrainedModel): attention_mask: Optional[jnp.DeviceArray] = None, decoder_attention_mask: Optional[jnp.DeviceArray] = None, encoder_outputs=None, - **kwargs + **kwargs, ): # initializing the cache batch_size, seq_length = decoder_input_ids.shape diff --git a/tests/models/whisper/test_modeling_flax_whisper.py b/tests/models/whisper/test_modeling_flax_whisper.py index a102f5d48d..3f1e201d72 100644 --- a/tests/models/whisper/test_modeling_flax_whisper.py +++ b/tests/models/whisper/test_modeling_flax_whisper.py @@ -34,11 +34,11 @@ if is_datasets_available(): from datasets import load_dataset if is_flax_available(): - import numpy as np - import jax + import numpy as np from flax.core.frozen_dict import unfreeze from flax.traverse_util import flatten_dict + from transformers import ( FLAX_MODEL_MAPPING, FlaxWhisperForConditionalGeneration, diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 68f8a5317a..fd5b5da014 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -51,6 +51,7 @@ if is_torch_available(): if is_flax_available(): import jax.numpy as jnp + from transformers.modeling_flax_pytorch_utils import ( convert_pytorch_state_dict_to_flax, load_flax_weights_in_pytorch_model,