diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index e18854d205..2d1450bd06 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -132,7 +132,10 @@ def get_parameter_device(parameter: Union[nn.Module, GenerationMixin, "ModuleUti return first_tuple[1].device -def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]): +def get_first_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]): + """ + Returns the first parameter dtype (can be non-floating) or asserts if none were found. + """ try: return next(parameter.parameters()).dtype except StopIteration: @@ -147,6 +150,58 @@ def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtil return first_tuple[1].dtype +def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]): + """ + Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found. + """ + try: + for t in parameter.parameters(): + if t.is_floating_point(): + return t.dtype + # if no floating dtype was found return whatever the first dtype is + else: + return t.dtype + + except StopIteration: + # For nn.DataParallel compatibility in PyTorch 1.5 + + def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + for tuple in gen: + if tuple[1].is_floating_point(): + return tuple[1].dtype + # fallback to any dtype the model has even if not floating + else: + return tuple[1].dtype + + +def get_state_dict_float_dtype(state_dict): + """ + Returns the first found floating dtype in `state_dict` or asserts if none were found. + """ + for t in state_dict.values(): + if t.is_floating_point(): + return t.dtype + + raise ValueError("couldn't find any floating point dtypes in state_dict") + + +def get_state_dict_dtype(state_dict): + """ + Returns the first found floating dtype in `state_dict` if there is one, otherwise returns the last dtype. + """ + for t in state_dict.values(): + if t.is_floating_point(): + return t.dtype + + # if no floating dtype was found return whatever the first dtype is + else: + return t.dtype + + def convert_file_size_to_int(size: Union[int, str]): """ Converts a size expressed as a string with digits an unit (like `"5MB"`) to an integer (in bytes). @@ -2076,7 +2131,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # set dtype to instantiate the model under: # 1. If torch_dtype is not None, we use that dtype # 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first - # weights entry - we assume all weights are of the same dtype + # weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype # we also may have config.torch_dtype available, but we won't rely on it till v5 dtype_orig = None if torch_dtype is not None: @@ -2085,10 +2140,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix if is_sharded and "dtype" in sharded_metadata: torch_dtype = sharded_metadata["dtype"] elif not is_sharded: - torch_dtype = next(iter(state_dict.values())).dtype + torch_dtype = get_state_dict_dtype(state_dict) else: one_state_dict = load_state_dict(resolved_archive_file) - torch_dtype = next(iter(one_state_dict.values())).dtype + torch_dtype = get_state_dict_dtype(one_state_dict) del one_state_dict # free CPU memory else: raise ValueError( diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 747647874c..7bdb4a0590 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -134,6 +134,7 @@ def _config_zero_init(config): TINY_T5 = "patrickvonplaten/t5-tiny-random" +TINY_BERT_FOR_TOKEN_CLASSIFICATION = "hf-internal-testing/tiny-bert-for-token-classification" @require_torch @@ -2557,6 +2558,10 @@ class ModelUtilsTest(TestCasePlus): model = AutoModel.from_pretrained(TINY_T5, torch_dtype=torch.float16) self.assertEqual(model.dtype, torch.float16) + # test model whose first param is not of a floating type, but int + model = AutoModel.from_pretrained(TINY_BERT_FOR_TOKEN_CLASSIFICATION, torch_dtype="auto") + self.assertEqual(model.dtype, torch.float32) + def test_no_super_init_config_and_model(self): config = NoSuperInitConfig(attribute=32) model = NoSuperInitModel(config)