add flax whisper implementation (#20479)
* add flax whisper implementation * rever change to setup * remove unused imports * revert generation changes * flax whisper docs * docs * import order * import sorting * isort * add dummy objects * doc formatting * formatting * remove trailing whitespaces * fix flax whisper docs * add generation logic to unlock flax whisper * remove scans * give credits to Flax Bart implementation * remove unused imports * add license * remove assert * more credits to Bart * fix style * formatting * support left padding * add flax whisper generation test * remove copied from comments whenever not a full copy * fix docstrings for logits processors * revert change to FlaxForceTokensLogitsProcessor * revert doc changes * improve generation docs * reorganize * formatting * cleanup docs * add tests * handle empty list case * fix forced decoder ids in flax tests * add flax whisper to inits * upate dummy objects * docs for FlaxAutoModelForSpeechSeq2Seq * fix decoder_position_ids computation in pretrained model decode/__call__ fns * add Copied from statements as necessary * compute position_ids only in __call__ and decode methods of pretrained model subclasses * improve readabilityof compute positional embeddings * check dimensionality of input_features instead of hidden_states * copied from statement for init_cache * formatting * fix copies * fix copies * pass attention mask to encoder layers * fix decoder module outputs * set dtype Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * smaller flax model for whisper test * Update src/transformers/generation/flax_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/whisper/modeling_flax_whisper.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update tests/models/whisper/test_modeling_flax_whisper.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * cleanup Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/whisper/modeling_flax_whisper.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * bias cleanup * doc fix * align style for force tokens processor * readability * fix input shape in tests * revert FlaxGenerationMixin docstring * formatting * fix tests * fix imports * consistent encoder hidden states * consistent hidden states * input shapes * typo * partial class trick * partial class for input shape * base_class with correct input shape * partial base classes * match by name * set main_input_name * compare on names * formatting * remove unused import * safer position ids computation * safer position id computation * Update src/transformers/models/whisper/modeling_flax_whisper.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update src/transformers/models/whisper/modeling_flax_whisper.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * remove identical inherited tests * fix prompt ids in tests * use generation config * use jnp array * better var names * more explicit bias use * import transformers * formatting * test formatting * remove unused imports * remove unused imports * formatting * isort * docs * fix ln orders for encoder hidden states * whisper unique generation stuff * flake * use finfo for attention bias * docs * Update src/transformers/generation/flax_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * docs * add timestamp flax test * jit for timestamps * formatting * clean up timestamps processor * formatting * remove if_true * cleanup --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -402,7 +402,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| Wav2Vec2 | ✅ | ❌ | ✅ | ✅ | ✅ |
|
||||
| Wav2Vec2-Conformer | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| WavLM | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| Whisper | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| Whisper | ✅ | ❌ | ✅ | ✅ | ✅ |
|
||||
| X-CLIP | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| X-MOD | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| XGLM | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
|
||||
@@ -286,6 +286,10 @@ The following auto classes are available for the following audio tasks.
|
||||
|
||||
[[autodoc]] TFAutoModelForSpeechSeq2Seq
|
||||
|
||||
### FlaxAutoModelForSpeechSeq2Seq
|
||||
|
||||
[[autodoc]] FlaxAutoModelForSpeechSeq2Seq
|
||||
|
||||
### AutoModelForAudioXVector
|
||||
|
||||
[[autodoc]] AutoModelForAudioXVector
|
||||
|
||||
@@ -79,3 +79,14 @@ The original code can be found [here](https://github.com/openai/whisper).
|
||||
|
||||
[[autodoc]] TFWhisperForConditionalGeneration
|
||||
- call
|
||||
|
||||
|
||||
## FlaxWhisperModel
|
||||
|
||||
[[autodoc]] FlaxWhisperModel
|
||||
- __call__
|
||||
|
||||
## FlaxWhisperForConditionalGeneration
|
||||
|
||||
[[autodoc]] FlaxWhisperForConditionalGeneration
|
||||
- __call__
|
||||
|
||||
@@ -3382,6 +3382,7 @@ else:
|
||||
"FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
|
||||
"FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
|
||||
"FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
|
||||
"FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
|
||||
"FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
|
||||
"FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING",
|
||||
"FLAX_MODEL_MAPPING",
|
||||
@@ -3395,6 +3396,7 @@ else:
|
||||
"FlaxAutoModelForQuestionAnswering",
|
||||
"FlaxAutoModelForSeq2SeqLM",
|
||||
"FlaxAutoModelForSequenceClassification",
|
||||
"FlaxAutoModelForSpeechSeq2Seq",
|
||||
"FlaxAutoModelForTokenClassification",
|
||||
"FlaxAutoModelForVision2Seq",
|
||||
]
|
||||
@@ -3578,6 +3580,13 @@ else:
|
||||
_import_structure["models.wav2vec2"].extend(
|
||||
["FlaxWav2Vec2ForCTC", "FlaxWav2Vec2ForPreTraining", "FlaxWav2Vec2Model", "FlaxWav2Vec2PreTrainedModel"]
|
||||
)
|
||||
_import_structure["models.whisper"].extend(
|
||||
[
|
||||
"FlaxWhisperForConditionalGeneration",
|
||||
"FlaxWhisperModel",
|
||||
"FlaxWhisperPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.xglm"].extend(
|
||||
[
|
||||
"FlaxXGLMForCausalLM",
|
||||
@@ -6381,6 +6390,7 @@ if TYPE_CHECKING:
|
||||
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
||||
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING,
|
||||
FLAX_MODEL_MAPPING,
|
||||
@@ -6394,6 +6404,7 @@ if TYPE_CHECKING:
|
||||
FlaxAutoModelForQuestionAnswering,
|
||||
FlaxAutoModelForSeq2SeqLM,
|
||||
FlaxAutoModelForSequenceClassification,
|
||||
FlaxAutoModelForSpeechSeq2Seq,
|
||||
FlaxAutoModelForTokenClassification,
|
||||
FlaxAutoModelForVision2Seq,
|
||||
)
|
||||
@@ -6529,6 +6540,7 @@ if TYPE_CHECKING:
|
||||
FlaxWav2Vec2Model,
|
||||
FlaxWav2Vec2PreTrainedModel,
|
||||
)
|
||||
from .models.whisper import FlaxWhisperForConditionalGeneration, FlaxWhisperModel, FlaxWhisperPreTrainedModel
|
||||
from .models.xglm import FlaxXGLMForCausalLM, FlaxXGLMModel, FlaxXGLMPreTrainedModel
|
||||
from .models.xlm_roberta import (
|
||||
FLAX_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
|
||||
@@ -264,3 +264,191 @@ class FlaxMinLengthLogitsProcessor(FlaxLogitsProcessor):
|
||||
scores = jnp.where(apply_penalty, scores.at[:, self.eos_token_id].set(-float("inf")), scores)
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
class FlaxSuppressTokensAtBeginLogitsProcessor(FlaxLogitsProcessor):
|
||||
r"""
|
||||
[`FlaxLogitsProcessor`] supressing a list of tokens as soon as the `generate` function starts generating using
|
||||
`begin_index` tokens. This should ensure that the tokens defined by `begin_suppress_tokens` are not sampled at the
|
||||
begining of the generation.
|
||||
|
||||
Args:
|
||||
begin_suppress_tokens (`List[int]`):
|
||||
Tokens to not sample.
|
||||
begin_index (`int`):
|
||||
Index where the tokens are suppressed.
|
||||
"""
|
||||
|
||||
def __init__(self, begin_suppress_tokens, begin_index):
|
||||
self.begin_suppress_tokens = list(begin_suppress_tokens)
|
||||
self.begin_index = begin_index
|
||||
|
||||
def __call__(self, input_ids, scores, cur_len: int):
|
||||
apply_penalty = 1 - jnp.bool_(cur_len - self.begin_index)
|
||||
|
||||
scores = jnp.where(apply_penalty, scores.at[:, self.begin_suppress_tokens].set(-float("inf")), scores)
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
class FlaxSuppressTokensLogitsProcessor(FlaxLogitsProcessor):
|
||||
r"""
|
||||
[`FlaxLogitsProcessor`] suppressing a list of tokens at each decoding step. The processor will set their log probs
|
||||
to be `-inf` so they are not sampled.
|
||||
|
||||
Args:
|
||||
suppress_tokens (`list`):
|
||||
Tokens to not sample.
|
||||
"""
|
||||
|
||||
def __init__(self, suppress_tokens: list):
|
||||
self.suppress_tokens = list(suppress_tokens)
|
||||
|
||||
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
|
||||
scores = scores.at[..., self.suppress_tokens].set(-float("inf"))
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
class FlaxForceTokensLogitsProcessor(FlaxLogitsProcessor):
|
||||
r"""
|
||||
[`FlaxLogitsProcessor`] that takes a list of pairs of integers which indicates a mapping from generation indices to
|
||||
token indices that will be forced before sampling. The processor will set their log probs to 0 and all other tokens
|
||||
to `-inf` so that they are sampled at their corresponding index.
|
||||
|
||||
Args:
|
||||
force_token_map (`list`):
|
||||
Map giving token ids and indices where they will be forced to be sampled.
|
||||
"""
|
||||
|
||||
def __init__(self, force_token_map):
|
||||
force_token_map = dict(force_token_map)
|
||||
# Converts the dictionary of format {index: token} containing the tokens to be forced to an array, where the
|
||||
# index of the array corresponds to the index of the token to be forced, for XLA compatibility.
|
||||
# Indexes without forced tokens will have a negative value.
|
||||
force_token_array = jnp.ones((max(force_token_map.keys()) + 1), dtype=jnp.int32) * -1
|
||||
for index, token in force_token_map.items():
|
||||
force_token_array = force_token_array.at[index].set(token)
|
||||
self.force_token_array = jnp.int32(force_token_array)
|
||||
|
||||
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
|
||||
def _force_token(generation_idx):
|
||||
batch_size = scores.shape[0]
|
||||
current_token = self.force_token_array[generation_idx]
|
||||
|
||||
new_scores = jnp.ones_like(scores, dtype=scores.dtype) * -float("inf")
|
||||
updates = jnp.zeros((batch_size, 1), dtype=scores.dtype)
|
||||
new_scores = lax.dynamic_update_slice(new_scores, updates, (0, current_token))
|
||||
return new_scores
|
||||
|
||||
scores = lax.cond(
|
||||
cur_len >= self.force_token_array.shape[0],
|
||||
# If the current length is geq than the length of force_token_array, the processor does nothing.
|
||||
lambda: scores,
|
||||
# Otherwise, it may force a certain token.
|
||||
lambda: lax.cond(
|
||||
self.force_token_array[cur_len] >= 0,
|
||||
# Only valid (positive) tokens are forced
|
||||
lambda: _force_token(cur_len),
|
||||
# Otherwise, the processor does nothing.
|
||||
lambda: scores,
|
||||
),
|
||||
)
|
||||
return scores
|
||||
|
||||
|
||||
class FlaxWhisperTimeStampLogitsProcessor(FlaxLogitsProcessor):
|
||||
r"""
|
||||
Whisper specific Processor. This processor can be used to force a list of tokens. The processor will set their log
|
||||
probs to `inf` so that they are sampled at their corresponding index.
|
||||
|
||||
Args:
|
||||
generate_config (`GenerateConfig`):
|
||||
The generate config used to generate the output. The following parameters are required:
|
||||
eos_token_id (`int`, *optional*, defaults to 50257):
|
||||
The id of the *end-of-sequence* token.
|
||||
no_timestamps_token_id (`int`, *optional*, defaults to 50363):
|
||||
The id of the `"<|notimestamps|>"` token.
|
||||
max_initial_timestamp_index (`int`, *optional*, defaults to 1):
|
||||
Used to set the maximum value of the initial timestamp. This is used to prevent the model from
|
||||
predicting timestamps that are too far in the future.
|
||||
"""
|
||||
|
||||
def __init__(self, generate_config, model_config, decoder_input_length):
|
||||
self.eos_token_id = generate_config.eos_token_id
|
||||
self.no_timestamps_token_id = generate_config.no_timestamps_token_id
|
||||
self.timestamp_begin = generate_config.no_timestamps_token_id + 1
|
||||
|
||||
self.begin_index = decoder_input_length + 1
|
||||
|
||||
if generate_config.is_multilingual:
|
||||
# room for language token and task token
|
||||
self.begin_index += 2
|
||||
if hasattr(generate_config, "max_initial_timestamp_index"):
|
||||
self.max_initial_timestamp_index = generate_config.max_initial_timestamp_index
|
||||
else:
|
||||
self.max_initial_timestamp_index = model_config.vocab_size
|
||||
if self.max_initial_timestamp_index is None:
|
||||
self.max_initial_timestamp_index = model_config.vocab_size
|
||||
|
||||
def __call__(self, input_ids, scores, cur_len):
|
||||
# suppress <|notimestamps|> which is handled by without_timestamps
|
||||
scores = scores.at[:, self.no_timestamps_token_id].set(-float("inf"))
|
||||
|
||||
def handle_pairs(input_ids_k, scores_k):
|
||||
last_was_timestamp = jnp.where((cur_len - self.begin_index) >= 1, True, False)
|
||||
last_was_timestamp = jnp.where(
|
||||
input_ids_k[cur_len - 1] >= self.timestamp_begin,
|
||||
True and last_was_timestamp,
|
||||
False,
|
||||
)
|
||||
|
||||
penultimate_was_timestamp = jnp.where((cur_len - self.begin_index) < 2, True, False)
|
||||
penultimate_was_timestamp = jnp.where(
|
||||
input_ids_k[cur_len - 2] >= self.timestamp_begin,
|
||||
True,
|
||||
penultimate_was_timestamp,
|
||||
)
|
||||
|
||||
return jnp.where(
|
||||
last_was_timestamp,
|
||||
jnp.where(
|
||||
penultimate_was_timestamp > 0,
|
||||
scores_k.at[self.timestamp_begin :].set(-float("inf")),
|
||||
scores_k.at[: self.eos_token_id].set(-float("inf")),
|
||||
),
|
||||
scores_k,
|
||||
)
|
||||
|
||||
scores = jax.vmap(handle_pairs)(input_ids, scores)
|
||||
|
||||
apply_max_initial_timestamp = jnp.where(cur_len == self.begin_index, True, False)
|
||||
apply_max_initial_timestamp = jnp.where(
|
||||
self.max_initial_timestamp_index is not None,
|
||||
True and apply_max_initial_timestamp,
|
||||
False,
|
||||
)
|
||||
|
||||
last_allowed = self.timestamp_begin + self.max_initial_timestamp_index
|
||||
|
||||
scores = jnp.where(
|
||||
apply_max_initial_timestamp,
|
||||
scores.at[:, last_allowed + 1 :].set(-float("inf")),
|
||||
scores,
|
||||
)
|
||||
|
||||
# if sum of probability over timestamps is above any other token, sample timestamp
|
||||
logprobs = jax.nn.log_softmax(scores, axis=-1)
|
||||
|
||||
def handle_cumulative_probs(logprobs_k, scores_k):
|
||||
timestamp_logprob = jax.nn.logsumexp(logprobs_k[self.timestamp_begin :], axis=-1)
|
||||
max_text_token_logprob = jnp.max(logprobs_k[: self.timestamp_begin])
|
||||
return jnp.where(
|
||||
timestamp_logprob > max_text_token_logprob,
|
||||
scores_k.at[: self.timestamp_begin].set(-float("inf")),
|
||||
scores_k,
|
||||
)
|
||||
|
||||
scores = jax.vmap(handle_cumulative_probs)(logprobs, scores)
|
||||
|
||||
return scores
|
||||
|
||||
@@ -37,8 +37,11 @@ from .configuration_utils import GenerationConfig
|
||||
from .flax_logits_process import (
|
||||
FlaxForcedBOSTokenLogitsProcessor,
|
||||
FlaxForcedEOSTokenLogitsProcessor,
|
||||
FlaxForceTokensLogitsProcessor,
|
||||
FlaxLogitsProcessorList,
|
||||
FlaxMinLengthLogitsProcessor,
|
||||
FlaxSuppressTokensAtBeginLogitsProcessor,
|
||||
FlaxSuppressTokensLogitsProcessor,
|
||||
FlaxTemperatureLogitsWarper,
|
||||
FlaxTopKLogitsWarper,
|
||||
FlaxTopPLogitsWarper,
|
||||
@@ -164,6 +167,50 @@ class FlaxGenerationMixin:
|
||||
model_kwargs["encoder_outputs"] = self.encode(input_ids, params=params, return_dict=True, **encoder_kwargs)
|
||||
return model_kwargs
|
||||
|
||||
def _prepare_decoder_input_ids_for_generation(
|
||||
self,
|
||||
batch_size: int,
|
||||
decoder_start_token_id: int = None,
|
||||
bos_token_id: int = None,
|
||||
model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
|
||||
) -> jnp.ndarray:
|
||||
if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
|
||||
# Only use this arg if not None, otherwise just remove from model_kwargs
|
||||
decoder_input_ids = model_kwargs.pop("decoder_input_ids")
|
||||
if decoder_input_ids is not None:
|
||||
return decoder_input_ids
|
||||
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
|
||||
return jnp.array(decoder_start_token_id, dtype="i4").reshape(1, -1).repeat(batch_size, axis=0)
|
||||
|
||||
def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int:
|
||||
# retrieve decoder_start_token_id for encoder-decoder models
|
||||
# fall back to bos_token_id if necessary
|
||||
decoder_start_token_id = (
|
||||
decoder_start_token_id
|
||||
if decoder_start_token_id is not None
|
||||
else self.generation_config.decoder_start_token_id
|
||||
)
|
||||
bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id
|
||||
if decoder_start_token_id is not None:
|
||||
return decoder_start_token_id
|
||||
elif (
|
||||
hasattr(self.config, "decoder")
|
||||
and hasattr(self.config.decoder, "decoder_start_token_id")
|
||||
and self.config.decoder.decoder_start_token_id is not None
|
||||
):
|
||||
return self.config.decoder.decoder_start_token_id
|
||||
elif bos_token_id is not None:
|
||||
return bos_token_id
|
||||
elif (
|
||||
hasattr(self.config, "decoder")
|
||||
and hasattr(self.config.decoder, "bos_token_id")
|
||||
and self.config.decoder.bos_token_id is not None
|
||||
):
|
||||
return self.config.decoder.bos_token_id
|
||||
raise ValueError(
|
||||
"`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _expand_to_num_beams(tensor, num_beams):
|
||||
return jnp.broadcast_to(tensor[:, None], (tensor.shape[0], num_beams) + tensor.shape[1:])
|
||||
@@ -224,6 +271,7 @@ class FlaxGenerationMixin:
|
||||
prng_key: Optional[jnp.ndarray] = None,
|
||||
trace: bool = True,
|
||||
params: Optional[Dict[str, jnp.ndarray]] = None,
|
||||
logits_processor: Optional[FlaxLogitsProcessorList] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@@ -244,6 +292,10 @@ class FlaxGenerationMixin:
|
||||
considerably slower runtime.
|
||||
params (`Dict[str, jnp.ndarray]`, *optional*):
|
||||
Optionally the model parameters can be passed. Can be useful for parallelized generation.
|
||||
logits_processor (`FlaxLogitsProcessorList `, *optional*):
|
||||
Custom logits processors that complement the default logits processors built from arguments and
|
||||
generation config. If a logit processor is passed that is already created with the arguments or a
|
||||
generation config an error is thrown. This feature is intended for advanced users.
|
||||
kwargs:
|
||||
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
|
||||
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
|
||||
@@ -277,6 +329,8 @@ class FlaxGenerationMixin:
|
||||
generation_config.validate()
|
||||
self._validate_model_kwargs(model_kwargs.copy())
|
||||
|
||||
logits_processor = logits_processor if logits_processor is not None else FlaxLogitsProcessorList()
|
||||
|
||||
# set init values
|
||||
prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
|
||||
|
||||
@@ -306,12 +360,19 @@ class FlaxGenerationMixin:
|
||||
"generation results, please set `padding_side='left'` when initializing the tokenizer."
|
||||
)
|
||||
|
||||
batch_size = input_ids.shape[0]
|
||||
|
||||
if self.config.is_encoder_decoder:
|
||||
# add encoder_outputs to model_kwargs
|
||||
if model_kwargs.get("encoder_outputs") is None:
|
||||
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, params, model_kwargs)
|
||||
# prepare decoder_input_ids for generation
|
||||
input_ids = jnp.ones((input_ids.shape[0], 1), dtype="i4") * generation_config.decoder_start_token_id
|
||||
input_ids = self._prepare_decoder_input_ids_for_generation(
|
||||
batch_size,
|
||||
decoder_start_token_id=generation_config.decoder_start_token_id,
|
||||
bos_token_id=generation_config.bos_token_id,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
|
||||
# Prepare `max_length` depending on other stopping criteria.
|
||||
input_ids_seq_length = input_ids.shape[-1]
|
||||
@@ -347,7 +408,11 @@ class FlaxGenerationMixin:
|
||||
" increasing`max_new_tokens`."
|
||||
)
|
||||
|
||||
logits_processor = self._get_logits_processor(generation_config=generation_config)
|
||||
logits_processor = self._get_logits_processor(
|
||||
generation_config=generation_config,
|
||||
input_ids_seq_length=input_ids_seq_length,
|
||||
logits_processor=logits_processor,
|
||||
)
|
||||
|
||||
if not generation_config.do_sample and generation_config.num_beams == 1:
|
||||
return self._greedy_search(
|
||||
@@ -419,7 +484,12 @@ class FlaxGenerationMixin:
|
||||
|
||||
return warpers
|
||||
|
||||
def _get_logits_processor(self, generation_config: GenerationConfig) -> FlaxLogitsProcessorList:
|
||||
def _get_logits_processor(
|
||||
self,
|
||||
generation_config: GenerationConfig,
|
||||
input_ids_seq_length: int,
|
||||
logits_processor: Optional[FlaxLogitsProcessorList],
|
||||
) -> FlaxLogitsProcessorList:
|
||||
"""
|
||||
This class returns a [`FlaxLogitsProcessorList`] list object that contains all relevant [`FlaxLogitsProcessor`]
|
||||
instances used to modify the scores of the language model head.
|
||||
@@ -440,9 +510,51 @@ class FlaxGenerationMixin:
|
||||
processors.append(
|
||||
FlaxForcedEOSTokenLogitsProcessor(generation_config.max_length, generation_config.forced_eos_token_id)
|
||||
)
|
||||
if generation_config.suppress_tokens is not None:
|
||||
processors.append(FlaxSuppressTokensLogitsProcessor(generation_config.suppress_tokens))
|
||||
if generation_config.begin_suppress_tokens is not None:
|
||||
begin_index = input_ids_seq_length
|
||||
begin_index = (
|
||||
begin_index
|
||||
if (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None)
|
||||
else begin_index + 1
|
||||
)
|
||||
if generation_config.forced_decoder_ids is not None and len(generation_config.forced_decoder_ids) > 0:
|
||||
# generation starts after the last token that is forced
|
||||
begin_index += generation_config.forced_decoder_ids[-1][0]
|
||||
processors.append(
|
||||
FlaxSuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index)
|
||||
)
|
||||
if generation_config.forced_decoder_ids is not None:
|
||||
forced_decoder_ids = [
|
||||
[input_ids_seq_length + i[0] - 1, i[1]] for i in generation_config.forced_decoder_ids
|
||||
]
|
||||
processors.append(FlaxForceTokensLogitsProcessor(forced_decoder_ids))
|
||||
processors = self._merge_criteria_processor_list(processors, logits_processor)
|
||||
|
||||
return processors
|
||||
|
||||
def _merge_criteria_processor_list(
|
||||
self,
|
||||
default_list: FlaxLogitsProcessorList,
|
||||
custom_list: FlaxLogitsProcessorList,
|
||||
) -> FlaxLogitsProcessorList:
|
||||
if len(custom_list) == 0:
|
||||
return default_list
|
||||
for default in default_list:
|
||||
for custom in custom_list:
|
||||
if type(custom) is type(default):
|
||||
object_type = "logits processor"
|
||||
raise ValueError(
|
||||
f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to"
|
||||
f" `generate`, but it has already been created with the values {default}. {default} has been"
|
||||
" created by passing the corresponding arguments to generate or by the model's config default"
|
||||
f" values. If you just want to change the default values of {object_type} consider passing"
|
||||
f" them as arguments to `generate` instead of using a custom {object_type}."
|
||||
)
|
||||
default_list.extend(custom_list)
|
||||
return default_list
|
||||
|
||||
def _greedy_search(
|
||||
self,
|
||||
input_ids: None,
|
||||
|
||||
@@ -163,6 +163,7 @@ else:
|
||||
"FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
|
||||
"FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
|
||||
"FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
|
||||
"FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
|
||||
"FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
|
||||
"FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING",
|
||||
"FLAX_MODEL_MAPPING",
|
||||
@@ -176,6 +177,7 @@ else:
|
||||
"FlaxAutoModelForQuestionAnswering",
|
||||
"FlaxAutoModelForSeq2SeqLM",
|
||||
"FlaxAutoModelForSequenceClassification",
|
||||
"FlaxAutoModelForSpeechSeq2Seq",
|
||||
"FlaxAutoModelForTokenClassification",
|
||||
"FlaxAutoModelForVision2Seq",
|
||||
]
|
||||
@@ -320,6 +322,7 @@ if TYPE_CHECKING:
|
||||
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
||||
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING,
|
||||
FLAX_MODEL_MAPPING,
|
||||
@@ -333,6 +336,7 @@ if TYPE_CHECKING:
|
||||
FlaxAutoModelForQuestionAnswering,
|
||||
FlaxAutoModelForSeq2SeqLM,
|
||||
FlaxAutoModelForSequenceClassification,
|
||||
FlaxAutoModelForSpeechSeq2Seq,
|
||||
FlaxAutoModelForTokenClassification,
|
||||
FlaxAutoModelForVision2Seq,
|
||||
)
|
||||
|
||||
@@ -55,6 +55,7 @@ FLAX_MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("vision-text-dual-encoder", "FlaxVisionTextDualEncoderModel"),
|
||||
("vit", "FlaxViTModel"),
|
||||
("wav2vec2", "FlaxWav2Vec2Model"),
|
||||
("whisper", "FlaxWhisperModel"),
|
||||
("xglm", "FlaxXGLMModel"),
|
||||
("xlm-roberta", "FlaxXLMRobertaModel"),
|
||||
]
|
||||
@@ -76,6 +77,7 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
|
||||
("roformer", "FlaxRoFormerForMaskedLM"),
|
||||
("t5", "FlaxT5ForConditionalGeneration"),
|
||||
("wav2vec2", "FlaxWav2Vec2ForPreTraining"),
|
||||
("whisper", "FlaxWhisperForConditionalGeneration"),
|
||||
("xlm-roberta", "FlaxXLMRobertaForMaskedLM"),
|
||||
]
|
||||
)
|
||||
@@ -219,6 +221,7 @@ FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
|
||||
FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("speech-encoder-decoder", "FlaxSpeechEncoderDecoderModel"),
|
||||
("whisper", "FlaxWhisperForConditionalGeneration"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -13,7 +13,13 @@
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available
|
||||
from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
is_flax_available,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
)
|
||||
|
||||
|
||||
_import_structure = {
|
||||
@@ -50,6 +56,19 @@ else:
|
||||
"TFWhisperPreTrainedModel",
|
||||
]
|
||||
|
||||
try:
|
||||
if not is_flax_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_flax_whisper"] = [
|
||||
"FlaxWhisperForConditionalGeneration",
|
||||
"FlaxWhisperModel",
|
||||
"FlaxWhisperPreTrainedModel",
|
||||
]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_whisper import WHISPER_PRETRAINED_CONFIG_ARCHIVE_MAP, WhisperConfig, WhisperOnnxConfig
|
||||
from .feature_extraction_whisper import WhisperFeatureExtractor
|
||||
@@ -82,6 +101,18 @@ if TYPE_CHECKING:
|
||||
TFWhisperPreTrainedModel,
|
||||
)
|
||||
|
||||
try:
|
||||
if not is_flax_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_flax_whisper import (
|
||||
FlaxWhisperForConditionalGeneration,
|
||||
FlaxWhisperModel,
|
||||
FlaxWhisperPreTrainedModel,
|
||||
)
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
||||
1470
src/transformers/models/whisper/modeling_flax_whisper.py
Normal file
1470
src/transformers/models/whisper/modeling_flax_whisper.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -162,6 +162,9 @@ FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = None
|
||||
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = None
|
||||
|
||||
|
||||
FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = None
|
||||
|
||||
|
||||
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None
|
||||
|
||||
|
||||
@@ -241,6 +244,13 @@ class FlaxAutoModelForSequenceClassification(metaclass=DummyObject):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxAutoModelForSpeechSeq2Seq(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxAutoModelForTokenClassification(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
@@ -1130,6 +1140,27 @@ class FlaxWav2Vec2PreTrainedModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxWhisperForConditionalGeneration(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxWhisperModel(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxWhisperPreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxXGLMForCausalLM(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
|
||||
706
tests/models/whisper/test_modeling_flax_whisper.py
Normal file
706
tests/models/whisper/test_modeling_flax_whisper.py
Normal file
@@ -0,0 +1,706 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import transformers
|
||||
from transformers import WhisperConfig, is_flax_available
|
||||
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow
|
||||
from transformers.utils import cached_property
|
||||
from transformers.utils.import_utils import is_datasets_available
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor
|
||||
|
||||
|
||||
if is_datasets_available():
|
||||
import datasets
|
||||
from datasets import load_dataset
|
||||
|
||||
if is_flax_available():
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from flax.core.frozen_dict import unfreeze
|
||||
from flax.traverse_util import flatten_dict
|
||||
from transformers import (
|
||||
FLAX_MODEL_MAPPING,
|
||||
FlaxWhisperForConditionalGeneration,
|
||||
FlaxWhisperModel,
|
||||
WhisperFeatureExtractor,
|
||||
WhisperProcessor,
|
||||
)
|
||||
from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxWhisperModelTester:
|
||||
config_cls = WhisperConfig
|
||||
config_updates = {}
|
||||
hidden_act = "gelu"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=60,
|
||||
is_training=True,
|
||||
use_labels=False,
|
||||
vocab_size=99,
|
||||
d_model=16,
|
||||
decoder_attention_heads=4,
|
||||
decoder_ffn_dim=16,
|
||||
decoder_layers=2,
|
||||
encoder_attention_heads=4,
|
||||
encoder_ffn_dim=16,
|
||||
encoder_layers=2,
|
||||
input_channels=1,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=70,
|
||||
max_source_positions=30,
|
||||
max_target_positions=40,
|
||||
bos_token_id=98,
|
||||
eos_token_id=98,
|
||||
pad_token_id=0,
|
||||
num_mel_bins=80,
|
||||
decoder_start_token_id=85,
|
||||
num_conv_layers=1,
|
||||
suppress_tokens=None,
|
||||
begin_suppress_tokens=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.d_model = d_model
|
||||
self.hidden_size = d_model
|
||||
self.num_hidden_layers = encoder_layers
|
||||
self.num_attention_heads = encoder_attention_heads
|
||||
self.decoder_attention_heads = decoder_attention_heads
|
||||
self.decoder_ffn_dim = decoder_ffn_dim
|
||||
self.decoder_layers = decoder_layers
|
||||
self.encoder_attention_heads = encoder_attention_heads
|
||||
self.encoder_ffn_dim = encoder_ffn_dim
|
||||
self.encoder_layers = encoder_layers
|
||||
self.encoder_seq_length = seq_length // 2
|
||||
self.decoder_seq_length = 1
|
||||
self.input_channels = input_channels
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.num_mel_bins = num_mel_bins
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.max_source_positions = max_source_positions
|
||||
self.max_target_positions = max_target_positions
|
||||
self.eos_token_id = eos_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
self.bos_token_id = bos_token_id
|
||||
self.decoder_start_token_id = decoder_start_token_id
|
||||
self.num_conv_layers = num_conv_layers
|
||||
self.suppress_tokens = suppress_tokens
|
||||
self.begin_suppress_tokens = begin_suppress_tokens
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
input_features = floats_tensor([self.batch_size, self.num_mel_bins, self.seq_length], self.vocab_size)
|
||||
|
||||
decoder_input_ids = np.array(self.batch_size * [[self.decoder_start_token_id]])
|
||||
|
||||
config = WhisperConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
num_mel_bins=self.num_mel_bins,
|
||||
decoder_start_token_id=self.decoder_start_token_id,
|
||||
is_encoder_decoder=True,
|
||||
activation_function=self.hidden_act,
|
||||
dropout=self.hidden_dropout_prob,
|
||||
attention_dropout=self.attention_probs_dropout_prob,
|
||||
max_source_positions=self.max_source_positions,
|
||||
max_target_positions=self.max_target_positions,
|
||||
pad_token_id=self.pad_token_id,
|
||||
bos_token_id=self.bos_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
tie_word_embeddings=True,
|
||||
d_model=self.d_model,
|
||||
decoder_attention_heads=self.decoder_attention_heads,
|
||||
decoder_ffn_dim=self.decoder_ffn_dim,
|
||||
decoder_layers=self.decoder_layers,
|
||||
encoder_attention_heads=self.encoder_attention_heads,
|
||||
encoder_ffn_dim=self.encoder_ffn_dim,
|
||||
encoder_layers=self.encoder_layers,
|
||||
suppress_tokens=self.suppress_tokens,
|
||||
begin_suppress_tokens=self.begin_suppress_tokens,
|
||||
)
|
||||
inputs_dict = prepare_whisper_inputs_dict(config, input_features, decoder_input_ids)
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
def prepare_whisper_inputs_dict(
|
||||
config,
|
||||
input_ids,
|
||||
decoder_input_ids,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
):
|
||||
if decoder_attention_mask is None:
|
||||
decoder_attention_mask = np.concatenate(
|
||||
[
|
||||
np.ones(decoder_input_ids[:, :1].shape, dtype=np.int8),
|
||||
np.not_equal(decoder_input_ids[:, 1:], config.pad_token_id).astype(np.int8),
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
return {
|
||||
"input_features": input_ids,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
}
|
||||
|
||||
|
||||
def partialclass(cls, *args, **kwargs):
|
||||
class NewCls(cls):
|
||||
__init__ = functools.partialmethod(cls.__init__, *args, **kwargs)
|
||||
|
||||
return NewCls
|
||||
|
||||
|
||||
def make_partial_class(full_class, *args, **kwargs):
|
||||
partial_class = partialclass(full_class, *args, **kwargs)
|
||||
partial_class.__name__ = full_class.__name__
|
||||
partial_class.__module__ = full_class.__module__
|
||||
|
||||
return partial_class
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxWhisperModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (FlaxWhisperForConditionalGeneration, FlaxWhisperModel) if is_flax_available() else ()
|
||||
all_generative_model_classes = (FlaxWhisperForConditionalGeneration,) if is_flax_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = FlaxWhisperModelTester(self)
|
||||
_, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
self.init_shape = (1,) + inputs_dict["input_features"].shape[1:]
|
||||
|
||||
self.all_model_classes = (
|
||||
make_partial_class(model_class, input_shape=self.init_shape) for model_class in self.all_model_classes
|
||||
)
|
||||
self.config_tester = ConfigTester(self, config_class=WhisperConfig)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
# overwrite because of `input_features`
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
signature = inspect.signature(model.__call__)
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
expected_arg_names = ["input_features", "decoder_input_ids"]
|
||||
self.assertListEqual(arg_names[:2], expected_arg_names)
|
||||
|
||||
# overwrite because of `input_features`
|
||||
def test_jit_compilation(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
model = model_class(config)
|
||||
|
||||
@jax.jit
|
||||
def model_jitted(input_features, decoder_input_ids, **kwargs):
|
||||
return model(input_features=input_features, decoder_input_ids=decoder_input_ids, **kwargs)
|
||||
|
||||
with self.subTest("JIT Enabled"):
|
||||
jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple()
|
||||
|
||||
with self.subTest("JIT Disabled"):
|
||||
with jax.disable_jit():
|
||||
outputs = model_jitted(**prepared_inputs_dict).to_tuple()
|
||||
|
||||
self.assertEqual(len(outputs), len(jitted_outputs))
|
||||
for jitted_output, output in zip(jitted_outputs, outputs):
|
||||
self.assertEqual(jitted_output.shape, output.shape)
|
||||
|
||||
# overwrite because of `input_features`
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_bf16_to_base_pt(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = make_partial_class(FLAX_MODEL_MAPPING[config.__class__], input_shape=self.init_shape)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class.__name__ == base_class.__name__:
|
||||
continue
|
||||
|
||||
model = model_class(config)
|
||||
model.params = model.to_bf16(model.params)
|
||||
base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))
|
||||
|
||||
# convert Flax model to PyTorch model
|
||||
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||
pt_model = pt_model_class(config).eval()
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
base_params = flatten_dict(unfreeze(base_model.params))
|
||||
|
||||
for key in base_params_from_head.keys():
|
||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
# overwrite because of `input_features`
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_from_base_pt(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = make_partial_class(FLAX_MODEL_MAPPING[config.__class__], input_shape=self.init_shape)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class.__name__ == base_class.__name__:
|
||||
continue
|
||||
|
||||
model = base_class(config)
|
||||
base_params = flatten_dict(unfreeze(model.params))
|
||||
|
||||
# convert Flax model to PyTorch model
|
||||
pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||
pt_model = pt_model_class(config).eval()
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
# save pt model
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
head_model = model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
base_param_from_head = flatten_dict(unfreeze(head_model.params[head_model.base_model_prefix]))
|
||||
|
||||
for key in base_param_from_head.keys():
|
||||
max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
# overwrite because of `input_features`
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_to_base_pt(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = make_partial_class(FLAX_MODEL_MAPPING[config.__class__], input_shape=self.init_shape)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class.__name__ == base_class.__name__:
|
||||
continue
|
||||
|
||||
model = model_class(config)
|
||||
base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))
|
||||
|
||||
# convert Flax model to PyTorch model
|
||||
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||
pt_model = pt_model_class(config).eval()
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
base_params = flatten_dict(unfreeze(base_model.params))
|
||||
|
||||
for key in base_params_from_head.keys():
|
||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
# overwrite because of `input_features`
|
||||
def test_save_load_from_base(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = make_partial_class(FLAX_MODEL_MAPPING[config.__class__], input_shape=self.init_shape)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class.__name__ == base_class.__name__:
|
||||
continue
|
||||
|
||||
model = base_class(config)
|
||||
base_params = flatten_dict(unfreeze(model.params))
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
head_model = model_class.from_pretrained(tmpdirname)
|
||||
|
||||
base_param_from_head = flatten_dict(unfreeze(head_model.params[head_model.base_model_prefix]))
|
||||
|
||||
for key in base_param_from_head.keys():
|
||||
max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
# overwrite because of `input_features`
|
||||
def test_save_load_to_base(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = make_partial_class(FLAX_MODEL_MAPPING[config.__class__], input_shape=self.init_shape)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class.__name__ == base_class.__name__:
|
||||
continue
|
||||
|
||||
model = model_class(config)
|
||||
base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
base_model = base_class.from_pretrained(tmpdirname)
|
||||
|
||||
base_params = flatten_dict(unfreeze(base_model.params))
|
||||
|
||||
for key in base_params_from_head.keys():
|
||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
|
||||
@slow
|
||||
@require_flax
|
||||
class FlaxWhisperModelIntegrationTest(unittest.TestCase):
|
||||
@cached_property
|
||||
def default_processor(self):
|
||||
return WhisperProcessor.from_pretrained("openai/whisper-base")
|
||||
|
||||
def _load_datasamples(self, num_samples):
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
# automatic decoding with librispeech
|
||||
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
|
||||
|
||||
return [x["array"] for x in speech_samples]
|
||||
|
||||
def test_tiny_logits_librispeech(self):
|
||||
model = FlaxWhisperModel.from_pretrained("openai/whisper-tiny", from_pt=True)
|
||||
input_speech = self._load_datasamples(1)
|
||||
feature_extractor = WhisperFeatureExtractor()
|
||||
input_features = feature_extractor(input_speech, return_tensors="np").input_features
|
||||
|
||||
logits = model(
|
||||
input_features,
|
||||
decoder_input_ids=np.array([[50258, 50259, 50359]]),
|
||||
output_hidden_states=False,
|
||||
output_attentions=False,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_LOGITS = np.array(
|
||||
[
|
||||
2.9892, -6.7607, 5.7348, 3.6096, 0.2152, -5.7321, 4.8855, -1.6407,
|
||||
0.2823, -1.5718, 10.4269, 3.4427, 0.0219, -8.0612, 3.4784, 8.4246,
|
||||
4.0575, -2.2864, 11.1084, 0.9963, 0.9884, -8.5154, -3.5469, -9.3713,
|
||||
0.9786, 3.5435, 7.4850, -5.2579, -1.4366, 10.4841
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
self.assertTrue(np.allclose(logits[0][0, 0, :30], EXPECTED_LOGITS, atol=1e-4))
|
||||
|
||||
def test_small_en_logits_librispeech(self):
|
||||
model = FlaxWhisperModel.from_pretrained("openai/whisper-small.en", from_pt=True)
|
||||
input_speech = self._load_datasamples(1)
|
||||
feature_extractor = WhisperFeatureExtractor()
|
||||
input_features = feature_extractor(input_speech, return_tensors="np").input_features
|
||||
|
||||
logits = model(
|
||||
input_features,
|
||||
decoder_input_ids=np.array([model.config.decoder_start_token_id]),
|
||||
output_hidden_states=False,
|
||||
output_attentions=False,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
logits = logits[0] @ model.params["model"]["decoder"]["embed_tokens"]["embedding"].T
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_LOGITS = np.array(
|
||||
[
|
||||
-3.6784, -7.7211, -9.5070, -11.9286, -7.6489, -9.7026, -5.6188,
|
||||
-8.0104, -4.6238, -5.1833, -9.0485, -3.4079, -5.4874, -2.6935,
|
||||
-6.3479, -7.3398, -6.9558, -7.6867, -7.4748, -8.3463, -9.9781,
|
||||
-10.8389, -10.3105, -11.7201, -9.7261, -7.1590, -5.9272, -12.4509,
|
||||
-11.1146, -8.1918
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
self.assertTrue(np.allclose(logits[0, 0, :30], EXPECTED_LOGITS, atol=1e-4))
|
||||
|
||||
def test_large_logits_librispeech(self):
|
||||
model = FlaxWhisperModel.from_pretrained("openai/whisper-large", from_pt=True)
|
||||
input_speech = self._load_datasamples(1)
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-large")
|
||||
processed_inputs = processor(
|
||||
audio=input_speech, text="This part of the speech", add_special_tokens=False, return_tensors="np"
|
||||
)
|
||||
input_features = processed_inputs.input_features
|
||||
decoder_input_ids = processed_inputs.labels
|
||||
|
||||
logits = model(
|
||||
input_features,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
output_hidden_states=False,
|
||||
output_attentions=False,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
logits = logits[0] @ model.params["model"]["decoder"]["embed_tokens"]["embedding"].T
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_LOGITS = np.array(
|
||||
[
|
||||
2.1382, 0.9381, 4.4671, 3.5589, 2.4022, 3.8576, -0.6521, 2.5472,
|
||||
1.8301, 1.9957, 2.3432, 1.4678, 0.5459, 2.2597, 1.5179, 2.5357,
|
||||
1.1624, 0.6194, 1.0757, 1.8259, 2.4076, 1.6601, 2.3503, 1.3376,
|
||||
1.9891, 1.8635, 3.8931, 5.3699, 4.4772, 3.9184
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
self.assertTrue(np.allclose(logits[0, 0, :30], EXPECTED_LOGITS, atol=1e-4))
|
||||
|
||||
def test_tiny_en_generation(self):
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
|
||||
model.config.decoder_start_token_id = 50257
|
||||
|
||||
input_speech = self._load_datasamples(1)
|
||||
input_features = processor.feature_extractor(
|
||||
raw_speech=input_speech, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="jax"
|
||||
).input_features
|
||||
|
||||
generated_ids = model.generate(input_features, num_beams=5, max_length=20).sequences
|
||||
transcript = processor.tokenizer.decode(generated_ids[0])
|
||||
|
||||
EXPECTED_TRANSCRIPT = (
|
||||
"<|startoftranscript|><|en|><|transcribe|><|notimestamps|> Mr. Quilter is the apostle of the middle"
|
||||
" classes and we are glad to"
|
||||
)
|
||||
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
|
||||
def test_tiny_generation(self):
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
|
||||
model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny", from_pt=True)
|
||||
|
||||
input_speech = self._load_datasamples(1)
|
||||
input_features = processor.feature_extractor(
|
||||
raw_speech=input_speech, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="jax"
|
||||
).input_features
|
||||
|
||||
generated_ids = model.generate(input_features, num_beams=5, max_length=20).sequences
|
||||
transcript = processor.tokenizer.decode(generated_ids[0])
|
||||
|
||||
EXPECTED_TRANSCRIPT = (
|
||||
"<|startoftranscript|><|en|><|transcribe|><|notimestamps|> Mr. Quilter is the apostle of the middle"
|
||||
" classes and we are glad"
|
||||
)
|
||||
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
|
||||
def test_large_generation(self):
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-large")
|
||||
model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-large", from_pt=True)
|
||||
|
||||
input_speech = self._load_datasamples(1)
|
||||
input_features = processor.feature_extractor(
|
||||
raw_speech=input_speech, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="jax"
|
||||
).input_features
|
||||
|
||||
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe")
|
||||
|
||||
generated_ids = model.generate(input_features, num_beams=5, max_length=20).sequences
|
||||
transcript = processor.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
||||
|
||||
EXPECTED_TRANSCRIPT = " Mr. Quilter is the apostle of the middle classes and we are glad"
|
||||
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
|
||||
def test_large_generation_multilingual(self):
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-large")
|
||||
model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-large", from_pt=True)
|
||||
|
||||
ds = load_dataset("common_voice", "ja", split="test", streaming=True)
|
||||
ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16_000))
|
||||
input_speech = next(iter(ds))["audio"]["array"]
|
||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="np")
|
||||
|
||||
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="transcribe")
|
||||
generated_ids = model.generate(input_features, do_sample=False, max_length=20).sequences
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
EXPECTED_TRANSCRIPT = "木村さんに電話を貸してもらいました"
|
||||
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
|
||||
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe")
|
||||
generated_ids = model.generate(
|
||||
input_features,
|
||||
do_sample=False,
|
||||
max_length=20,
|
||||
).sequences
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
EXPECTED_TRANSCRIPT = " Kimura-san called me."
|
||||
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
|
||||
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="translate")
|
||||
generated_ids = model.generate(input_features, do_sample=False, max_length=20).sequences
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
EXPECTED_TRANSCRIPT = " I borrowed a phone from Kimura san"
|
||||
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
|
||||
def test_large_batched_generation(self):
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-large")
|
||||
model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-large", from_pt=True)
|
||||
|
||||
input_speech = self._load_datasamples(4)
|
||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="np").input_features
|
||||
generated_ids = model.generate(input_features, max_length=20).sequences
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_LOGITS = np.array(
|
||||
[
|
||||
[50258, 50358, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404, 281],
|
||||
[50258, 50358, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257, 50257],
|
||||
[50258, 50358, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256],
|
||||
[50258, 50358, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439, 11]
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(np.allclose(generated_ids, EXPECTED_LOGITS))
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_TRANSCRIPT = [
|
||||
" Mr. Quilter is the apostle of the middle classes and we are glad to",
|
||||
" Nor is Mr. Quilter's manner less interesting than his matter.",
|
||||
" He tells us that at this festive season of the year, with Christmas and roast beef",
|
||||
" He has grave doubts whether Sir Frederick Layton's work is really Greek after all,",
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
self.assertListEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
|
||||
def test_tiny_en_batched_generation(self):
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
|
||||
|
||||
input_speech = self._load_datasamples(4)
|
||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="np").input_features
|
||||
generated_ids = model.generate(input_features, max_length=20).sequences
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_LOGITS = np.array(
|
||||
[
|
||||
[50257, 50362, 1770, 13, 2264, 346, 353, 318, 262, 46329, 286, 262, 3504, 6097, 11, 290, 356, 389, 9675, 284],
|
||||
[50257, 50362, 5414, 318, 1770, 13, 2264, 346, 353, 338, 5642, 1342, 3499, 621, 465, 2300, 13, 50256, 50256, 50256],
|
||||
[50257, 50362, 679, 4952, 514, 326, 379, 428, 43856, 1622, 286, 262, 614, 11, 351, 6786, 290, 32595, 12023, 28236],
|
||||
[50257, 50362, 679, 468, 12296, 17188, 1771, 7361, 26113, 18881, 1122, 338, 670, 318, 1107, 8312, 706, 477, 290, 460]
|
||||
]
|
||||
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(np.allclose(generated_ids, EXPECTED_LOGITS))
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_TRANSCRIPT = [
|
||||
" Mr. Quilter is the apostle of the middle classes, and we are glad to",
|
||||
" Nor is Mr. Quilter's manner less interesting than his matter.",
|
||||
" He tells us that at this festive season of the year, with Christmas and roast beef looming",
|
||||
" He has grave doubts whether Sir Frederick Layton's work is really Greek after all and can",
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
self.assertListEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
|
||||
@slow
|
||||
def test_tiny_timestamp_generation(self):
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
|
||||
model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
|
||||
|
||||
input_speech = np.concatenate(self._load_datasamples(4))
|
||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="jax").input_features
|
||||
|
||||
generate_fn = jax.jit(functools.partial(model.generate, max_length=448, return_timestamps=True))
|
||||
|
||||
generated_ids = generate_fn(input_features)
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_OUTPUT = np.array([50258, 50259, 50359, 50364, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 11, 293, 321, 366, 5404, 281, 2928, 702, 14943, 13, 50692, 50692, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50926, 50926, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256, 450, 10539, 51208, 51208, 949, 505, 11, 14138, 10117, 490, 3936, 293, 1080, 3542, 5160, 881, 26336, 281, 264, 1575, 13, 51552, 51552, 634, 575, 12525, 22618, 1968, 6144, 35617, 7354, 1292, 6, 589, 307, 534, 10281, 934, 439, 11, 293, 51836, 51836, 50257])
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(np.allclose(generated_ids, EXPECTED_OUTPUT))
|
||||
|
||||
EXPECTED_TRANSCRIPT = [
|
||||
{
|
||||
"text": (
|
||||
" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel. Nor is"
|
||||
" Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season"
|
||||
" of the year, with Christmas and roast beef looming before us, similarly drawn from eating and"
|
||||
" its results occur most readily to the mind. He has grave doubts whether Sir Frederick Latins'"
|
||||
" work is really Greek after all, and"
|
||||
),
|
||||
"offsets": [
|
||||
{
|
||||
"text": (
|
||||
" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."
|
||||
),
|
||||
"timestamp": (0.0, 6.5600000000000005),
|
||||
},
|
||||
{
|
||||
"text": " Nor is Mr. Quilter's manner less interesting than his matter.",
|
||||
"timestamp": (6.5600000000000005, 11.24),
|
||||
},
|
||||
{
|
||||
"text": (
|
||||
" He tells us that at this festive season of the year, with Christmas and roast beef"
|
||||
" looming"
|
||||
),
|
||||
"timestamp": (11.24, 16.88),
|
||||
},
|
||||
{
|
||||
"text": (
|
||||
" before us, similarly drawn from eating and its results occur most readily to the mind."
|
||||
),
|
||||
"timestamp": (16.88, 23.76),
|
||||
},
|
||||
{
|
||||
"text": (
|
||||
" He has grave doubts whether Sir Frederick Latins' work is really Greek after all, and"
|
||||
),
|
||||
"timestamp": (23.76, 29.44),
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True)
|
||||
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
@@ -22,9 +22,10 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
import transformers
|
||||
from transformers import WhisperConfig
|
||||
from transformers.testing_utils import is_torch_available, require_torch, require_torchaudio, slow, torch_device
|
||||
from transformers.utils import cached_property
|
||||
from transformers.testing_utils import is_pt_flax_cross_test, require_torch, require_torchaudio, slow, torch_device
|
||||
from transformers.utils import cached_property, is_flax_available, is_torch_available
|
||||
from transformers.utils.import_utils import is_datasets_available
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
@@ -48,6 +49,13 @@ if is_torch_available():
|
||||
)
|
||||
from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder
|
||||
|
||||
if is_flax_available():
|
||||
import jax.numpy as jnp
|
||||
from transformers.modeling_flax_pytorch_utils import (
|
||||
convert_pytorch_state_dict_to_flax,
|
||||
load_flax_weights_in_pytorch_model,
|
||||
)
|
||||
|
||||
|
||||
def prepare_whisper_inputs_dict(
|
||||
config,
|
||||
@@ -747,6 +755,159 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
|
||||
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_pt_to_flax(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
init_shape = (1,) + inputs_dict["input_features"].shape[1:]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
fx_model_class_name = "Flax" + model_class.__name__
|
||||
|
||||
if not hasattr(transformers, fx_model_class_name):
|
||||
# no flax model exists for this class
|
||||
return
|
||||
|
||||
# Output all for aggressive testing
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = self.has_attentions
|
||||
|
||||
fx_model_class = getattr(transformers, fx_model_class_name)
|
||||
|
||||
# load PyTorch class
|
||||
pt_model = model_class(config).eval()
|
||||
# Flax models don't use the `use_cache` option and cache is not returned as a default.
|
||||
# So we disable `use_cache` here for PyTorch model.
|
||||
pt_model.config.use_cache = False
|
||||
|
||||
# load Flax class
|
||||
fx_model = fx_model_class(config, input_shape=init_shape, dtype=jnp.float32)
|
||||
|
||||
# make sure only flax inputs are forward that actually exist in function args
|
||||
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
|
||||
|
||||
# prepare inputs
|
||||
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
# remove function args that don't exist in Flax
|
||||
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
|
||||
|
||||
# send pytorch inputs to the correct device
|
||||
pt_inputs = {
|
||||
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items()
|
||||
}
|
||||
|
||||
# convert inputs to Flax
|
||||
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)}
|
||||
|
||||
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||
fx_model.params = fx_state
|
||||
|
||||
# send pytorch model to the correct device
|
||||
pt_model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs)
|
||||
fx_outputs = fx_model(**fx_inputs)
|
||||
|
||||
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
fx_model_loaded = fx_model_class.from_pretrained(tmpdirname, input_shape=init_shape, from_pt=True)
|
||||
|
||||
fx_outputs_loaded = fx_model_loaded(**fx_inputs)
|
||||
|
||||
fx_keys = tuple([k for k, v in fx_outputs_loaded.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_pt_flax_outputs(fx_outputs_loaded, pt_outputs, model_class)
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_flax_to_pt(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
init_shape = (1,) + inputs_dict["input_features"].shape[1:]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
fx_model_class_name = "Flax" + model_class.__name__
|
||||
|
||||
if not hasattr(transformers, fx_model_class_name):
|
||||
# no flax model exists for this class
|
||||
return
|
||||
|
||||
# Output all for aggressive testing
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = self.has_attentions
|
||||
|
||||
fx_model_class = getattr(transformers, fx_model_class_name)
|
||||
|
||||
# load PyTorch class
|
||||
pt_model = model_class(config).eval()
|
||||
# Flax models don't use the `use_cache` option and cache is not returned as a default.
|
||||
# So we disable `use_cache` here for PyTorch model.
|
||||
pt_model.config.use_cache = False
|
||||
|
||||
# load Flax class
|
||||
fx_model = fx_model_class(config, input_shape=init_shape, dtype=jnp.float32)
|
||||
|
||||
# make sure only flax inputs are forward that actually exist in function args
|
||||
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
|
||||
|
||||
# prepare inputs
|
||||
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
# remove function args that don't exist in Flax
|
||||
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
|
||||
|
||||
# send pytorch inputs to the correct device
|
||||
pt_inputs = {
|
||||
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items()
|
||||
}
|
||||
|
||||
# convert inputs to Flax
|
||||
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)}
|
||||
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
|
||||
|
||||
# make sure weights are tied in PyTorch
|
||||
pt_model.tie_weights()
|
||||
|
||||
# send pytorch model to the correct device
|
||||
pt_model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs)
|
||||
fx_outputs = fx_model(**fx_inputs)
|
||||
|
||||
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fx_model.save_pretrained(tmpdirname)
|
||||
pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True)
|
||||
|
||||
# send pytorch model to the correct device
|
||||
pt_model_loaded.to(torch_device)
|
||||
pt_model_loaded.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs_loaded = pt_model_loaded(**pt_inputs)
|
||||
|
||||
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
|
||||
@@ -68,6 +68,8 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
|
||||
"DeformableDetrEncoder", # Building part of bigger (tested) model.
|
||||
"DeformableDetrDecoder", # Building part of bigger (tested) model.
|
||||
"OPTDecoder", # Building part of bigger (tested) model.
|
||||
"FlaxWhisperDecoder", # Building part of bigger (tested) model.
|
||||
"FlaxWhisperEncoder", # Building part of bigger (tested) model.
|
||||
"WhisperDecoder", # Building part of bigger (tested) model.
|
||||
"WhisperEncoder", # Building part of bigger (tested) model.
|
||||
"DecisionTransformerGPT2Model", # Building part of bigger (tested) model.
|
||||
|
||||
Reference in New Issue
Block a user