Make Whisper Encoder's sinusoidal PE non-trainable by default (#26032)
* set encoder's PE as non-trainable * freeze flax * init sinusoids * add test for non-trainable embed positions * simplify TF encoder embed_pos * revert tf * clean up * add sinusoidal init for jax * make consistent sinusoidal function * fix dtype * add default dtype * use numpy for sinusoids. fix jax * add sinusoid init for TF * fix * use custom embedding * use specialized init for each impl * fix sinusoids init. add test for pytorch * fix TF dtype * simplify sinusoid init for flax and tf * add tests for TF * change default dtype to float32 * add sinusoid test for flax * Update src/transformers/models/whisper/modeling_flax_whisper.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update src/transformers/models/whisper/modeling_tf_whisper.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * move sinusoidal init to _init_weights --------- Co-authored-by: sanchit-gandhi <sanchit@huggingface.co> Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
""" Flax whisper model."""
|
||||
|
||||
import math
|
||||
import random
|
||||
from functools import partial
|
||||
from typing import Optional, Tuple
|
||||
@@ -58,6 +59,19 @@ _CONFIG_FOR_DOC = "WhisperConfig"
|
||||
remat = nn_partitioning.remat
|
||||
|
||||
|
||||
def sinusoidal_embedding_init(key, shape, dtype=jnp.float_) -> jax.Array:
|
||||
"""Returns sinusoids for positional embedding"""
|
||||
length, channels = shape
|
||||
if channels % 2 != 0:
|
||||
raise ValueError(
|
||||
f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels."
|
||||
)
|
||||
log_timescale_increment = math.log(10000) / (channels // 2 - 1)
|
||||
inv_timescales = jnp.exp(-log_timescale_increment * jnp.arange(channels // 2))
|
||||
scaled_time = jnp.arange(length).reshape(-1, 1) * inv_timescales.reshape(1, -1)
|
||||
return jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1).astype(dtype)
|
||||
|
||||
|
||||
WHISPER_START_DOCSTRING = r"""
|
||||
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its models (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
@@ -649,7 +663,13 @@ class FlaxWhisperEncoder(nn.Module):
|
||||
dtype=self.dtype,
|
||||
gradient_checkpointing=self.gradient_checkpointing,
|
||||
)
|
||||
self.embed_positions = nn.Embed(self.config.max_source_positions, self.config.d_model, dtype=self.dtype)
|
||||
|
||||
self.embed_positions = nn.Embed(
|
||||
self.config.max_source_positions,
|
||||
self.config.d_model,
|
||||
dtype=self.dtype,
|
||||
embedding_init=sinusoidal_embedding_init,
|
||||
)
|
||||
|
||||
self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
|
||||
|
||||
@@ -673,6 +693,8 @@ class FlaxWhisperEncoder(nn.Module):
|
||||
hidden_states = jax.nn.gelu(self.conv2(hidden_states), approximate=False)
|
||||
|
||||
embed_positions = self.embed_positions(jnp.arange(self.config.max_source_positions))
|
||||
# freeze the sinusoidal embeddings by stopping the back-prop
|
||||
embed_positions = jax.lax.stop_gradient(embed_positions)
|
||||
hidden_states = hidden_states + embed_positions
|
||||
|
||||
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
|
||||
|
||||
@@ -59,6 +59,19 @@ TF_WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
LARGE_NEGATIVE = -1e8
|
||||
|
||||
|
||||
def sinusoidal_embedding_init(shape, dtype=tf.float32) -> tf.Tensor:
|
||||
"""Returns sinusoids for positional embedding"""
|
||||
length, channels = shape
|
||||
if channels % 2 != 0:
|
||||
raise ValueError(
|
||||
f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels."
|
||||
)
|
||||
log_timescale_increment = math.log(10000) / (channels // 2 - 1)
|
||||
inv_timescales = tf.exp(-log_timescale_increment * tf.range(channels // 2, dtype=tf.float32))
|
||||
scaled_time = tf.reshape(tf.range(length, dtype=tf.float32), (-1, 1)) * tf.reshape(inv_timescales, (1, -1))
|
||||
return tf.cast(tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1), dtype)
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right
|
||||
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
|
||||
pad_token_id = tf.cast(pad_token_id, input_ids.dtype)
|
||||
@@ -117,16 +130,25 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
|
||||
|
||||
|
||||
class TFWhisperPositionalEmbedding(tf.keras.layers.Layer):
|
||||
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
num_positions: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: Optional[int] = None,
|
||||
embedding_initializer=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.num_positions = num_positions
|
||||
self.embedding_dim = embedding_dim
|
||||
self.padding_idx = padding_idx
|
||||
self.embedding_initializer = tf.keras.initializers.get(embedding_initializer)
|
||||
|
||||
def build(self, input_shape):
|
||||
self.weight = self.add_weight(
|
||||
name="weight",
|
||||
shape=[self.num_positions, self.embedding_dim],
|
||||
initializer=self.embedding_initializer,
|
||||
trainable=True,
|
||||
)
|
||||
super().build(input_shape)
|
||||
@@ -620,8 +642,12 @@ class TFWhisperEncoder(tf.keras.layers.Layer):
|
||||
self.conv2 = tf.keras.layers.Conv1D(self.embed_dim, kernel_size=3, strides=2, padding="valid", name="conv2")
|
||||
|
||||
self.embed_positions = TFWhisperPositionalEmbedding(
|
||||
self.max_source_positions, self.embed_dim, name="embed_positions"
|
||||
num_positions=self.max_source_positions,
|
||||
embedding_dim=self.embed_dim,
|
||||
embedding_initializer=sinusoidal_embedding_init,
|
||||
name="embed_positions",
|
||||
)
|
||||
self.embed_positions.trainable = False
|
||||
|
||||
self.encoder_layers = [TFWhisperEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
|
||||
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm")
|
||||
|
||||
@@ -55,6 +55,18 @@ WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
]
|
||||
|
||||
|
||||
def sinusoids(length: int, channels: int, max_timescale: float = 10000) -> torch.Tensor:
|
||||
"""Returns sinusoids for positional embedding"""
|
||||
if channels % 2 != 0:
|
||||
raise ValueError(
|
||||
f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels."
|
||||
)
|
||||
log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1)
|
||||
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
|
||||
scaled_time = torch.arange(length).view(-1, 1) * inv_timescales.view(1, -1)
|
||||
return torch.cat([scaled_time.sin(), scaled_time.cos()], dim=1)
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.shift_tokens_right
|
||||
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
|
||||
"""
|
||||
@@ -668,6 +680,10 @@ class WhisperPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, WhisperEncoder):
|
||||
with torch.no_grad():
|
||||
embed_positions = module.embed_positions.weight
|
||||
embed_positions.copy_(sinusoids(*embed_positions.shape))
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (WhisperDecoder, WhisperEncoder)):
|
||||
@@ -835,6 +851,7 @@ class WhisperEncoder(WhisperPreTrainedModel):
|
||||
self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
|
||||
|
||||
self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)
|
||||
self.embed_positions.requires_grad_(False)
|
||||
|
||||
self.layers = nn.ModuleList([WhisperEncoderLayer(config) for _ in range(config.encoder_layers)])
|
||||
self.layer_norm = nn.LayerNorm(config.d_model)
|
||||
|
||||
@@ -46,6 +46,7 @@ if is_flax_available():
|
||||
WhisperProcessor,
|
||||
)
|
||||
from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model
|
||||
from transformers.models.whisper.modeling_flax_whisper import sinusoidal_embedding_init
|
||||
|
||||
|
||||
@require_flax
|
||||
@@ -387,6 +388,19 @@ class FlaxWhisperModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
def test_encoder_sinusoidal_embed_positions(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
params = model.params
|
||||
if model.base_model_prefix in params:
|
||||
params = model.params[model.base_model_prefix]
|
||||
|
||||
embeds = params["encoder"]["embed_positions"]["embedding"]
|
||||
sinusoids = sinusoidal_embedding_init(None, embeds.shape)
|
||||
self.assertTrue(jax.numpy.allclose(embeds, sinusoids))
|
||||
|
||||
|
||||
@slow
|
||||
@require_flax
|
||||
|
||||
@@ -42,7 +42,11 @@ if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers import TFWhisperForConditionalGeneration, TFWhisperModel, set_seed
|
||||
from transformers.models.whisper.modeling_tf_whisper import TFWhisperDecoder, TFWhisperEncoder
|
||||
from transformers.models.whisper.modeling_tf_whisper import (
|
||||
TFWhisperDecoder,
|
||||
TFWhisperEncoder,
|
||||
sinusoidal_embedding_init,
|
||||
)
|
||||
|
||||
|
||||
def prepare_whisper_inputs_dict(
|
||||
@@ -297,6 +301,23 @@ class TFWhisperModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model_forward(*config_and_inputs)
|
||||
|
||||
def test_requires_grad_encoder_embed_positions(self):
|
||||
config = self.model_tester.get_config()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
encoder = model.get_encoder()
|
||||
self.assertFalse(encoder.embed_positions.trainable)
|
||||
|
||||
def test_encoder_sinusoidal_embed_positions(self):
|
||||
config = self.model_tester.get_config()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.build()
|
||||
|
||||
embeds = model.get_encoder().embed_positions.get_weights()[0]
|
||||
sinusoids = sinusoidal_embedding_init(embeds.shape).numpy()
|
||||
self.assertTrue(np.allclose(embeds, sinusoids))
|
||||
|
||||
def test_decoder_model_past_with_large_inputs(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
|
||||
|
||||
@@ -49,7 +49,7 @@ if is_torch_available():
|
||||
WhisperProcessor,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder
|
||||
from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder, sinusoids
|
||||
|
||||
if is_flax_available():
|
||||
import jax.numpy as jnp
|
||||
@@ -351,6 +351,20 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
self.assertFalse(all(encoder_grads))
|
||||
self.assertTrue(all(decoder_grads))
|
||||
|
||||
def test_requires_grad_encoder_embed_positions(self):
|
||||
config = self.model_tester.get_config()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
encoder = model.get_encoder()
|
||||
self.assertFalse(encoder.embed_positions.weight.requires_grad)
|
||||
|
||||
def test_encoder_sinusoidal_embed_positions(self):
|
||||
config = self.model_tester.get_config()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
embeds = model.get_encoder().embed_positions.weight
|
||||
self.assertTrue(torch.allclose(embeds, sinusoids(*embeds.shape)))
|
||||
|
||||
def test_decoder_model_past_with_large_inputs(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
|
||||
|
||||
Reference in New Issue
Block a user