Fix quality
This commit is contained in:
@@ -34,11 +34,11 @@ if is_datasets_available():
|
||||
from datasets import load_dataset
|
||||
|
||||
if is_flax_available():
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
from flax.core.frozen_dict import unfreeze
|
||||
from flax.traverse_util import flatten_dict
|
||||
|
||||
from transformers import (
|
||||
FLAX_MODEL_MAPPING,
|
||||
FlaxWhisperForConditionalGeneration,
|
||||
|
||||
@@ -51,6 +51,7 @@ if is_torch_available():
|
||||
|
||||
if is_flax_available():
|
||||
import jax.numpy as jnp
|
||||
|
||||
from transformers.modeling_flax_pytorch_utils import (
|
||||
convert_pytorch_state_dict_to_flax,
|
||||
load_flax_weights_in_pytorch_model,
|
||||
|
||||
Reference in New Issue
Block a user