Enhance Model Loading By Providing Parallelism, Uses Optional Env Flag (#36835)

* Get parallel loader working. Include tests.

* Update the tests for parallel loading

* Rename env variables.

* Add docs for parallel model weight loading.

* Touch up parallel model loading docs.

* Touch up parallel model loading docs again.

* Edit comment in test_modeling_utils_parallel_loading.py

* Make sure HF_PARALLEL_LOADING_WORKERS is spelled correctly in modeling_utils.py

* Correct times for parallelized loading, previous times were for a "hot" filesystem

* Update parallel model loading so the spawn method is encapsulated. DRY up the code by leveraging get_submodule.

* Update docs on model loading parallelism so that details on setting the multiprocessing start method are removed, now that the package handles this step internally.

* Fix style on model loading parallelism changes.

* Merge latest version of master's modeling_utils.

* Removed unused variable.

* Fix argument packing for the parallel loader.

* Fix state dict being undefined in the parallel model loader.

* Rename variables used in parallel model loading for clarity. Use get_module_from_name().

* Switch to the use of threads for parallel model loading.

* Update docs for parallel loading.

* Remove the use of json.loads when evaluating HF_ENABLE_PARALLEL_LOADING. Prefer simple casting.

* Move parallelized shard loading into its own function.

* Remove use of is_true(). Favor checking env var true values for HF_ENABLE_PARALLEL_LOADING.

* Update copyright to 2025 in readme for paralell model loading.

* Remove garbage collection line in load_shard_file, implicit garbage collection already occurs.

* Run formatter on modeling_utils.py

* Apply style fixes

* Delete tests/utils/test_modeling_utils_parallel_loading.py

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
This commit is contained in:
Aaron V
2025-05-23 12:39:47 -04:00
committed by GitHub
parent 1ed19360b1
commit d5f992f5e6
4 changed files with 234 additions and 76 deletions

View File

@@ -297,6 +297,27 @@ if is_torch_available():
hub.TRANSFORMERS_CACHE = transformers_cache
# Need to be serializable, which means they cannot be in a test class method
class TestGammaBetaNorm(torch.nn.Module):
def __init__(self):
super().__init__()
self.gamma = torch.nn.Parameter(torch.ones(1))
self.beta = torch.nn.Parameter(torch.zeros(1))
def forward(self):
return self.gamma.sum() + self.beta.sum()
class TestModelGammaBeta(PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.LayerNorm = TestGammaBetaNorm()
self.post_init()
def forward(self):
return self.LayerNorm()
if is_flax_available():
from transformers import FlaxBertModel
@@ -1636,24 +1657,6 @@ class ModelUtilsTest(TestCasePlus):
torch.testing.assert_close(outputs_from_saved["logits"], outputs["logits"])
def test_warning_for_beta_gamma_parameters(self):
class TestGammaBetaNorm(torch.nn.Module):
def __init__(self):
super().__init__()
self.gamma = torch.nn.Parameter(torch.ones(1))
self.beta = torch.nn.Parameter(torch.zeros(1))
def forward(self):
return self.gamma.sum() + self.beta.sum()
class TestModelGammaBeta(PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.LayerNorm = TestGammaBetaNorm()
self.post_init()
def forward(self):
return self.LayerNorm()
logger = logging.get_logger("transformers.modeling_utils")
config = PretrainedConfig()
warning_msg_gamma = "`LayerNorm.gamma` -> `LayerNorm.weight`"