Enhance Model Loading By Providing Parallelism, Uses Optional Env Flag (#36835)
* Get parallel loader working. Include tests. * Update the tests for parallel loading * Rename env variables. * Add docs for parallel model weight loading. * Touch up parallel model loading docs. * Touch up parallel model loading docs again. * Edit comment in test_modeling_utils_parallel_loading.py * Make sure HF_PARALLEL_LOADING_WORKERS is spelled correctly in modeling_utils.py * Correct times for parallelized loading, previous times were for a "hot" filesystem * Update parallel model loading so the spawn method is encapsulated. DRY up the code by leveraging get_submodule. * Update docs on model loading parallelism so that details on setting the multiprocessing start method are removed, now that the package handles this step internally. * Fix style on model loading parallelism changes. * Merge latest version of master's modeling_utils. * Removed unused variable. * Fix argument packing for the parallel loader. * Fix state dict being undefined in the parallel model loader. * Rename variables used in parallel model loading for clarity. Use get_module_from_name(). * Switch to the use of threads for parallel model loading. * Update docs for parallel loading. * Remove the use of json.loads when evaluating HF_ENABLE_PARALLEL_LOADING. Prefer simple casting. * Move parallelized shard loading into its own function. * Remove use of is_true(). Favor checking env var true values for HF_ENABLE_PARALLEL_LOADING. * Update copyright to 2025 in readme for paralell model loading. * Remove garbage collection line in load_shard_file, implicit garbage collection already occurs. * Run formatter on modeling_utils.py * Apply style fixes * Delete tests/utils/test_modeling_utils_parallel_loading.py --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
This commit is contained in:
@@ -297,6 +297,27 @@ if is_torch_available():
|
||||
hub.TRANSFORMERS_CACHE = transformers_cache
|
||||
|
||||
|
||||
# Need to be serializable, which means they cannot be in a test class method
|
||||
class TestGammaBetaNorm(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.gamma = torch.nn.Parameter(torch.ones(1))
|
||||
self.beta = torch.nn.Parameter(torch.zeros(1))
|
||||
|
||||
def forward(self):
|
||||
return self.gamma.sum() + self.beta.sum()
|
||||
|
||||
|
||||
class TestModelGammaBeta(PreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.LayerNorm = TestGammaBetaNorm()
|
||||
self.post_init()
|
||||
|
||||
def forward(self):
|
||||
return self.LayerNorm()
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
from transformers import FlaxBertModel
|
||||
|
||||
@@ -1636,24 +1657,6 @@ class ModelUtilsTest(TestCasePlus):
|
||||
torch.testing.assert_close(outputs_from_saved["logits"], outputs["logits"])
|
||||
|
||||
def test_warning_for_beta_gamma_parameters(self):
|
||||
class TestGammaBetaNorm(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.gamma = torch.nn.Parameter(torch.ones(1))
|
||||
self.beta = torch.nn.Parameter(torch.zeros(1))
|
||||
|
||||
def forward(self):
|
||||
return self.gamma.sum() + self.beta.sum()
|
||||
|
||||
class TestModelGammaBeta(PreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.LayerNorm = TestGammaBetaNorm()
|
||||
self.post_init()
|
||||
|
||||
def forward(self):
|
||||
return self.LayerNorm()
|
||||
|
||||
logger = logging.get_logger("transformers.modeling_utils")
|
||||
config = PretrainedConfig()
|
||||
warning_msg_gamma = "`LayerNorm.gamma` -> `LayerNorm.weight`"
|
||||
|
||||
Reference in New Issue
Block a user