[models] respect dtype of the model when instantiating it (#12316)

* [models] respect dtype of the model when instantiating it

* cleanup

* cleanup

* rework to handle non-float dtype

* fix

* switch to fp32 tiny model

* improve

* use dtype.is_floating_point

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* fix the doc

* recode to use explicit torch_dtype_auto_detect, torch_dtype args

* docs and tweaks

* docs and tweaks

* docs and tweaks

* merge 2 args, add docs

* fix

* fix

* better doc

* better doc

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Stas Bekman
2021-06-28 20:11:21 -07:00
committed by GitHub
parent 31c3e7e75b
commit 7682e97702
8 changed files with 221 additions and 26 deletions

View File

@@ -25,7 +25,7 @@ from typing import Dict, List, Tuple
from huggingface_hub import HfApi
from requests.exceptions import HTTPError
from transformers import is_torch_available, logging
from transformers import AutoModel, is_torch_available, logging
from transformers.file_utils import WEIGHTS_NAME, is_torch_fx_available
from transformers.models.auto import get_values
from transformers.testing_utils import (
@@ -33,6 +33,7 @@ from transformers.testing_utils import (
PASS,
USER,
CaptureLogger,
TestCasePlus,
is_staging_test,
require_torch,
require_torch_multi_gpu,
@@ -63,6 +64,7 @@ if is_torch_available():
BertModel,
PretrainedConfig,
PreTrainedModel,
T5Config,
T5ForConditionalGeneration,
)
@@ -1574,7 +1576,7 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None):
@require_torch
class ModelUtilsTest(unittest.TestCase):
class ModelUtilsTest(TestCasePlus):
@slow
def test_model_from_pretrained(self):
for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
@@ -1607,6 +1609,60 @@ class ModelUtilsTest(unittest.TestCase):
BertModel.from_pretrained(TINY_T5)
self.assertTrue("You are using a model of type t5 to instantiate a model of type bert" in cl.out)
@require_torch
def test_model_from_config_torch_dtype(self):
# test that the model can be instantiated with dtype of user's choice - as long as it's a
# float dtype. To make it happen config.torch_dtype needs to be set before instantiating the
# model from the config object.
config = T5Config.from_pretrained(TINY_T5)
model = AutoModel.from_config(config)
# XXX: isn't supported
# model = T5ForConditionalGeneration.from_config(config)
self.assertEqual(model.dtype, torch.float32)
model = AutoModel.from_config(config, torch_dtype=torch.float16)
self.assertEqual(model.dtype, torch.float16)
# torch.set_default_dtype() supports only float dtypes, so will fail with non-float type
with self.assertRaises(ValueError):
model = AutoModel.from_config(config, torch_dtype=torch.int64)
@require_torch
def test_model_from_pretrained_torch_dtype(self):
# test that the model can be instantiated with dtype of either
# 1. config.torch_dtype setting in the saved model (priority)
# 2. via autodiscovery by looking at model weights
# so if a model.half() was saved, we want it to be instantiated as such.
model_path = self.get_auto_remove_tmp_dir()
# baseline - we know TINY_T5 is fp32 model
model = T5ForConditionalGeneration.from_pretrained(TINY_T5)
self.assertEqual(model.dtype, torch.float32)
# test the default fp32 save_pretrained => from_pretrained cycle
model.save_pretrained(model_path)
model = T5ForConditionalGeneration.from_pretrained(model_path)
self.assertEqual(model.dtype, torch.float32)
# test with auto-detection
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto")
self.assertEqual(model.dtype, torch.float32)
# test forced loading in fp16 (even though the weights are in fp32)
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16)
self.assertEqual(model.dtype, torch.float16)
# test fp16 save_pretrained, loaded with auto-detection
model = model.half()
model.save_pretrained(model_path)
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto")
self.assertEqual(model.config.torch_dtype, "float16") # tests `config.torch_dtype` saving
self.assertEqual(model.dtype, torch.float16)
# test fp16 save_pretrained, loaded with the explicit fp16
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16)
self.assertEqual(model.dtype, torch.float16)
@require_torch
@is_staging_test