Add loading speed test (#36671)

* Update test_modeling_utils.py

* Update test_modeling_utils.py

* Update test_modeling_utils.py

* Update test_modeling_utils.py

* Update test_modeling_utils.py

* Update test_modeling_utils.py

* trigger CIs

* Update test_modeling_utils.py

* Update test_modeling_utils.py

* Update test_modeling_utils.py

* better error messages

* Update test_modeling_utils.py

* Update test_modeling_utils.py
This commit is contained in:
Cyril Vallez
2025-03-13 17:07:30 +01:00
committed by GitHub
parent a3201cea14
commit 2a004f9ff1

View File

@@ -17,8 +17,10 @@ import glob
import json import json
import os import os
import os.path import os.path
import subprocess
import sys import sys
import tempfile import tempfile
import textwrap
import threading import threading
import unittest import unittest
import unittest.mock as mock import unittest.mock as mock
@@ -28,6 +30,7 @@ from pathlib import Path
import requests import requests
from huggingface_hub import HfApi, HfFolder from huggingface_hub import HfApi, HfFolder
from parameterized import parameterized
from pytest import mark from pytest import mark
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
@@ -55,10 +58,12 @@ from transformers.testing_utils import (
is_staging_test, is_staging_test,
require_accelerate, require_accelerate,
require_flax, require_flax,
require_read_token,
require_safetensors, require_safetensors,
require_tf, require_tf,
require_torch, require_torch,
require_torch_accelerator, require_torch_accelerator,
require_torch_gpu,
require_torch_multi_accelerator, require_torch_multi_accelerator,
require_usr_bin_time, require_usr_bin_time,
slow, slow,
@@ -1900,6 +1905,61 @@ class ModelUtilsTest(TestCasePlus):
self.assertEqual(len(cm.records), 1) self.assertEqual(len(cm.records), 1)
self.assertTrue(cm.records[0].message.startswith("Unknown quantization type, got")) self.assertTrue(cm.records[0].message.startswith("Unknown quantization type, got"))
@parameterized.expand([("Qwen/Qwen2.5-3B-Instruct", 10), ("meta-llama/Llama-2-7b-chat-hf", 10)])
@slow
@require_read_token
@require_torch_gpu
def test_loading_is_fast_on_gpu(self, model_id: str, max_loading_time: float):
"""
This test is used to avoid regresion on https://github.com/huggingface/transformers/pull/36380.
10s should be more than enough for both models, and allows for some margin as loading time are quite
unstable. Before #36380, it used to take more than 40s, so 10s is still reasonable.
Note that we run this test in a subprocess, to ensure that cuda is not already initialized/warmed-up.
"""
# First download the weights if not already on disk
_ = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16)
script_to_run = textwrap.dedent(
"""
import torch
import time
import argparse
from transformers import AutoModelForCausalLM
parser = argparse.ArgumentParser()
parser.add_argument("model_id", type=str)
parser.add_argument("max_loading_time", type=float)
args = parser.parse_args()
device = torch.device("cuda:0")
torch.cuda.synchronize(device)
t0 = time.time()
model = AutoModelForCausalLM.from_pretrained(args.model_id, torch_dtype=torch.float16, device_map=device)
torch.cuda.synchronize(device)
dt = time.time() - t0
# Assert loading is faster (it should be more than enough in both cases)
if dt > args.max_loading_time:
raise ValueError(f"Loading took {dt:.2f}s! It should not take more than {args.max_loading_time}s")
# Ensure everything is correctly loaded on gpu
bad_device_params = {k for k, v in model.named_parameters() if v.device != device}
if len(bad_device_params) > 0:
raise ValueError(f"The following parameters are not on GPU: {bad_device_params}")
"""
)
with tempfile.NamedTemporaryFile(mode="w+", suffix=".py") as tmp:
tmp.write(script_to_run)
tmp.flush()
tmp.seek(0)
cmd = f"python {tmp.name} {model_id} {max_loading_time}".split()
try:
# We cannot use a timeout of `max_loading_time` as cuda initialization can take up to 15-20s
_ = subprocess.run(cmd, capture_output=True, env=self.get_env(), text=True, check=True, timeout=60)
except subprocess.CalledProcessError as e:
raise Exception(f"The following error was captured: {e.stderr}")
@slow @slow
@require_torch @require_torch