diff --git a/src/transformers/models/auto/auto_factory.py b/src/transformers/models/auto/auto_factory.py index 17cc7f9579..aad113d454 100644 --- a/src/transformers/models/auto/auto_factory.py +++ b/src/transformers/models/auto/auto_factory.py @@ -435,19 +435,24 @@ class _BaseAutoModelClass: ] hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs} if not isinstance(config, PretrainedConfig): - kwargs_copy = copy.deepcopy(kwargs) + kwargs_orig = copy.deepcopy(kwargs) # ensure not to pollute the config object with torch_dtype="auto" - since it's # meaningless in the context of the config object - torch.dtype values are acceptable - if kwargs_copy.get("torch_dtype", None) == "auto": - _ = kwargs_copy.pop("torch_dtype") + if kwargs.get("torch_dtype", None) == "auto": + _ = kwargs.pop("torch_dtype") config, kwargs = AutoConfig.from_pretrained( pretrained_model_name_or_path, return_unused_kwargs=True, trust_remote_code=trust_remote_code, **hub_kwargs, - **kwargs_copy, + **kwargs, ) + + # if torch_dtype=auto was passed here, ensure to pass it on + if kwargs_orig.get("torch_dtype", None) == "auto": + kwargs["torch_dtype"] = "auto" + if hasattr(config, "auto_map") and cls.__name__ in config.auto_map: if not trust_remote_code: raise ValueError( diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 43e06e9067..e6e7968ae5 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2920,6 +2920,10 @@ class ModelUtilsTest(TestCasePlus): model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto") self.assertEqual(model.dtype, torch.float16) + # 3. now retest that AutoModel behaves the same wrt torch_dtype="auto" as T5ForConditionalGeneration + model = AutoModel.from_pretrained(model_path, torch_dtype="auto") + self.assertEqual(model.dtype, torch.float16) + # test fp16 save_pretrained, loaded with the explicit fp16 model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16) self.assertEqual(model.dtype, torch.float16)