Make Whisper Encoder's sinusoidal PE non-trainable by default (#26032)
* set encoder's PE as non-trainable * freeze flax * init sinusoids * add test for non-trainable embed positions * simplify TF encoder embed_pos * revert tf * clean up * add sinusoidal init for jax * make consistent sinusoidal function * fix dtype * add default dtype * use numpy for sinusoids. fix jax * add sinusoid init for TF * fix * use custom embedding * use specialized init for each impl * fix sinusoids init. add test for pytorch * fix TF dtype * simplify sinusoid init for flax and tf * add tests for TF * change default dtype to float32 * add sinusoid test for flax * Update src/transformers/models/whisper/modeling_flax_whisper.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update src/transformers/models/whisper/modeling_tf_whisper.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * move sinusoidal init to _init_weights --------- Co-authored-by: sanchit-gandhi <sanchit@huggingface.co> Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
This commit is contained in:
@@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" Flax whisper model."""
|
""" Flax whisper model."""
|
||||||
|
|
||||||
|
import math
|
||||||
import random
|
import random
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
@@ -58,6 +59,19 @@ _CONFIG_FOR_DOC = "WhisperConfig"
|
|||||||
remat = nn_partitioning.remat
|
remat = nn_partitioning.remat
|
||||||
|
|
||||||
|
|
||||||
|
def sinusoidal_embedding_init(key, shape, dtype=jnp.float_) -> jax.Array:
|
||||||
|
"""Returns sinusoids for positional embedding"""
|
||||||
|
length, channels = shape
|
||||||
|
if channels % 2 != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels."
|
||||||
|
)
|
||||||
|
log_timescale_increment = math.log(10000) / (channels // 2 - 1)
|
||||||
|
inv_timescales = jnp.exp(-log_timescale_increment * jnp.arange(channels // 2))
|
||||||
|
scaled_time = jnp.arange(length).reshape(-1, 1) * inv_timescales.reshape(1, -1)
|
||||||
|
return jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1).astype(dtype)
|
||||||
|
|
||||||
|
|
||||||
WHISPER_START_DOCSTRING = r"""
|
WHISPER_START_DOCSTRING = r"""
|
||||||
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
|
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||||
library implements for all its models (such as downloading or saving, resizing the input embeddings, pruning heads
|
library implements for all its models (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||||
@@ -649,7 +663,13 @@ class FlaxWhisperEncoder(nn.Module):
|
|||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
gradient_checkpointing=self.gradient_checkpointing,
|
gradient_checkpointing=self.gradient_checkpointing,
|
||||||
)
|
)
|
||||||
self.embed_positions = nn.Embed(self.config.max_source_positions, self.config.d_model, dtype=self.dtype)
|
|
||||||
|
self.embed_positions = nn.Embed(
|
||||||
|
self.config.max_source_positions,
|
||||||
|
self.config.d_model,
|
||||||
|
dtype=self.dtype,
|
||||||
|
embedding_init=sinusoidal_embedding_init,
|
||||||
|
)
|
||||||
|
|
||||||
self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
|
self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
|
||||||
|
|
||||||
@@ -673,6 +693,8 @@ class FlaxWhisperEncoder(nn.Module):
|
|||||||
hidden_states = jax.nn.gelu(self.conv2(hidden_states), approximate=False)
|
hidden_states = jax.nn.gelu(self.conv2(hidden_states), approximate=False)
|
||||||
|
|
||||||
embed_positions = self.embed_positions(jnp.arange(self.config.max_source_positions))
|
embed_positions = self.embed_positions(jnp.arange(self.config.max_source_positions))
|
||||||
|
# freeze the sinusoidal embeddings by stopping the back-prop
|
||||||
|
embed_positions = jax.lax.stop_gradient(embed_positions)
|
||||||
hidden_states = hidden_states + embed_positions
|
hidden_states = hidden_states + embed_positions
|
||||||
|
|
||||||
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
|
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
|
||||||
|
|||||||
@@ -59,6 +59,19 @@ TF_WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
|||||||
LARGE_NEGATIVE = -1e8
|
LARGE_NEGATIVE = -1e8
|
||||||
|
|
||||||
|
|
||||||
|
def sinusoidal_embedding_init(shape, dtype=tf.float32) -> tf.Tensor:
|
||||||
|
"""Returns sinusoids for positional embedding"""
|
||||||
|
length, channels = shape
|
||||||
|
if channels % 2 != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels."
|
||||||
|
)
|
||||||
|
log_timescale_increment = math.log(10000) / (channels // 2 - 1)
|
||||||
|
inv_timescales = tf.exp(-log_timescale_increment * tf.range(channels // 2, dtype=tf.float32))
|
||||||
|
scaled_time = tf.reshape(tf.range(length, dtype=tf.float32), (-1, 1)) * tf.reshape(inv_timescales, (1, -1))
|
||||||
|
return tf.cast(tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1), dtype)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right
|
# Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right
|
||||||
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
|
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
|
||||||
pad_token_id = tf.cast(pad_token_id, input_ids.dtype)
|
pad_token_id = tf.cast(pad_token_id, input_ids.dtype)
|
||||||
@@ -117,16 +130,25 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
|
|||||||
|
|
||||||
|
|
||||||
class TFWhisperPositionalEmbedding(tf.keras.layers.Layer):
|
class TFWhisperPositionalEmbedding(tf.keras.layers.Layer):
|
||||||
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None, **kwargs):
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_positions: int,
|
||||||
|
embedding_dim: int,
|
||||||
|
padding_idx: Optional[int] = None,
|
||||||
|
embedding_initializer=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.num_positions = num_positions
|
self.num_positions = num_positions
|
||||||
self.embedding_dim = embedding_dim
|
self.embedding_dim = embedding_dim
|
||||||
self.padding_idx = padding_idx
|
self.padding_idx = padding_idx
|
||||||
|
self.embedding_initializer = tf.keras.initializers.get(embedding_initializer)
|
||||||
|
|
||||||
def build(self, input_shape):
|
def build(self, input_shape):
|
||||||
self.weight = self.add_weight(
|
self.weight = self.add_weight(
|
||||||
name="weight",
|
name="weight",
|
||||||
shape=[self.num_positions, self.embedding_dim],
|
shape=[self.num_positions, self.embedding_dim],
|
||||||
|
initializer=self.embedding_initializer,
|
||||||
trainable=True,
|
trainable=True,
|
||||||
)
|
)
|
||||||
super().build(input_shape)
|
super().build(input_shape)
|
||||||
@@ -620,8 +642,12 @@ class TFWhisperEncoder(tf.keras.layers.Layer):
|
|||||||
self.conv2 = tf.keras.layers.Conv1D(self.embed_dim, kernel_size=3, strides=2, padding="valid", name="conv2")
|
self.conv2 = tf.keras.layers.Conv1D(self.embed_dim, kernel_size=3, strides=2, padding="valid", name="conv2")
|
||||||
|
|
||||||
self.embed_positions = TFWhisperPositionalEmbedding(
|
self.embed_positions = TFWhisperPositionalEmbedding(
|
||||||
self.max_source_positions, self.embed_dim, name="embed_positions"
|
num_positions=self.max_source_positions,
|
||||||
|
embedding_dim=self.embed_dim,
|
||||||
|
embedding_initializer=sinusoidal_embedding_init,
|
||||||
|
name="embed_positions",
|
||||||
)
|
)
|
||||||
|
self.embed_positions.trainable = False
|
||||||
|
|
||||||
self.encoder_layers = [TFWhisperEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
|
self.encoder_layers = [TFWhisperEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
|
||||||
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm")
|
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm")
|
||||||
|
|||||||
@@ -55,6 +55,18 @@ WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def sinusoids(length: int, channels: int, max_timescale: float = 10000) -> torch.Tensor:
|
||||||
|
"""Returns sinusoids for positional embedding"""
|
||||||
|
if channels % 2 != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels."
|
||||||
|
)
|
||||||
|
log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1)
|
||||||
|
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
|
||||||
|
scaled_time = torch.arange(length).view(-1, 1) * inv_timescales.view(1, -1)
|
||||||
|
return torch.cat([scaled_time.sin(), scaled_time.cos()], dim=1)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.bart.modeling_bart.shift_tokens_right
|
# Copied from transformers.models.bart.modeling_bart.shift_tokens_right
|
||||||
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
|
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
|
||||||
"""
|
"""
|
||||||
@@ -668,6 +680,10 @@ class WhisperPreTrainedModel(PreTrainedModel):
|
|||||||
module.weight.data.normal_(mean=0.0, std=std)
|
module.weight.data.normal_(mean=0.0, std=std)
|
||||||
if module.padding_idx is not None:
|
if module.padding_idx is not None:
|
||||||
module.weight.data[module.padding_idx].zero_()
|
module.weight.data[module.padding_idx].zero_()
|
||||||
|
elif isinstance(module, WhisperEncoder):
|
||||||
|
with torch.no_grad():
|
||||||
|
embed_positions = module.embed_positions.weight
|
||||||
|
embed_positions.copy_(sinusoids(*embed_positions.shape))
|
||||||
|
|
||||||
def _set_gradient_checkpointing(self, module, value=False):
|
def _set_gradient_checkpointing(self, module, value=False):
|
||||||
if isinstance(module, (WhisperDecoder, WhisperEncoder)):
|
if isinstance(module, (WhisperDecoder, WhisperEncoder)):
|
||||||
@@ -835,6 +851,7 @@ class WhisperEncoder(WhisperPreTrainedModel):
|
|||||||
self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
|
self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
|
||||||
|
|
||||||
self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)
|
self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)
|
||||||
|
self.embed_positions.requires_grad_(False)
|
||||||
|
|
||||||
self.layers = nn.ModuleList([WhisperEncoderLayer(config) for _ in range(config.encoder_layers)])
|
self.layers = nn.ModuleList([WhisperEncoderLayer(config) for _ in range(config.encoder_layers)])
|
||||||
self.layer_norm = nn.LayerNorm(config.d_model)
|
self.layer_norm = nn.LayerNorm(config.d_model)
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ if is_flax_available():
|
|||||||
WhisperProcessor,
|
WhisperProcessor,
|
||||||
)
|
)
|
||||||
from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model
|
from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model
|
||||||
|
from transformers.models.whisper.modeling_flax_whisper import sinusoidal_embedding_init
|
||||||
|
|
||||||
|
|
||||||
@require_flax
|
@require_flax
|
||||||
@@ -387,6 +388,19 @@ class FlaxWhisperModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
|||||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||||
|
|
||||||
|
def test_encoder_sinusoidal_embed_positions(self):
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
params = model.params
|
||||||
|
if model.base_model_prefix in params:
|
||||||
|
params = model.params[model.base_model_prefix]
|
||||||
|
|
||||||
|
embeds = params["encoder"]["embed_positions"]["embedding"]
|
||||||
|
sinusoids = sinusoidal_embedding_init(None, embeds.shape)
|
||||||
|
self.assertTrue(jax.numpy.allclose(embeds, sinusoids))
|
||||||
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_flax
|
@require_flax
|
||||||
|
|||||||
@@ -42,7 +42,11 @@ if is_tf_available():
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from transformers import TFWhisperForConditionalGeneration, TFWhisperModel, set_seed
|
from transformers import TFWhisperForConditionalGeneration, TFWhisperModel, set_seed
|
||||||
from transformers.models.whisper.modeling_tf_whisper import TFWhisperDecoder, TFWhisperEncoder
|
from transformers.models.whisper.modeling_tf_whisper import (
|
||||||
|
TFWhisperDecoder,
|
||||||
|
TFWhisperEncoder,
|
||||||
|
sinusoidal_embedding_init,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def prepare_whisper_inputs_dict(
|
def prepare_whisper_inputs_dict(
|
||||||
@@ -297,6 +301,23 @@ class TFWhisperModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_model_forward(*config_and_inputs)
|
self.model_tester.create_and_check_model_forward(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_requires_grad_encoder_embed_positions(self):
|
||||||
|
config = self.model_tester.get_config()
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
encoder = model.get_encoder()
|
||||||
|
self.assertFalse(encoder.embed_positions.trainable)
|
||||||
|
|
||||||
|
def test_encoder_sinusoidal_embed_positions(self):
|
||||||
|
config = self.model_tester.get_config()
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
model.build()
|
||||||
|
|
||||||
|
embeds = model.get_encoder().embed_positions.get_weights()[0]
|
||||||
|
sinusoids = sinusoidal_embedding_init(embeds.shape).numpy()
|
||||||
|
self.assertTrue(np.allclose(embeds, sinusoids))
|
||||||
|
|
||||||
def test_decoder_model_past_with_large_inputs(self):
|
def test_decoder_model_past_with_large_inputs(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
|
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ if is_torch_available():
|
|||||||
WhisperProcessor,
|
WhisperProcessor,
|
||||||
set_seed,
|
set_seed,
|
||||||
)
|
)
|
||||||
from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder
|
from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder, sinusoids
|
||||||
|
|
||||||
if is_flax_available():
|
if is_flax_available():
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
@@ -351,6 +351,20 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
self.assertFalse(all(encoder_grads))
|
self.assertFalse(all(encoder_grads))
|
||||||
self.assertTrue(all(decoder_grads))
|
self.assertTrue(all(decoder_grads))
|
||||||
|
|
||||||
|
def test_requires_grad_encoder_embed_positions(self):
|
||||||
|
config = self.model_tester.get_config()
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
encoder = model.get_encoder()
|
||||||
|
self.assertFalse(encoder.embed_positions.weight.requires_grad)
|
||||||
|
|
||||||
|
def test_encoder_sinusoidal_embed_positions(self):
|
||||||
|
config = self.model_tester.get_config()
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
embeds = model.get_encoder().embed_positions.weight
|
||||||
|
self.assertTrue(torch.allclose(embeds, sinusoids(*embeds.shape)))
|
||||||
|
|
||||||
def test_decoder_model_past_with_large_inputs(self):
|
def test_decoder_model_past_with_large_inputs(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
|
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
|
||||||
|
|||||||
Reference in New Issue
Block a user