Fix quality
This commit is contained in:
@@ -821,7 +821,7 @@ class FlaxWhisperPreTrainedModel(FlaxPreTrainedModel):
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
_do_init: bool = True,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
@@ -1348,7 +1348,7 @@ class FlaxWhisperForConditionalGeneration(FlaxWhisperPreTrainedModel):
|
||||
task=None,
|
||||
language=None,
|
||||
is_multilingual=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
if generation_config is None:
|
||||
generation_config = self.generation_config
|
||||
@@ -1411,7 +1411,7 @@ class FlaxWhisperForConditionalGeneration(FlaxWhisperPreTrainedModel):
|
||||
attention_mask: Optional[jnp.DeviceArray] = None,
|
||||
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
|
||||
encoder_outputs=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# initializing the cache
|
||||
batch_size, seq_length = decoder_input_ids.shape
|
||||
|
||||
@@ -34,11 +34,11 @@ if is_datasets_available():
|
||||
from datasets import load_dataset
|
||||
|
||||
if is_flax_available():
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
from flax.core.frozen_dict import unfreeze
|
||||
from flax.traverse_util import flatten_dict
|
||||
|
||||
from transformers import (
|
||||
FLAX_MODEL_MAPPING,
|
||||
FlaxWhisperForConditionalGeneration,
|
||||
|
||||
@@ -51,6 +51,7 @@ if is_torch_available():
|
||||
|
||||
if is_flax_available():
|
||||
import jax.numpy as jnp
|
||||
|
||||
from transformers.modeling_flax_pytorch_utils import (
|
||||
convert_pytorch_state_dict_to_flax,
|
||||
load_flax_weights_in_pytorch_model,
|
||||
|
||||
Reference in New Issue
Block a user