Fix quality
This commit is contained in:
@@ -821,7 +821,7 @@ class FlaxWhisperPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
dtype: jnp.dtype = jnp.float32,
|
dtype: jnp.dtype = jnp.float32,
|
||||||
_do_init: bool = True,
|
_do_init: bool = True,
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||||
@@ -1348,7 +1348,7 @@ class FlaxWhisperForConditionalGeneration(FlaxWhisperPreTrainedModel):
|
|||||||
task=None,
|
task=None,
|
||||||
language=None,
|
language=None,
|
||||||
is_multilingual=None,
|
is_multilingual=None,
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
if generation_config is None:
|
if generation_config is None:
|
||||||
generation_config = self.generation_config
|
generation_config = self.generation_config
|
||||||
@@ -1411,7 +1411,7 @@ class FlaxWhisperForConditionalGeneration(FlaxWhisperPreTrainedModel):
|
|||||||
attention_mask: Optional[jnp.DeviceArray] = None,
|
attention_mask: Optional[jnp.DeviceArray] = None,
|
||||||
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
|
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
# initializing the cache
|
# initializing the cache
|
||||||
batch_size, seq_length = decoder_input_ids.shape
|
batch_size, seq_length = decoder_input_ids.shape
|
||||||
|
|||||||
@@ -34,11 +34,11 @@ if is_datasets_available():
|
|||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|
||||||
if is_flax_available():
|
if is_flax_available():
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
|
import numpy as np
|
||||||
from flax.core.frozen_dict import unfreeze
|
from flax.core.frozen_dict import unfreeze
|
||||||
from flax.traverse_util import flatten_dict
|
from flax.traverse_util import flatten_dict
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
FLAX_MODEL_MAPPING,
|
FLAX_MODEL_MAPPING,
|
||||||
FlaxWhisperForConditionalGeneration,
|
FlaxWhisperForConditionalGeneration,
|
||||||
|
|||||||
@@ -51,6 +51,7 @@ if is_torch_available():
|
|||||||
|
|
||||||
if is_flax_available():
|
if is_flax_available():
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
|
||||||
from transformers.modeling_flax_pytorch_utils import (
|
from transformers.modeling_flax_pytorch_utils import (
|
||||||
convert_pytorch_state_dict_to_flax,
|
convert_pytorch_state_dict_to_flax,
|
||||||
load_flax_weights_in_pytorch_model,
|
load_flax_weights_in_pytorch_model,
|
||||||
|
|||||||
Reference in New Issue
Block a user