[models] respect dtype of the model when instantiating it (#12316)
* [models] respect dtype of the model when instantiating it * cleanup * cleanup * rework to handle non-float dtype * fix * switch to fp32 tiny model * improve * use dtype.is_floating_point * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * fix the doc * recode to use explicit torch_dtype_auto_detect, torch_dtype args * docs and tweaks * docs and tweaks * docs and tweaks * merge 2 args, add docs * fix * fix * better doc * better doc Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -25,7 +25,7 @@ from typing import Dict, List, Tuple
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers import is_torch_available, logging
|
||||
from transformers import AutoModel, is_torch_available, logging
|
||||
from transformers.file_utils import WEIGHTS_NAME, is_torch_fx_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import (
|
||||
@@ -33,6 +33,7 @@ from transformers.testing_utils import (
|
||||
PASS,
|
||||
USER,
|
||||
CaptureLogger,
|
||||
TestCasePlus,
|
||||
is_staging_test,
|
||||
require_torch,
|
||||
require_torch_multi_gpu,
|
||||
@@ -63,6 +64,7 @@ if is_torch_available():
|
||||
BertModel,
|
||||
PretrainedConfig,
|
||||
PreTrainedModel,
|
||||
T5Config,
|
||||
T5ForConditionalGeneration,
|
||||
)
|
||||
|
||||
@@ -1574,7 +1576,7 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None):
|
||||
|
||||
|
||||
@require_torch
|
||||
class ModelUtilsTest(unittest.TestCase):
|
||||
class ModelUtilsTest(TestCasePlus):
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
@@ -1607,6 +1609,60 @@ class ModelUtilsTest(unittest.TestCase):
|
||||
BertModel.from_pretrained(TINY_T5)
|
||||
self.assertTrue("You are using a model of type t5 to instantiate a model of type bert" in cl.out)
|
||||
|
||||
@require_torch
|
||||
def test_model_from_config_torch_dtype(self):
|
||||
# test that the model can be instantiated with dtype of user's choice - as long as it's a
|
||||
# float dtype. To make it happen config.torch_dtype needs to be set before instantiating the
|
||||
# model from the config object.
|
||||
|
||||
config = T5Config.from_pretrained(TINY_T5)
|
||||
model = AutoModel.from_config(config)
|
||||
# XXX: isn't supported
|
||||
# model = T5ForConditionalGeneration.from_config(config)
|
||||
self.assertEqual(model.dtype, torch.float32)
|
||||
|
||||
model = AutoModel.from_config(config, torch_dtype=torch.float16)
|
||||
self.assertEqual(model.dtype, torch.float16)
|
||||
|
||||
# torch.set_default_dtype() supports only float dtypes, so will fail with non-float type
|
||||
with self.assertRaises(ValueError):
|
||||
model = AutoModel.from_config(config, torch_dtype=torch.int64)
|
||||
|
||||
@require_torch
|
||||
def test_model_from_pretrained_torch_dtype(self):
|
||||
# test that the model can be instantiated with dtype of either
|
||||
# 1. config.torch_dtype setting in the saved model (priority)
|
||||
# 2. via autodiscovery by looking at model weights
|
||||
# so if a model.half() was saved, we want it to be instantiated as such.
|
||||
model_path = self.get_auto_remove_tmp_dir()
|
||||
|
||||
# baseline - we know TINY_T5 is fp32 model
|
||||
model = T5ForConditionalGeneration.from_pretrained(TINY_T5)
|
||||
self.assertEqual(model.dtype, torch.float32)
|
||||
|
||||
# test the default fp32 save_pretrained => from_pretrained cycle
|
||||
model.save_pretrained(model_path)
|
||||
model = T5ForConditionalGeneration.from_pretrained(model_path)
|
||||
self.assertEqual(model.dtype, torch.float32)
|
||||
# test with auto-detection
|
||||
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto")
|
||||
self.assertEqual(model.dtype, torch.float32)
|
||||
|
||||
# test forced loading in fp16 (even though the weights are in fp32)
|
||||
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16)
|
||||
self.assertEqual(model.dtype, torch.float16)
|
||||
|
||||
# test fp16 save_pretrained, loaded with auto-detection
|
||||
model = model.half()
|
||||
model.save_pretrained(model_path)
|
||||
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto")
|
||||
self.assertEqual(model.config.torch_dtype, "float16") # tests `config.torch_dtype` saving
|
||||
self.assertEqual(model.dtype, torch.float16)
|
||||
|
||||
# test fp16 save_pretrained, loaded with the explicit fp16
|
||||
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16)
|
||||
self.assertEqual(model.dtype, torch.float16)
|
||||
|
||||
|
||||
@require_torch
|
||||
@is_staging_test
|
||||
|
||||
Reference in New Issue
Block a user