🔴 🚨 Resizing tokens embeddings: initialize from old embeddings' normal distribution. (#33325)

* intilize new embeddings from normal distrib

* Fix typo in comments

* Fix typo in comments

* Fix style

* Fix variables naming

* Add tests

* Fix style

* code consistency nit

* Add deepspeed support

* Add deepspeed support

* Conver embeddings weights to float32 before computations

* Add deepspeed tests

* Cover when vocab_size is smaller than embedding_size

* Style fix

* Add tests for vocab_size smaller than hiddin_size

* Style fix

* Nits in tests

* Nits in tests

* Check for deepspeed before importing it

* Increase vocab_size for positive definite covariance matrix test

* Add warning

* Add multivariate_resizing flag and implement resizing for lm_heads

* Fix typo

* Fix wrong bias indexing

* Fix bias is zero check

* remove multivariate_resizing flag from tests

* Intialize bias from old bias normal distribution

* Fixup

* Code usability

* Use mean_resizing instead of multivariate_resizing

* Fix up

* Fix comments and docs
This commit is contained in:
Mohamed Abu El-Nasr
2024-10-04 17:29:55 +03:00
committed by GitHub
parent b916efcb3c
commit 78ef58325c
2 changed files with 314 additions and 24 deletions

View File

@@ -25,6 +25,7 @@ import tempfile
import time
import warnings
from collections import defaultdict
from contextlib import contextmanager
from typing import Dict, List, Tuple
import numpy as np
@@ -45,6 +46,12 @@ from transformers import (
logging,
set_seed,
)
from transformers.integrations import HfDeepSpeedConfig
from transformers.integrations.deepspeed import (
is_deepspeed_available,
is_deepspeed_zero3_enabled,
unset_hf_deepspeed_config,
)
from transformers.models.auto import get_values
from transformers.models.auto.modeling_auto import (
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
@@ -75,6 +82,7 @@ from transformers.testing_utils import (
is_pt_tf_cross_test,
require_accelerate,
require_bitsandbytes,
require_deepspeed,
require_flash_attn,
require_non_xpu,
require_read_token,
@@ -134,6 +142,9 @@ if is_flax_available():
if is_torch_fx_available():
from transformers.utils.fx import _FX_SUPPORTED_MODELS_WITH_KV_CACHE, symbolic_trace
if is_deepspeed_available():
import deepspeed
def _config_zero_init(config):
configs_no_init = copy.deepcopy(config)
@@ -171,6 +182,15 @@ def _mock_all_init_weights(self):
self.tie_weights()
@contextmanager
def _deepspeed_zero3(ds_config):
dschf = HfDeepSpeedConfig(ds_config)
try:
yield dschf
finally:
unset_hf_deepspeed_config()
@require_torch
class ModelTesterMixin:
model_tester = None
@@ -1797,8 +1817,13 @@ class ModelTesterMixin:
for model_class in self.all_model_classes:
config = copy.deepcopy(original_config)
model = model_class(config)
model.to(torch_device)
if is_deepspeed_zero3_enabled():
with deepspeed.zero.Init():
model = model_class(config)
else:
model = model_class(config)
model.to(torch_device)
model_embed_pre_resize = model.get_input_embeddings()
type_model_embed_pre_resize = type(model_embed_pre_resize)
@@ -1813,15 +1838,26 @@ class ModelTesterMixin:
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
model_embed = model.resize_token_embeddings(model_vocab_size + 10)
new_model_vocab_size = model.config.get_text_config().vocab_size
self.assertEqual(new_model_vocab_size, model_vocab_size + 10)
# Check that it actually resizes the embeddings matrix
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10)
# Check to make sure the type of embeddings returned post resizing is same as type of input
type_model_embed_post_resize = type(model_embed)
self.assertEqual(type_model_embed_pre_resize, type_model_embed_post_resize)
# Check that added embeddings mean is close to the old embeddings mean
if is_deepspeed_zero3_enabled():
with deepspeed.zero.GatheredParameters(model_embed.weight, modifier_rank=None):
old_embeddings_mean = torch.mean(model_embed.weight.data[:-10, :], axis=0)
new_embeddings_mean = torch.mean(model_embed.weight.data[-10:, :], axis=0)
else:
old_embeddings_mean = torch.mean(model_embed.weight.data[:-10, :], axis=0)
new_embeddings_mean = torch.mean(model_embed.weight.data[-10:, :], axis=0)
torch.testing.assert_close(old_embeddings_mean, new_embeddings_mean, atol=1e-3, rtol=1e-1)
# Check that the model can still do a forward pass successfully (every parameter should be resized)
model(**self._prepare_for_class(inputs_dict, model_class))
if not is_deepspeed_zero3_enabled():
# A distriputed launcher is needed for the forward pass when deepspeed is enabled
model(**self._prepare_for_class(inputs_dict, model_class))
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
model_embed = model.resize_token_embeddings(model_vocab_size - 15)
@@ -1835,9 +1871,11 @@ class ModelTesterMixin:
inputs_dict["input_ids"].clamp_(max=model_vocab_size - 15 - 1)
# make sure that decoder_input_ids are resized as well
if "decoder_input_ids" in inputs_dict:
inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1)
model(**self._prepare_for_class(inputs_dict, model_class))
if not is_deepspeed_zero3_enabled():
# A distriputed launcher is needed for the forward pass when deepspeed is enabled
if "decoder_input_ids" in inputs_dict:
inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1)
model(**self._prepare_for_class(inputs_dict, model_class))
# Check that adding and removing tokens has not modified the first part of the embedding matrix.
models_equal = True
@@ -1847,9 +1885,13 @@ class ModelTesterMixin:
self.assertTrue(models_equal)
config = copy.deepcopy(original_config)
model = model_class(config)
model.to(torch_device)
del model
if is_deepspeed_zero3_enabled():
with deepspeed.zero.Init():
model = model_class(config)
else:
model = model_class(config)
model.to(torch_device)
model_vocab_size = config.get_text_config().vocab_size
model.resize_token_embeddings(model_vocab_size + 10, pad_to_multiple_of=1)
@@ -1877,6 +1919,63 @@ class ModelTesterMixin:
):
model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=1.3)
# Test when `vocab_size` is smaller than `hidden_size`.
del model
config.vocab_size = 4
if is_deepspeed_zero3_enabled():
with deepspeed.zero.Init():
model = model_class(config)
else:
model = model_class(config)
model.to(torch_device)
model_vocab_size = config.get_text_config().vocab_size
# Retrieve the embeddings and clone theme
model_embed = model.resize_token_embeddings(model_vocab_size)
cloned_embeddings = model_embed.weight.clone()
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
model_embed = model.resize_token_embeddings(model_vocab_size + 10)
new_model_vocab_size = model.config.get_text_config().vocab_size
self.assertEqual(new_model_vocab_size, model_vocab_size + 10)
# Check that it actually resizes the embeddings matrix
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10)
# Check to make sure the type of embeddings returned post resizing is same as type of input
type_model_embed_post_resize = type(model_embed)
self.assertEqual(type_model_embed_pre_resize, type_model_embed_post_resize)
# Check that added embeddings mean is close to the old embeddings mean
if is_deepspeed_zero3_enabled():
with deepspeed.zero.GatheredParameters(model_embed.weight, modifier_rank=None):
old_embeddings_mean = torch.mean(model_embed.weight.data[:-10, :], axis=0)
new_embeddings_mean = torch.mean(model_embed.weight.data[-10:, :], axis=0)
else:
old_embeddings_mean = torch.mean(model_embed.weight.data[:-10, :], axis=0)
new_embeddings_mean = torch.mean(model_embed.weight.data[-10:, :], axis=0)
torch.testing.assert_close(old_embeddings_mean, new_embeddings_mean, atol=1e-3, rtol=1e-1)
@require_deepspeed
@require_torch_gpu
def test_resize_tokens_embeddings_with_deepspeed(self):
ds_config = {
"zero_optimization": {
"stage": 3,
"offload_param": {"device": "cpu", "pin_memory": True},
},
}
with _deepspeed_zero3(ds_config):
self.test_resize_tokens_embeddings()
@require_deepspeed
@require_torch_multi_gpu
def test_resize_tokens_embeddings_with_deepspeed_multi_gpu(self):
ds_config = {
"zero_optimization": {
"stage": 3,
},
}
with _deepspeed_zero3(ds_config):
self.test_resize_tokens_embeddings()
def test_resize_embeddings_untied(self):
if not self.test_resize_embeddings:
self.skipTest(reason="test_resize_embeddings is set to `False`")
@@ -1890,7 +1989,11 @@ class ModelTesterMixin:
for model_class in self.all_model_classes:
config = copy.deepcopy(original_config)
model = model_class(config).to(torch_device)
if is_deepspeed_zero3_enabled():
with deepspeed.zero.Init():
model = model_class(config)
else:
model = model_class(config).to(torch_device)
# if no output embeddings -> leave test
if model.get_output_embeddings() is None:
@@ -1907,7 +2010,33 @@ class ModelTesterMixin:
if output_embeds.bias is not None:
self.assertEqual(output_embeds.bias.shape[0], model_vocab_size + 10)
# Check that the model can still do a forward pass successfully (every parameter should be resized)
model(**self._prepare_for_class(inputs_dict, model_class))
if not is_deepspeed_zero3_enabled():
# A distriputed launcher is needed for the forward pass when deepspeed is enabled
model(**self._prepare_for_class(inputs_dict, model_class))
# Test multivariate resizing.
model.resize_token_embeddings(model_vocab_size + 10)
output_embeds = model.get_output_embeddings()
# Check that added embeddings mean is close to the old embeddings mean
if is_deepspeed_zero3_enabled():
with deepspeed.zero.GatheredParameters(output_embeds.weight, modifier_rank=None):
old_embeddings_mean = torch.mean(output_embeds.weight.data[:-10, :], axis=0)
new_embeddings_mean = torch.mean(output_embeds.weight.data[-10:, :], axis=0)
else:
old_embeddings_mean = torch.mean(output_embeds.weight.data[:-10, :], axis=0)
new_embeddings_mean = torch.mean(output_embeds.weight.data[-10:, :], axis=0)
torch.testing.assert_close(old_embeddings_mean, new_embeddings_mean, atol=1e-3, rtol=1e-1)
# check if the bias is always initialized with zero.
if output_embeds.bias is not None:
if is_deepspeed_zero3_enabled():
with deepspeed.zero.GatheredParameters(output_embeds.bias, modifier_rank=None):
old_bias_mean = torch.mean(output_embeds.bias.data[:-10], axis=0)
new_bias_mean = torch.mean(output_embeds.bias.data[-10:], axis=0)
else:
old_bias_mean = torch.mean(output_embeds.bias.data[:-10], axis=0)
new_bias_mean = torch.mean(output_embeds.bias.data[-10:], axis=0)
torch.testing.assert_close(old_bias_mean, new_bias_mean, atol=1e-5, rtol=1e-2)
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
model.resize_token_embeddings(model_vocab_size - 15)
@@ -1925,7 +2054,32 @@ class ModelTesterMixin:
if "decoder_input_ids" in inputs_dict:
inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1)
# Check that the model can still do a forward pass successfully (every parameter should be resized)
model(**self._prepare_for_class(inputs_dict, model_class))
if not is_deepspeed_zero3_enabled():
# A distriputed launcher is needed for the forward pass when deepspeed is enabled
model(**self._prepare_for_class(inputs_dict, model_class))
@require_deepspeed
@require_torch_gpu
def test_resize_embeddings_untied_with_deepspeed(self):
ds_config = {
"zero_optimization": {
"stage": 3,
"offload_param": {"device": "cpu", "pin_memory": True},
},
}
with _deepspeed_zero3(ds_config):
self.test_resize_embeddings_untied()
@require_deepspeed
@require_torch_multi_gpu
def test_resize_embeddings_untied_with_deepspeed_multi_gpu(self):
ds_config = {
"zero_optimization": {
"stage": 3,
},
}
with _deepspeed_zero3(ds_config):
self.test_resize_embeddings_untied()
def test_model_get_set_embeddings(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()