Fix: take into account meta device (#34134)
* Do not load for meta device * Make some minor improvements * Add test * Update tests/utils/test_modeling_utils.py Update test parameters Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Make the test simpler --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import glob
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import os.path
|
||||
@@ -459,6 +460,19 @@ class ModelUtilsTest(TestCasePlus):
|
||||
with self.assertRaises(ValueError):
|
||||
model = AutoModel.from_pretrained(TINY_T5, torch_dtype="int64")
|
||||
|
||||
@require_torch
|
||||
def test_model_from_pretrained_meta_device(self):
|
||||
def is_on_meta(model_id, dtype):
|
||||
with torch.device("meta"):
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype)
|
||||
return all(value.device.type == "meta" for value in model.state_dict().values())
|
||||
|
||||
model_ids = ("fxmarty/tiny-llama-fast-tokenizer", "fxmarty/small-llama-testing")
|
||||
dtypes = (None, "auto", torch.float16)
|
||||
|
||||
for model_id, dtype in itertools.product(model_ids, dtypes):
|
||||
self.assertTrue(is_on_meta(model_id, dtype))
|
||||
|
||||
def test_model_from_pretrained_torch_dtype(self):
|
||||
# test that the model can be instantiated with dtype of either
|
||||
# 1. explicit from_pretrained's torch_dtype argument
|
||||
|
||||
Reference in New Issue
Block a user