[from_pretrained] Make from_pretrained fast again (#27709)

* Skip nn.Module.reset_parameters

* Actually skip

* Check quality

* Maybe change all inits

* Fix init issues: only modify public functions

* Add a small test for now

* Style

* test updates

* style

* nice tes

* style

* make it even faster

* one more second

* remove fx icompatible

* Update tests/test_modeling_common.py

Co-authored-by: Lysandre Debut <hi@lysand.re>

* Update tests/test_modeling_common.py

Co-authored-by: Lysandre Debut <hi@lysand.re>

* skip

* fix quality

* protect the import

---------

Co-authored-by: Lysandre Debut <hi@lysand.re>
This commit is contained in:
Arthur
2023-12-11 12:38:17 +01:00
committed by GitHub
parent 9f18cc6df0
commit 0676d992a5
2 changed files with 88 additions and 2 deletions

View File

@@ -36,8 +36,10 @@ from transformers import (
AutoModelForCausalLM,
AutoModelForSequenceClassification,
PretrainedConfig,
PreTrainedModel,
is_torch_available,
logging,
set_seed,
)
from transformers.models.auto import get_values
from transformers.models.auto.modeling_auto import (
@@ -85,7 +87,7 @@ from transformers.utils import (
is_torch_fx_available,
is_torch_sdpa_available,
)
from transformers.utils.generic import ModelOutput
from transformers.utils.generic import ContextManagers, ModelOutput
if is_accelerate_available():
@@ -99,6 +101,7 @@ if is_torch_available():
from torch import nn
from transformers import MODEL_MAPPING, AdaptiveEmbedding
from transformers.modeling_utils import no_init_weights
from transformers.pytorch_utils import id_tensor_storage
@@ -428,6 +431,56 @@ class ModelTesterMixin:
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
def test_fast_init_context_manager(self):
# 1. Create a dummy class. Should have buffers as well? To make sure we test __init__
class MyClass(PreTrainedModel):
config_class = PretrainedConfig
def __init__(self, config=None):
super().__init__(config if config is not None else PretrainedConfig())
self.linear = nn.Linear(10, 10, bias=True)
self.embedding = nn.Embedding(10, 10)
self.std = 1
def _init_weights(self, module):
if isinstance(module, nn.Linear):
module.weight.data = nn.init.kaiming_uniform_(module.weight.data, np.sqrt(5))
if module.bias is not None:
module.bias.data.normal_(mean=0.0, std=self.std)
# 2. Make sure a linear layer's reset params is properly skipped:
with ContextManagers([no_init_weights(True)]):
no_init_instance = MyClass()
set_seed(0)
expected_bias = torch.tensor(
([0.2975, 0.2131, -0.1379, -0.0796, -0.3012, -0.0057, -0.2381, -0.2439, -0.0174, 0.0475])
)
init_instance = MyClass()
torch.testing.assert_allclose(init_instance.linear.bias, expected_bias, rtol=1e-3, atol=1e-4)
set_seed(0)
torch.testing.assert_allclose(
init_instance.linear.weight, nn.init.kaiming_uniform_(no_init_instance.linear.weight, np.sqrt(5))
)
# 3. Make sure weights that are not present use init_weight_ and get expected values
with tempfile.TemporaryDirectory() as tmpdirname:
state_dict = init_instance.state_dict()
del state_dict["linear.weight"]
init_instance.config.save_pretrained(tmpdirname)
torch.save(state_dict, os.path.join(tmpdirname, "pytorch_model.bin"))
set_seed(0)
model_fast_init = MyClass.from_pretrained(tmpdirname)
set_seed(0)
model_slow_init = MyClass.from_pretrained(tmpdirname, _fast_init=False)
for key in model_fast_init.state_dict().keys():
max_diff = torch.max(torch.abs(model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]))
self.assertLessEqual(max_diff.item(), 1e-3, msg=f"{key} not identical")
def test_save_load_fast_init_to_base(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if config.__class__ not in MODEL_MAPPING: