[from_pretrained] Make from_pretrained fast again (#27709)

* Skip nn.Module.reset_parameters

* Actually skip

* Check quality

* Maybe change all inits

* Fix init issues: only modify public functions

* Add a small test for now

* Style

* test updates

* style

* nice tes

* style

* make it even faster

* one more second

* remove fx icompatible

* Update tests/test_modeling_common.py

Co-authored-by: Lysandre Debut <hi@lysand.re>

* Update tests/test_modeling_common.py

Co-authored-by: Lysandre Debut <hi@lysand.re>

* skip

* fix quality

* protect the import

---------

Co-authored-by: Lysandre Debut <hi@lysand.re>
This commit is contained in:
Arthur
2023-12-11 12:38:17 +01:00
committed by GitHub
parent 9f18cc6df0
commit 0676d992a5
2 changed files with 88 additions and 2 deletions

View File

@@ -154,6 +154,23 @@ else:
if is_peft_available(): if is_peft_available():
from .utils import find_adapter_config_file from .utils import find_adapter_config_file
TORCH_INIT_FUNCTIONS = {
"uniform_": nn.init.uniform_,
"normal_": nn.init.normal_,
"trunc_normal_": nn.init.trunc_normal_,
"constant_": nn.init.constant_,
"xavier_uniform_": nn.init.xavier_uniform_,
"xavier_normal_": nn.init.xavier_normal_,
"kaiming_uniform_": nn.init.kaiming_uniform_,
"kaiming_normal_": nn.init.kaiming_normal_,
"uniform": nn.init.uniform,
"normal": nn.init.normal,
"xavier_uniform": nn.init.xavier_uniform,
"xavier_normal": nn.init.xavier_normal,
"kaiming_uniform": nn.init.kaiming_uniform,
"kaiming_normal": nn.init.kaiming_normal,
}
@contextmanager @contextmanager
def no_init_weights(_enable=True): def no_init_weights(_enable=True):
@@ -164,12 +181,24 @@ def no_init_weights(_enable=True):
""" """
global _init_weights global _init_weights
old_init_weights = _init_weights old_init_weights = _init_weights
if _enable: if _enable:
_init_weights = False _init_weights = False
def _skip_init(*args, **kwargs):
pass
# # Save the original initialization functions
for name, init_func in TORCH_INIT_FUNCTIONS.items():
setattr(torch.nn.init, name, _skip_init)
try: try:
yield yield
finally: finally:
_init_weights = old_init_weights _init_weights = old_init_weights
if _enable:
# # Restore the original initialization functions
for name, init_func in TORCH_INIT_FUNCTIONS.items():
setattr(torch.nn.init, name, init_func)
def get_parameter_device(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]): def get_parameter_device(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]):
@@ -1506,7 +1535,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
def _init_weights(self, module): def _init_weights(self, module):
""" """
Initialize the weights. This method should be overridden by derived class. Initialize the weights. This method should be overridden by derived class and is
the only initialization method that will be called when loading a checkpoint
using `from_pretrained`. Any attempt to initialize outside of this function
will be useless as the torch.nn.init function are all replaced with skip.
""" """
pass pass
@@ -3414,6 +3446,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
) )
with ContextManagers(init_contexts): with ContextManagers(init_contexts):
# Let's make sure we don't run the init function of buffer modules
model = cls(config, *model_args, **model_kwargs) model = cls(config, *model_args, **model_kwargs)
# make sure we use the model's config since the __init__ call might have copied it # make sure we use the model's config since the __init__ call might have copied it

View File

@@ -36,8 +36,10 @@ from transformers import (
AutoModelForCausalLM, AutoModelForCausalLM,
AutoModelForSequenceClassification, AutoModelForSequenceClassification,
PretrainedConfig, PretrainedConfig,
PreTrainedModel,
is_torch_available, is_torch_available,
logging, logging,
set_seed,
) )
from transformers.models.auto import get_values from transformers.models.auto import get_values
from transformers.models.auto.modeling_auto import ( from transformers.models.auto.modeling_auto import (
@@ -85,7 +87,7 @@ from transformers.utils import (
is_torch_fx_available, is_torch_fx_available,
is_torch_sdpa_available, is_torch_sdpa_available,
) )
from transformers.utils.generic import ModelOutput from transformers.utils.generic import ContextManagers, ModelOutput
if is_accelerate_available(): if is_accelerate_available():
@@ -99,6 +101,7 @@ if is_torch_available():
from torch import nn from torch import nn
from transformers import MODEL_MAPPING, AdaptiveEmbedding from transformers import MODEL_MAPPING, AdaptiveEmbedding
from transformers.modeling_utils import no_init_weights
from transformers.pytorch_utils import id_tensor_storage from transformers.pytorch_utils import id_tensor_storage
@@ -428,6 +431,56 @@ class ModelTesterMixin:
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item() max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
def test_fast_init_context_manager(self):
# 1. Create a dummy class. Should have buffers as well? To make sure we test __init__
class MyClass(PreTrainedModel):
config_class = PretrainedConfig
def __init__(self, config=None):
super().__init__(config if config is not None else PretrainedConfig())
self.linear = nn.Linear(10, 10, bias=True)
self.embedding = nn.Embedding(10, 10)
self.std = 1
def _init_weights(self, module):
if isinstance(module, nn.Linear):
module.weight.data = nn.init.kaiming_uniform_(module.weight.data, np.sqrt(5))
if module.bias is not None:
module.bias.data.normal_(mean=0.0, std=self.std)
# 2. Make sure a linear layer's reset params is properly skipped:
with ContextManagers([no_init_weights(True)]):
no_init_instance = MyClass()
set_seed(0)
expected_bias = torch.tensor(
([0.2975, 0.2131, -0.1379, -0.0796, -0.3012, -0.0057, -0.2381, -0.2439, -0.0174, 0.0475])
)
init_instance = MyClass()
torch.testing.assert_allclose(init_instance.linear.bias, expected_bias, rtol=1e-3, atol=1e-4)
set_seed(0)
torch.testing.assert_allclose(
init_instance.linear.weight, nn.init.kaiming_uniform_(no_init_instance.linear.weight, np.sqrt(5))
)
# 3. Make sure weights that are not present use init_weight_ and get expected values
with tempfile.TemporaryDirectory() as tmpdirname:
state_dict = init_instance.state_dict()
del state_dict["linear.weight"]
init_instance.config.save_pretrained(tmpdirname)
torch.save(state_dict, os.path.join(tmpdirname, "pytorch_model.bin"))
set_seed(0)
model_fast_init = MyClass.from_pretrained(tmpdirname)
set_seed(0)
model_slow_init = MyClass.from_pretrained(tmpdirname, _fast_init=False)
for key in model_fast_init.state_dict().keys():
max_diff = torch.max(torch.abs(model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]))
self.assertLessEqual(max_diff.item(), 1e-3, msg=f"{key} not identical")
def test_save_load_fast_init_to_base(self): def test_save_load_fast_init_to_base(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if config.__class__ not in MODEL_MAPPING: if config.__class__ not in MODEL_MAPPING: