Make Whisper Encoder's sinusoidal PE non-trainable by default (#26032)

* set encoder's PE as non-trainable

* freeze flax

* init sinusoids

* add test for non-trainable embed positions

* simplify TF encoder embed_pos

* revert tf

* clean up

* add sinusoidal init for jax

* make consistent sinusoidal function

* fix dtype

* add default dtype

* use numpy for sinusoids. fix jax

* add sinusoid init for TF

* fix

* use custom embedding

* use specialized init for each impl

* fix sinusoids init. add test for pytorch

* fix TF dtype

* simplify sinusoid init for flax and tf

* add tests for TF

* change default dtype to float32

* add sinusoid test for flax

* Update src/transformers/models/whisper/modeling_flax_whisper.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update src/transformers/models/whisper/modeling_tf_whisper.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* move sinusoidal init to _init_weights

---------

Co-authored-by: sanchit-gandhi <sanchit@huggingface.co>
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
This commit is contained in:
Thien Tran
2023-10-11 16:08:54 +08:00
committed by GitHub
parent fc63914399
commit 1e3c9ddacc
6 changed files with 119 additions and 5 deletions

View File

@@ -46,6 +46,7 @@ if is_flax_available():
WhisperProcessor,
)
from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model
from transformers.models.whisper.modeling_flax_whisper import sinusoidal_embedding_init
@require_flax
@@ -387,6 +388,19 @@ class FlaxWhisperModelTest(FlaxModelTesterMixin, unittest.TestCase):
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
def test_encoder_sinusoidal_embed_positions(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
params = model.params
if model.base_model_prefix in params:
params = model.params[model.base_model_prefix]
embeds = params["encoder"]["embed_positions"]["embedding"]
sinusoids = sinusoidal_embedding_init(None, embeds.shape)
self.assertTrue(jax.numpy.allclose(embeds, sinusoids))
@slow
@require_flax

View File

@@ -42,7 +42,11 @@ if is_tf_available():
import tensorflow as tf
from transformers import TFWhisperForConditionalGeneration, TFWhisperModel, set_seed
from transformers.models.whisper.modeling_tf_whisper import TFWhisperDecoder, TFWhisperEncoder
from transformers.models.whisper.modeling_tf_whisper import (
TFWhisperDecoder,
TFWhisperEncoder,
sinusoidal_embedding_init,
)
def prepare_whisper_inputs_dict(
@@ -297,6 +301,23 @@ class TFWhisperModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestC
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_forward(*config_and_inputs)
def test_requires_grad_encoder_embed_positions(self):
config = self.model_tester.get_config()
for model_class in self.all_model_classes:
model = model_class(config)
encoder = model.get_encoder()
self.assertFalse(encoder.embed_positions.trainable)
def test_encoder_sinusoidal_embed_positions(self):
config = self.model_tester.get_config()
for model_class in self.all_model_classes:
model = model_class(config)
model.build()
embeds = model.get_encoder().embed_positions.get_weights()[0]
sinusoids = sinusoidal_embedding_init(embeds.shape).numpy()
self.assertTrue(np.allclose(embeds, sinusoids))
def test_decoder_model_past_with_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)

View File

@@ -49,7 +49,7 @@ if is_torch_available():
WhisperProcessor,
set_seed,
)
from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder
from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder, sinusoids
if is_flax_available():
import jax.numpy as jnp
@@ -351,6 +351,20 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
self.assertFalse(all(encoder_grads))
self.assertTrue(all(decoder_grads))
def test_requires_grad_encoder_embed_positions(self):
config = self.model_tester.get_config()
for model_class in self.all_model_classes:
model = model_class(config)
encoder = model.get_encoder()
self.assertFalse(encoder.embed_positions.weight.requires_grad)
def test_encoder_sinusoidal_embed_positions(self):
config = self.model_tester.get_config()
for model_class in self.all_model_classes:
model = model_class(config)
embeds = model.get_encoder().embed_positions.weight
self.assertTrue(torch.allclose(embeds, sinusoids(*embeds.shape)))
def test_decoder_model_past_with_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)