[from_pretrained] Make from_pretrained fast again (#27709)
* Skip nn.Module.reset_parameters * Actually skip * Check quality * Maybe change all inits * Fix init issues: only modify public functions * Add a small test for now * Style * test updates * style * nice tes * style * make it even faster * one more second * remove fx icompatible * Update tests/test_modeling_common.py Co-authored-by: Lysandre Debut <hi@lysand.re> * Update tests/test_modeling_common.py Co-authored-by: Lysandre Debut <hi@lysand.re> * skip * fix quality * protect the import --------- Co-authored-by: Lysandre Debut <hi@lysand.re>
This commit is contained in:
@@ -154,6 +154,23 @@ else:
|
|||||||
if is_peft_available():
|
if is_peft_available():
|
||||||
from .utils import find_adapter_config_file
|
from .utils import find_adapter_config_file
|
||||||
|
|
||||||
|
TORCH_INIT_FUNCTIONS = {
|
||||||
|
"uniform_": nn.init.uniform_,
|
||||||
|
"normal_": nn.init.normal_,
|
||||||
|
"trunc_normal_": nn.init.trunc_normal_,
|
||||||
|
"constant_": nn.init.constant_,
|
||||||
|
"xavier_uniform_": nn.init.xavier_uniform_,
|
||||||
|
"xavier_normal_": nn.init.xavier_normal_,
|
||||||
|
"kaiming_uniform_": nn.init.kaiming_uniform_,
|
||||||
|
"kaiming_normal_": nn.init.kaiming_normal_,
|
||||||
|
"uniform": nn.init.uniform,
|
||||||
|
"normal": nn.init.normal,
|
||||||
|
"xavier_uniform": nn.init.xavier_uniform,
|
||||||
|
"xavier_normal": nn.init.xavier_normal,
|
||||||
|
"kaiming_uniform": nn.init.kaiming_uniform,
|
||||||
|
"kaiming_normal": nn.init.kaiming_normal,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def no_init_weights(_enable=True):
|
def no_init_weights(_enable=True):
|
||||||
@@ -164,12 +181,24 @@ def no_init_weights(_enable=True):
|
|||||||
"""
|
"""
|
||||||
global _init_weights
|
global _init_weights
|
||||||
old_init_weights = _init_weights
|
old_init_weights = _init_weights
|
||||||
|
|
||||||
if _enable:
|
if _enable:
|
||||||
_init_weights = False
|
_init_weights = False
|
||||||
|
|
||||||
|
def _skip_init(*args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# # Save the original initialization functions
|
||||||
|
for name, init_func in TORCH_INIT_FUNCTIONS.items():
|
||||||
|
setattr(torch.nn.init, name, _skip_init)
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
_init_weights = old_init_weights
|
_init_weights = old_init_weights
|
||||||
|
if _enable:
|
||||||
|
# # Restore the original initialization functions
|
||||||
|
for name, init_func in TORCH_INIT_FUNCTIONS.items():
|
||||||
|
setattr(torch.nn.init, name, init_func)
|
||||||
|
|
||||||
|
|
||||||
def get_parameter_device(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]):
|
def get_parameter_device(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]):
|
||||||
@@ -1506,7 +1535,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
"""
|
"""
|
||||||
Initialize the weights. This method should be overridden by derived class.
|
Initialize the weights. This method should be overridden by derived class and is
|
||||||
|
the only initialization method that will be called when loading a checkpoint
|
||||||
|
using `from_pretrained`. Any attempt to initialize outside of this function
|
||||||
|
will be useless as the torch.nn.init function are all replaced with skip.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -3414,6 +3446,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
)
|
)
|
||||||
|
|
||||||
with ContextManagers(init_contexts):
|
with ContextManagers(init_contexts):
|
||||||
|
# Let's make sure we don't run the init function of buffer modules
|
||||||
model = cls(config, *model_args, **model_kwargs)
|
model = cls(config, *model_args, **model_kwargs)
|
||||||
|
|
||||||
# make sure we use the model's config since the __init__ call might have copied it
|
# make sure we use the model's config since the __init__ call might have copied it
|
||||||
|
|||||||
@@ -36,8 +36,10 @@ from transformers import (
|
|||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoModelForSequenceClassification,
|
AutoModelForSequenceClassification,
|
||||||
PretrainedConfig,
|
PretrainedConfig,
|
||||||
|
PreTrainedModel,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
logging,
|
logging,
|
||||||
|
set_seed,
|
||||||
)
|
)
|
||||||
from transformers.models.auto import get_values
|
from transformers.models.auto import get_values
|
||||||
from transformers.models.auto.modeling_auto import (
|
from transformers.models.auto.modeling_auto import (
|
||||||
@@ -85,7 +87,7 @@ from transformers.utils import (
|
|||||||
is_torch_fx_available,
|
is_torch_fx_available,
|
||||||
is_torch_sdpa_available,
|
is_torch_sdpa_available,
|
||||||
)
|
)
|
||||||
from transformers.utils.generic import ModelOutput
|
from transformers.utils.generic import ContextManagers, ModelOutput
|
||||||
|
|
||||||
|
|
||||||
if is_accelerate_available():
|
if is_accelerate_available():
|
||||||
@@ -99,6 +101,7 @@ if is_torch_available():
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from transformers import MODEL_MAPPING, AdaptiveEmbedding
|
from transformers import MODEL_MAPPING, AdaptiveEmbedding
|
||||||
|
from transformers.modeling_utils import no_init_weights
|
||||||
from transformers.pytorch_utils import id_tensor_storage
|
from transformers.pytorch_utils import id_tensor_storage
|
||||||
|
|
||||||
|
|
||||||
@@ -428,6 +431,56 @@ class ModelTesterMixin:
|
|||||||
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
|
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
|
||||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||||
|
|
||||||
|
def test_fast_init_context_manager(self):
|
||||||
|
# 1. Create a dummy class. Should have buffers as well? To make sure we test __init__
|
||||||
|
class MyClass(PreTrainedModel):
|
||||||
|
config_class = PretrainedConfig
|
||||||
|
|
||||||
|
def __init__(self, config=None):
|
||||||
|
super().__init__(config if config is not None else PretrainedConfig())
|
||||||
|
self.linear = nn.Linear(10, 10, bias=True)
|
||||||
|
self.embedding = nn.Embedding(10, 10)
|
||||||
|
self.std = 1
|
||||||
|
|
||||||
|
def _init_weights(self, module):
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
module.weight.data = nn.init.kaiming_uniform_(module.weight.data, np.sqrt(5))
|
||||||
|
if module.bias is not None:
|
||||||
|
module.bias.data.normal_(mean=0.0, std=self.std)
|
||||||
|
|
||||||
|
# 2. Make sure a linear layer's reset params is properly skipped:
|
||||||
|
with ContextManagers([no_init_weights(True)]):
|
||||||
|
no_init_instance = MyClass()
|
||||||
|
|
||||||
|
set_seed(0)
|
||||||
|
expected_bias = torch.tensor(
|
||||||
|
([0.2975, 0.2131, -0.1379, -0.0796, -0.3012, -0.0057, -0.2381, -0.2439, -0.0174, 0.0475])
|
||||||
|
)
|
||||||
|
init_instance = MyClass()
|
||||||
|
torch.testing.assert_allclose(init_instance.linear.bias, expected_bias, rtol=1e-3, atol=1e-4)
|
||||||
|
|
||||||
|
set_seed(0)
|
||||||
|
torch.testing.assert_allclose(
|
||||||
|
init_instance.linear.weight, nn.init.kaiming_uniform_(no_init_instance.linear.weight, np.sqrt(5))
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Make sure weights that are not present use init_weight_ and get expected values
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
state_dict = init_instance.state_dict()
|
||||||
|
del state_dict["linear.weight"]
|
||||||
|
|
||||||
|
init_instance.config.save_pretrained(tmpdirname)
|
||||||
|
torch.save(state_dict, os.path.join(tmpdirname, "pytorch_model.bin"))
|
||||||
|
set_seed(0)
|
||||||
|
model_fast_init = MyClass.from_pretrained(tmpdirname)
|
||||||
|
|
||||||
|
set_seed(0)
|
||||||
|
model_slow_init = MyClass.from_pretrained(tmpdirname, _fast_init=False)
|
||||||
|
|
||||||
|
for key in model_fast_init.state_dict().keys():
|
||||||
|
max_diff = torch.max(torch.abs(model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]))
|
||||||
|
self.assertLessEqual(max_diff.item(), 1e-3, msg=f"{key} not identical")
|
||||||
|
|
||||||
def test_save_load_fast_init_to_base(self):
|
def test_save_load_fast_init_to_base(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
if config.__class__ not in MODEL_MAPPING:
|
if config.__class__ not in MODEL_MAPPING:
|
||||||
|
|||||||
Reference in New Issue
Block a user