[from_pretrained] Make from_pretrained fast again (#27709)
* Skip nn.Module.reset_parameters * Actually skip * Check quality * Maybe change all inits * Fix init issues: only modify public functions * Add a small test for now * Style * test updates * style * nice tes * style * make it even faster * one more second * remove fx icompatible * Update tests/test_modeling_common.py Co-authored-by: Lysandre Debut <hi@lysand.re> * Update tests/test_modeling_common.py Co-authored-by: Lysandre Debut <hi@lysand.re> * skip * fix quality * protect the import --------- Co-authored-by: Lysandre Debut <hi@lysand.re>
This commit is contained in:
@@ -36,8 +36,10 @@ from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForSequenceClassification,
|
||||
PretrainedConfig,
|
||||
PreTrainedModel,
|
||||
is_torch_available,
|
||||
logging,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.models.auto.modeling_auto import (
|
||||
@@ -85,7 +87,7 @@ from transformers.utils import (
|
||||
is_torch_fx_available,
|
||||
is_torch_sdpa_available,
|
||||
)
|
||||
from transformers.utils.generic import ModelOutput
|
||||
from transformers.utils.generic import ContextManagers, ModelOutput
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
@@ -99,6 +101,7 @@ if is_torch_available():
|
||||
from torch import nn
|
||||
|
||||
from transformers import MODEL_MAPPING, AdaptiveEmbedding
|
||||
from transformers.modeling_utils import no_init_weights
|
||||
from transformers.pytorch_utils import id_tensor_storage
|
||||
|
||||
|
||||
@@ -428,6 +431,56 @@ class ModelTesterMixin:
|
||||
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
def test_fast_init_context_manager(self):
|
||||
# 1. Create a dummy class. Should have buffers as well? To make sure we test __init__
|
||||
class MyClass(PreTrainedModel):
|
||||
config_class = PretrainedConfig
|
||||
|
||||
def __init__(self, config=None):
|
||||
super().__init__(config if config is not None else PretrainedConfig())
|
||||
self.linear = nn.Linear(10, 10, bias=True)
|
||||
self.embedding = nn.Embedding(10, 10)
|
||||
self.std = 1
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data = nn.init.kaiming_uniform_(module.weight.data, np.sqrt(5))
|
||||
if module.bias is not None:
|
||||
module.bias.data.normal_(mean=0.0, std=self.std)
|
||||
|
||||
# 2. Make sure a linear layer's reset params is properly skipped:
|
||||
with ContextManagers([no_init_weights(True)]):
|
||||
no_init_instance = MyClass()
|
||||
|
||||
set_seed(0)
|
||||
expected_bias = torch.tensor(
|
||||
([0.2975, 0.2131, -0.1379, -0.0796, -0.3012, -0.0057, -0.2381, -0.2439, -0.0174, 0.0475])
|
||||
)
|
||||
init_instance = MyClass()
|
||||
torch.testing.assert_allclose(init_instance.linear.bias, expected_bias, rtol=1e-3, atol=1e-4)
|
||||
|
||||
set_seed(0)
|
||||
torch.testing.assert_allclose(
|
||||
init_instance.linear.weight, nn.init.kaiming_uniform_(no_init_instance.linear.weight, np.sqrt(5))
|
||||
)
|
||||
|
||||
# 3. Make sure weights that are not present use init_weight_ and get expected values
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
state_dict = init_instance.state_dict()
|
||||
del state_dict["linear.weight"]
|
||||
|
||||
init_instance.config.save_pretrained(tmpdirname)
|
||||
torch.save(state_dict, os.path.join(tmpdirname, "pytorch_model.bin"))
|
||||
set_seed(0)
|
||||
model_fast_init = MyClass.from_pretrained(tmpdirname)
|
||||
|
||||
set_seed(0)
|
||||
model_slow_init = MyClass.from_pretrained(tmpdirname, _fast_init=False)
|
||||
|
||||
for key in model_fast_init.state_dict().keys():
|
||||
max_diff = torch.max(torch.abs(model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]))
|
||||
self.assertLessEqual(max_diff.item(), 1e-3, msg=f"{key} not identical")
|
||||
|
||||
def test_save_load_fast_init_to_base(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if config.__class__ not in MODEL_MAPPING:
|
||||
|
||||
Reference in New Issue
Block a user