From 1e3c9ddacc7fc4142253bc9ddcba85c4d5b977e7 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 11 Oct 2023 16:08:54 +0800 Subject: [PATCH] Make Whisper Encoder's sinusoidal PE non-trainable by default (#26032) * set encoder's PE as non-trainable * freeze flax * init sinusoids * add test for non-trainable embed positions * simplify TF encoder embed_pos * revert tf * clean up * add sinusoidal init for jax * make consistent sinusoidal function * fix dtype * add default dtype * use numpy for sinusoids. fix jax * add sinusoid init for TF * fix * use custom embedding * use specialized init for each impl * fix sinusoids init. add test for pytorch * fix TF dtype * simplify sinusoid init for flax and tf * add tests for TF * change default dtype to float32 * add sinusoid test for flax * Update src/transformers/models/whisper/modeling_flax_whisper.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update src/transformers/models/whisper/modeling_tf_whisper.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * move sinusoidal init to _init_weights --------- Co-authored-by: sanchit-gandhi Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> --- .../models/whisper/modeling_flax_whisper.py | 24 ++++++++++++++- .../models/whisper/modeling_tf_whisper.py | 30 +++++++++++++++++-- .../models/whisper/modeling_whisper.py | 17 +++++++++++ .../whisper/test_modeling_flax_whisper.py | 14 +++++++++ .../whisper/test_modeling_tf_whisper.py | 23 +++++++++++++- tests/models/whisper/test_modeling_whisper.py | 16 +++++++++- 6 files changed, 119 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index 0f158fb602..ffcaeb53ad 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -14,6 +14,7 @@ # limitations under the License. """ Flax whisper model.""" +import math import random from functools import partial from typing import Optional, Tuple @@ -58,6 +59,19 @@ _CONFIG_FOR_DOC = "WhisperConfig" remat = nn_partitioning.remat +def sinusoidal_embedding_init(key, shape, dtype=jnp.float_) -> jax.Array: + """Returns sinusoids for positional embedding""" + length, channels = shape + if channels % 2 != 0: + raise ValueError( + f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels." + ) + log_timescale_increment = math.log(10000) / (channels // 2 - 1) + inv_timescales = jnp.exp(-log_timescale_increment * jnp.arange(channels // 2)) + scaled_time = jnp.arange(length).reshape(-1, 1) * inv_timescales.reshape(1, -1) + return jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1).astype(dtype) + + WHISPER_START_DOCSTRING = r""" This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its models (such as downloading or saving, resizing the input embeddings, pruning heads @@ -649,7 +663,13 @@ class FlaxWhisperEncoder(nn.Module): dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing, ) - self.embed_positions = nn.Embed(self.config.max_source_positions, self.config.d_model, dtype=self.dtype) + + self.embed_positions = nn.Embed( + self.config.max_source_positions, + self.config.d_model, + dtype=self.dtype, + embedding_init=sinusoidal_embedding_init, + ) self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) @@ -673,6 +693,8 @@ class FlaxWhisperEncoder(nn.Module): hidden_states = jax.nn.gelu(self.conv2(hidden_states), approximate=False) embed_positions = self.embed_positions(jnp.arange(self.config.max_source_positions)) + # freeze the sinusoidal embeddings by stopping the back-prop + embed_positions = jax.lax.stop_gradient(embed_positions) hidden_states = hidden_states + embed_positions hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) diff --git a/src/transformers/models/whisper/modeling_tf_whisper.py b/src/transformers/models/whisper/modeling_tf_whisper.py index 27b6ff63ce..1dfe413da2 100644 --- a/src/transformers/models/whisper/modeling_tf_whisper.py +++ b/src/transformers/models/whisper/modeling_tf_whisper.py @@ -59,6 +59,19 @@ TF_WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST = [ LARGE_NEGATIVE = -1e8 +def sinusoidal_embedding_init(shape, dtype=tf.float32) -> tf.Tensor: + """Returns sinusoids for positional embedding""" + length, channels = shape + if channels % 2 != 0: + raise ValueError( + f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels." + ) + log_timescale_increment = math.log(10000) / (channels // 2 - 1) + inv_timescales = tf.exp(-log_timescale_increment * tf.range(channels // 2, dtype=tf.float32)) + scaled_time = tf.reshape(tf.range(length, dtype=tf.float32), (-1, 1)) * tf.reshape(inv_timescales, (1, -1)) + return tf.cast(tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1), dtype) + + # Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int): pad_token_id = tf.cast(pad_token_id, input_ids.dtype) @@ -117,16 +130,25 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None): class TFWhisperPositionalEmbedding(tf.keras.layers.Layer): - def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None, **kwargs): + def __init__( + self, + num_positions: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + embedding_initializer=None, + **kwargs, + ): super().__init__(**kwargs) self.num_positions = num_positions self.embedding_dim = embedding_dim self.padding_idx = padding_idx + self.embedding_initializer = tf.keras.initializers.get(embedding_initializer) def build(self, input_shape): self.weight = self.add_weight( name="weight", shape=[self.num_positions, self.embedding_dim], + initializer=self.embedding_initializer, trainable=True, ) super().build(input_shape) @@ -620,8 +642,12 @@ class TFWhisperEncoder(tf.keras.layers.Layer): self.conv2 = tf.keras.layers.Conv1D(self.embed_dim, kernel_size=3, strides=2, padding="valid", name="conv2") self.embed_positions = TFWhisperPositionalEmbedding( - self.max_source_positions, self.embed_dim, name="embed_positions" + num_positions=self.max_source_positions, + embedding_dim=self.embed_dim, + embedding_initializer=sinusoidal_embedding_init, + name="embed_positions", ) + self.embed_positions.trainable = False self.encoder_layers = [TFWhisperEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)] self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm") diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 447d7275d5..be5f50dbff 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -55,6 +55,18 @@ WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST = [ ] +def sinusoids(length: int, channels: int, max_timescale: float = 10000) -> torch.Tensor: + """Returns sinusoids for positional embedding""" + if channels % 2 != 0: + raise ValueError( + f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels." + ) + log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1) + inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) + scaled_time = torch.arange(length).view(-1, 1) * inv_timescales.view(1, -1) + return torch.cat([scaled_time.sin(), scaled_time.cos()], dim=1) + + # Copied from transformers.models.bart.modeling_bart.shift_tokens_right def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): """ @@ -668,6 +680,10 @@ class WhisperPreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, WhisperEncoder): + with torch.no_grad(): + embed_positions = module.embed_positions.weight + embed_positions.copy_(sinusoids(*embed_positions.shape)) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (WhisperDecoder, WhisperEncoder)): @@ -835,6 +851,7 @@ class WhisperEncoder(WhisperPreTrainedModel): self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1) self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim) + self.embed_positions.requires_grad_(False) self.layers = nn.ModuleList([WhisperEncoderLayer(config) for _ in range(config.encoder_layers)]) self.layer_norm = nn.LayerNorm(config.d_model) diff --git a/tests/models/whisper/test_modeling_flax_whisper.py b/tests/models/whisper/test_modeling_flax_whisper.py index 7ec5f90f0f..982dcb4827 100644 --- a/tests/models/whisper/test_modeling_flax_whisper.py +++ b/tests/models/whisper/test_modeling_flax_whisper.py @@ -46,6 +46,7 @@ if is_flax_available(): WhisperProcessor, ) from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model + from transformers.models.whisper.modeling_flax_whisper import sinusoidal_embedding_init @require_flax @@ -387,6 +388,19 @@ class FlaxWhisperModelTest(FlaxModelTesterMixin, unittest.TestCase): max_diff = (base_params[key] - base_params_from_head[key]).sum().item() self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + def test_encoder_sinusoidal_embed_positions(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + params = model.params + if model.base_model_prefix in params: + params = model.params[model.base_model_prefix] + + embeds = params["encoder"]["embed_positions"]["embedding"] + sinusoids = sinusoidal_embedding_init(None, embeds.shape) + self.assertTrue(jax.numpy.allclose(embeds, sinusoids)) + @slow @require_flax diff --git a/tests/models/whisper/test_modeling_tf_whisper.py b/tests/models/whisper/test_modeling_tf_whisper.py index 7fae1e466e..75c62ae1ad 100644 --- a/tests/models/whisper/test_modeling_tf_whisper.py +++ b/tests/models/whisper/test_modeling_tf_whisper.py @@ -42,7 +42,11 @@ if is_tf_available(): import tensorflow as tf from transformers import TFWhisperForConditionalGeneration, TFWhisperModel, set_seed - from transformers.models.whisper.modeling_tf_whisper import TFWhisperDecoder, TFWhisperEncoder + from transformers.models.whisper.modeling_tf_whisper import ( + TFWhisperDecoder, + TFWhisperEncoder, + sinusoidal_embedding_init, + ) def prepare_whisper_inputs_dict( @@ -297,6 +301,23 @@ class TFWhisperModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestC config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model_forward(*config_and_inputs) + def test_requires_grad_encoder_embed_positions(self): + config = self.model_tester.get_config() + for model_class in self.all_model_classes: + model = model_class(config) + encoder = model.get_encoder() + self.assertFalse(encoder.embed_positions.trainable) + + def test_encoder_sinusoidal_embed_positions(self): + config = self.model_tester.get_config() + for model_class in self.all_model_classes: + model = model_class(config) + model.build() + + embeds = model.get_encoder().embed_positions.get_weights()[0] + sinusoids = sinusoidal_embedding_init(embeds.shape).numpy() + self.assertTrue(np.allclose(embeds, sinusoids)) + def test_decoder_model_past_with_large_inputs(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 9decb7192a..337d334852 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -49,7 +49,7 @@ if is_torch_available(): WhisperProcessor, set_seed, ) - from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder + from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder, sinusoids if is_flax_available(): import jax.numpy as jnp @@ -351,6 +351,20 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi self.assertFalse(all(encoder_grads)) self.assertTrue(all(decoder_grads)) + def test_requires_grad_encoder_embed_positions(self): + config = self.model_tester.get_config() + for model_class in self.all_model_classes: + model = model_class(config) + encoder = model.get_encoder() + self.assertFalse(encoder.embed_positions.weight.requires_grad) + + def test_encoder_sinusoidal_embed_positions(self): + config = self.model_tester.get_config() + for model_class in self.all_model_classes: + model = model_class(config) + embeds = model.get_encoder().embed_positions.weight + self.assertTrue(torch.allclose(embeds, sinusoids(*embeds.shape))) + def test_decoder_model_past_with_large_inputs(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)