Better typing for model.config (#39132)
* Apply to all models config annotation * Update modular to preserve order * Apply modular * fix define docstring * fix dinov2 consistency (docs<->modular) * fix InstructBlipVideoForConditionalGeneration docs<->modular consistency * fixup * remove duplicate code * Delete config_class attribute from the modeling code * Add config_class attribute in base model * Update init sub class * Deprecated models update * Update new models * Fix remote code BC issue * fixup * fixing more corner cases * fix new models * add test * modular docs update * fix comment a bit * fix for py3.9
This commit is contained in:
committed by
GitHub
parent
4b258454a7
commit
cc24b0378e
@@ -33,7 +33,7 @@ from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
from functools import partial, wraps
|
||||
from threading import Thread
|
||||
from typing import Any, Callable, Optional, TypeVar, Union
|
||||
from typing import Any, Callable, Optional, TypeVar, Union, get_type_hints
|
||||
from zipfile import is_zipfile
|
||||
|
||||
import torch
|
||||
@@ -2060,6 +2060,30 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
"""
|
||||
return "pt"
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
# For BC we keep the original `config_class` definition in case
|
||||
# there is a `config_class` attribute (e.g. remote code models),
|
||||
# otherwise we derive it from the annotated `config` attribute.
|
||||
|
||||
# defined in this particular subclass
|
||||
child_annotation = cls.__dict__.get("__annotations__", {}).get("config", None)
|
||||
child_attribute = cls.__dict__.get("config_class", None)
|
||||
|
||||
# defined in the class (this subclass or any parent class)
|
||||
full_annotation = get_type_hints(cls).get("config", None)
|
||||
full_attribute = cls.config_class
|
||||
|
||||
# priority (child class_config -> child annotation -> global class_config -> global annotation)
|
||||
if child_attribute is not None:
|
||||
cls.config_class = child_attribute
|
||||
elif child_annotation is not None:
|
||||
cls.config_class = child_annotation
|
||||
elif full_attribute is not None:
|
||||
cls.config_class = full_attribute
|
||||
elif full_annotation is not None:
|
||||
cls.config_class = full_annotation
|
||||
|
||||
def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
|
||||
super().__init__()
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
|
||||
Reference in New Issue
Block a user