Make Whisper Encoder's sinusoidal PE non-trainable by default (#26032)
* set encoder's PE as non-trainable * freeze flax * init sinusoids * add test for non-trainable embed positions * simplify TF encoder embed_pos * revert tf * clean up * add sinusoidal init for jax * make consistent sinusoidal function * fix dtype * add default dtype * use numpy for sinusoids. fix jax * add sinusoid init for TF * fix * use custom embedding * use specialized init for each impl * fix sinusoids init. add test for pytorch * fix TF dtype * simplify sinusoid init for flax and tf * add tests for TF * change default dtype to float32 * add sinusoid test for flax * Update src/transformers/models/whisper/modeling_flax_whisper.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update src/transformers/models/whisper/modeling_tf_whisper.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * move sinusoidal init to _init_weights --------- Co-authored-by: sanchit-gandhi <sanchit@huggingface.co> Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
This commit is contained in:
@@ -46,6 +46,7 @@ if is_flax_available():
|
||||
WhisperProcessor,
|
||||
)
|
||||
from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model
|
||||
from transformers.models.whisper.modeling_flax_whisper import sinusoidal_embedding_init
|
||||
|
||||
|
||||
@require_flax
|
||||
@@ -387,6 +388,19 @@ class FlaxWhisperModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
def test_encoder_sinusoidal_embed_positions(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
params = model.params
|
||||
if model.base_model_prefix in params:
|
||||
params = model.params[model.base_model_prefix]
|
||||
|
||||
embeds = params["encoder"]["embed_positions"]["embedding"]
|
||||
sinusoids = sinusoidal_embedding_init(None, embeds.shape)
|
||||
self.assertTrue(jax.numpy.allclose(embeds, sinusoids))
|
||||
|
||||
|
||||
@slow
|
||||
@require_flax
|
||||
|
||||
@@ -42,7 +42,11 @@ if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers import TFWhisperForConditionalGeneration, TFWhisperModel, set_seed
|
||||
from transformers.models.whisper.modeling_tf_whisper import TFWhisperDecoder, TFWhisperEncoder
|
||||
from transformers.models.whisper.modeling_tf_whisper import (
|
||||
TFWhisperDecoder,
|
||||
TFWhisperEncoder,
|
||||
sinusoidal_embedding_init,
|
||||
)
|
||||
|
||||
|
||||
def prepare_whisper_inputs_dict(
|
||||
@@ -297,6 +301,23 @@ class TFWhisperModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model_forward(*config_and_inputs)
|
||||
|
||||
def test_requires_grad_encoder_embed_positions(self):
|
||||
config = self.model_tester.get_config()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
encoder = model.get_encoder()
|
||||
self.assertFalse(encoder.embed_positions.trainable)
|
||||
|
||||
def test_encoder_sinusoidal_embed_positions(self):
|
||||
config = self.model_tester.get_config()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.build()
|
||||
|
||||
embeds = model.get_encoder().embed_positions.get_weights()[0]
|
||||
sinusoids = sinusoidal_embedding_init(embeds.shape).numpy()
|
||||
self.assertTrue(np.allclose(embeds, sinusoids))
|
||||
|
||||
def test_decoder_model_past_with_large_inputs(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
|
||||
|
||||
@@ -49,7 +49,7 @@ if is_torch_available():
|
||||
WhisperProcessor,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder
|
||||
from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder, sinusoids
|
||||
|
||||
if is_flax_available():
|
||||
import jax.numpy as jnp
|
||||
@@ -351,6 +351,20 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
self.assertFalse(all(encoder_grads))
|
||||
self.assertTrue(all(decoder_grads))
|
||||
|
||||
def test_requires_grad_encoder_embed_positions(self):
|
||||
config = self.model_tester.get_config()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
encoder = model.get_encoder()
|
||||
self.assertFalse(encoder.embed_positions.weight.requires_grad)
|
||||
|
||||
def test_encoder_sinusoidal_embed_positions(self):
|
||||
config = self.model_tester.get_config()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
embeds = model.get_encoder().embed_positions.weight
|
||||
self.assertTrue(torch.allclose(embeds, sinusoids(*embeds.shape)))
|
||||
|
||||
def test_decoder_model_past_with_large_inputs(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
|
||||
|
||||
Reference in New Issue
Block a user