Include image processor in add-new-model-like (#20439)
This commit is contained in:
@@ -62,6 +62,9 @@ class ModelPatterns:
|
||||
The tokenizer class associated with this model. Will default to `"{model_camel_cased}Config"`.
|
||||
tokenizer_class (`str`, *optional*):
|
||||
The tokenizer class associated with this model (leave to `None` for models that don't use a tokenizer).
|
||||
image_processor_class (`str`, *optional*):
|
||||
The image processor class associated with this model (leave to `None` for models that don't use an image
|
||||
processor).
|
||||
feature_extractor_class (`str`, *optional*):
|
||||
The feature extractor class associated with this model (leave to `None` for models that don't use a feature
|
||||
extractor).
|
||||
@@ -77,6 +80,7 @@ class ModelPatterns:
|
||||
model_upper_cased: Optional[str] = None
|
||||
config_class: Optional[str] = None
|
||||
tokenizer_class: Optional[str] = None
|
||||
image_processor_class: Optional[str] = None
|
||||
feature_extractor_class: Optional[str] = None
|
||||
processor_class: Optional[str] = None
|
||||
|
||||
@@ -101,6 +105,7 @@ class ModelPatterns:
|
||||
ATTRIBUTE_TO_PLACEHOLDER = {
|
||||
"config_class": "[CONFIG_CLASS]",
|
||||
"tokenizer_class": "[TOKENIZER_CLASS]",
|
||||
"image_processor_class": "[IMAGE_PROCESSOR_CLASS]",
|
||||
"feature_extractor_class": "[FEATURE_EXTRACTOR_CLASS]",
|
||||
"processor_class": "[PROCESSOR_CLASS]",
|
||||
"checkpoint": "[CHECKPOINT]",
|
||||
@@ -283,7 +288,7 @@ def replace_model_patterns(
|
||||
# contains the camel-cased named, but will be treated before.
|
||||
attributes_to_check = ["config_class"]
|
||||
# Add relevant preprocessing classes
|
||||
for attr in ["tokenizer_class", "feature_extractor_class", "processor_class"]:
|
||||
for attr in ["tokenizer_class", "image_processor_class", "feature_extractor_class", "processor_class"]:
|
||||
if getattr(old_model_patterns, attr) is not None and getattr(new_model_patterns, attr) is not None:
|
||||
attributes_to_check.append(attr)
|
||||
|
||||
@@ -553,6 +558,7 @@ def get_model_files(model_type: str, frameworks: Optional[List[str]] = None) ->
|
||||
f"test_modeling_tf_{module_name}.py",
|
||||
f"test_modeling_flax_{module_name}.py",
|
||||
f"test_tokenization_{module_name}.py",
|
||||
f"test_image_processing_{module_name}.py",
|
||||
f"test_feature_extraction_{module_name}.py",
|
||||
f"test_processor_{module_name}.py",
|
||||
]
|
||||
@@ -687,6 +693,7 @@ def retrieve_info_for_model(model_type, frameworks: Optional[List[str]] = None):
|
||||
tokenizer_class = tokenizer_classes[0] if tokenizer_classes[0] is not None else tokenizer_classes[1]
|
||||
else:
|
||||
tokenizer_class = None
|
||||
image_processor_class = auto_module.image_processing_auto.IMAGE_PROCESSOR_MAPPING_NAMES.get(model_type, None)
|
||||
feature_extractor_class = auto_module.feature_extraction_auto.FEATURE_EXTRACTOR_MAPPING_NAMES.get(model_type, None)
|
||||
processor_class = auto_module.processing_auto.PROCESSOR_MAPPING_NAMES.get(model_type, None)
|
||||
|
||||
@@ -731,6 +738,7 @@ def retrieve_info_for_model(model_type, frameworks: Optional[List[str]] = None):
|
||||
model_upper_cased=model_upper_cased,
|
||||
config_class=config_class,
|
||||
tokenizer_class=tokenizer_class,
|
||||
image_processor_class=image_processor_class,
|
||||
feature_extractor_class=feature_extractor_class,
|
||||
processor_class=processor_class,
|
||||
)
|
||||
@@ -748,14 +756,15 @@ def clean_frameworks_in_init(
|
||||
):
|
||||
"""
|
||||
Removes all the import lines that don't belong to a given list of frameworks or concern tokenizers/feature
|
||||
extractors/processors in an init.
|
||||
extractors/image processors/processors in an init.
|
||||
|
||||
Args:
|
||||
init_file (`str` or `os.PathLike`): The path to the init to treat.
|
||||
frameworks (`List[str]`, *optional*):
|
||||
If passed, this will remove all imports that are subject to a framework not in frameworks
|
||||
keep_processing (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to keep the preprocessing (tokenizer, feature extractor, processor) imports in the init.
|
||||
Whether or not to keep the preprocessing (tokenizer, feature extractor, image processor, processor) imports
|
||||
in the init.
|
||||
"""
|
||||
if frameworks is None:
|
||||
frameworks = get_default_frameworks()
|
||||
@@ -808,8 +817,9 @@ def clean_frameworks_in_init(
|
||||
idx += 1
|
||||
# Otherwise we keep the line, except if it's a tokenizer import and we don't want to keep it.
|
||||
elif keep_processing or (
|
||||
re.search('^\s*"(tokenization|processing|feature_extraction)', lines[idx]) is None
|
||||
and re.search("^\s*from .(tokenization|processing|feature_extraction)", lines[idx]) is None
|
||||
re.search('^\s*"(tokenization|processing|feature_extraction|image_processing)', lines[idx]) is None
|
||||
and re.search("^\s*from .(tokenization|processing|feature_extraction|image_processing)", lines[idx])
|
||||
is None
|
||||
):
|
||||
new_lines.append(lines[idx])
|
||||
idx += 1
|
||||
@@ -885,6 +895,7 @@ def add_model_to_main_init(
|
||||
if not with_processing:
|
||||
processing_classes = [
|
||||
old_model_patterns.tokenizer_class,
|
||||
old_model_patterns.image_processor_class,
|
||||
old_model_patterns.feature_extractor_class,
|
||||
old_model_patterns.processor_class,
|
||||
]
|
||||
@@ -962,6 +973,7 @@ AUTO_CLASSES_PATTERNS = {
|
||||
' ("{model_type}", "{pretrained_archive_map}"),',
|
||||
],
|
||||
"feature_extraction_auto.py": [' ("{model_type}", "{feature_extractor_class}"),'],
|
||||
"image_processing_auto.py": [' ("{model_type}", "{image_processor_class}"),'],
|
||||
"modeling_auto.py": [' ("{model_type}", "{any_pt_class}"),'],
|
||||
"modeling_tf_auto.py": [' ("{model_type}", "{any_tf_class}"),'],
|
||||
"modeling_flax_auto.py": [' ("{model_type}", "{any_flax_class}"),'],
|
||||
@@ -995,6 +1007,14 @@ def add_model_to_auto_classes(
|
||||
)
|
||||
elif "{config_class}" in pattern:
|
||||
new_patterns.append(pattern.replace("{config_class}", old_model_patterns.config_class))
|
||||
elif "{image_processor_class}" in pattern:
|
||||
if (
|
||||
old_model_patterns.image_processor_class is not None
|
||||
and new_model_patterns.image_processor_class is not None
|
||||
):
|
||||
new_patterns.append(
|
||||
pattern.replace("{image_processor_class}", old_model_patterns.image_processor_class)
|
||||
)
|
||||
elif "{feature_extractor_class}" in pattern:
|
||||
if (
|
||||
old_model_patterns.feature_extractor_class is not None
|
||||
@@ -1121,6 +1141,10 @@ def duplicate_doc_file(
|
||||
# We only add the tokenizer if necessary
|
||||
if old_model_patterns.tokenizer_class != new_model_patterns.tokenizer_class:
|
||||
new_blocks.append(new_block)
|
||||
elif "ImageProcessor" in block_class:
|
||||
# We only add the image processor if necessary
|
||||
if old_model_patterns.image_processor_class != new_model_patterns.image_processor_class:
|
||||
new_blocks.append(new_block)
|
||||
elif "FeatureExtractor" in block_class:
|
||||
# We only add the feature extractor if necessary
|
||||
if old_model_patterns.feature_extractor_class != new_model_patterns.feature_extractor_class:
|
||||
@@ -1182,7 +1206,7 @@ def create_new_model_like(
|
||||
)
|
||||
|
||||
keep_old_processing = True
|
||||
for processing_attr in ["feature_extractor_class", "processor_class", "tokenizer_class"]:
|
||||
for processing_attr in ["image_processor_class", "feature_extractor_class", "processor_class", "tokenizer_class"]:
|
||||
if getattr(old_model_patterns, processing_attr) != getattr(new_model_patterns, processing_attr):
|
||||
keep_old_processing = False
|
||||
|
||||
@@ -1198,7 +1222,10 @@ def create_new_model_like(
|
||||
files_to_adapt = [
|
||||
f
|
||||
for f in files_to_adapt
|
||||
if "tokenization" not in str(f) and "processing" not in str(f) and "feature_extraction" not in str(f)
|
||||
if "tokenization" not in str(f)
|
||||
and "processing" not in str(f)
|
||||
and "feature_extraction" not in str(f)
|
||||
and "image_processing" not in str(f)
|
||||
]
|
||||
|
||||
os.makedirs(module_folder, exist_ok=True)
|
||||
@@ -1236,7 +1263,10 @@ def create_new_model_like(
|
||||
files_to_adapt = [
|
||||
f
|
||||
for f in files_to_adapt
|
||||
if "tokenization" not in str(f) and "processor" not in str(f) and "feature_extraction" not in str(f)
|
||||
if "tokenization" not in str(f)
|
||||
and "processor" not in str(f)
|
||||
and "feature_extraction" not in str(f)
|
||||
and "image_processing" not in str(f)
|
||||
]
|
||||
|
||||
def disable_fx_test(filename: Path) -> bool:
|
||||
@@ -1458,6 +1488,7 @@ def get_user_input():
|
||||
|
||||
old_model_info = retrieve_info_for_model(old_model_type)
|
||||
old_tokenizer_class = old_model_info["model_patterns"].tokenizer_class
|
||||
old_image_processor_class = old_model_info["model_patterns"].image_processor_class
|
||||
old_feature_extractor_class = old_model_info["model_patterns"].feature_extractor_class
|
||||
old_processor_class = old_model_info["model_patterns"].processor_class
|
||||
old_frameworks = old_model_info["frameworks"]
|
||||
@@ -1497,7 +1528,9 @@ def get_user_input():
|
||||
)
|
||||
|
||||
old_processing_classes = [
|
||||
c for c in [old_feature_extractor_class, old_tokenizer_class, old_processor_class] if c is not None
|
||||
c
|
||||
for c in [old_image_processor_class, old_feature_extractor_class, old_tokenizer_class, old_processor_class]
|
||||
if c is not None
|
||||
]
|
||||
old_processing_classes = ", ".join(old_processing_classes)
|
||||
keep_processing = get_user_field(
|
||||
@@ -1506,6 +1539,7 @@ def get_user_input():
|
||||
fallback_message="Please answer yes/no, y/n, true/false or 1/0. ",
|
||||
)
|
||||
if keep_processing:
|
||||
image_processor_class = old_image_processor_class
|
||||
feature_extractor_class = old_feature_extractor_class
|
||||
processor_class = old_processor_class
|
||||
tokenizer_class = old_tokenizer_class
|
||||
@@ -1517,6 +1551,11 @@ def get_user_input():
|
||||
)
|
||||
else:
|
||||
tokenizer_class = None
|
||||
if old_image_processor_class is not None:
|
||||
image_processor_class = get_user_field(
|
||||
"What will be the name of the image processor class for this model? ",
|
||||
default_value=f"{model_camel_cased}ImageProcessor",
|
||||
)
|
||||
if old_feature_extractor_class is not None:
|
||||
feature_extractor_class = get_user_field(
|
||||
"What will be the name of the feature extractor class for this model? ",
|
||||
@@ -1541,6 +1580,7 @@ def get_user_input():
|
||||
model_upper_cased=model_upper_cased,
|
||||
config_class=config_class,
|
||||
tokenizer_class=tokenizer_class,
|
||||
image_processor_class=image_processor_class,
|
||||
feature_extractor_class=feature_extractor_class,
|
||||
processor_class=processor_class,
|
||||
)
|
||||
|
||||
@@ -44,12 +44,14 @@ BERT_MODEL_FILES = {
|
||||
"src/transformers/models/bert/configuration_bert.py",
|
||||
"src/transformers/models/bert/tokenization_bert.py",
|
||||
"src/transformers/models/bert/tokenization_bert_fast.py",
|
||||
"src/transformers/models/bert/tokenization_bert_tf.py",
|
||||
"src/transformers/models/bert/modeling_bert.py",
|
||||
"src/transformers/models/bert/modeling_flax_bert.py",
|
||||
"src/transformers/models/bert/modeling_tf_bert.py",
|
||||
"src/transformers/models/bert/convert_bert_original_tf_checkpoint_to_pytorch.py",
|
||||
"src/transformers/models/bert/convert_bert_original_tf2_checkpoint_to_pytorch.py",
|
||||
"src/transformers/models/bert/convert_bert_pytorch_checkpoint_to_original_tf.py",
|
||||
"src/transformers/models/bert/convert_bert_token_dropping_original_tf2_checkpoint_to_pytorch.py",
|
||||
}
|
||||
|
||||
VIT_MODEL_FILES = {
|
||||
@@ -58,6 +60,7 @@ VIT_MODEL_FILES = {
|
||||
"src/transformers/models/vit/convert_dino_to_pytorch.py",
|
||||
"src/transformers/models/vit/convert_vit_timm_to_pytorch.py",
|
||||
"src/transformers/models/vit/feature_extraction_vit.py",
|
||||
"src/transformers/models/vit/image_processing_vit.py",
|
||||
"src/transformers/models/vit/modeling_vit.py",
|
||||
"src/transformers/models/vit/modeling_tf_vit.py",
|
||||
"src/transformers/models/vit/modeling_flax_vit.py",
|
||||
@@ -89,7 +92,8 @@ class TestAddNewModelLike(unittest.TestCase):
|
||||
|
||||
def check_result(self, file_name, expected_result):
|
||||
with open(file_name, "r", encoding="utf-8") as f:
|
||||
self.assertEqual(f.read(), expected_result)
|
||||
result = f.read()
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
def test_re_class_func(self):
|
||||
self.assertEqual(_re_class_func.search("def my_function(x, y):").groups()[0], "my_function")
|
||||
@@ -439,7 +443,7 @@ NEW_BERT_CONSTANT = "value"
|
||||
self.check_result(dest_file_name, bert_expected)
|
||||
|
||||
def test_filter_framework_files(self):
|
||||
files = ["modeling_tf_bert.py", "modeling_bert.py", "modeling_flax_bert.py", "configuration_bert.py"]
|
||||
files = ["modeling_bert.py", "modeling_tf_bert.py", "modeling_flax_bert.py", "configuration_bert.py"]
|
||||
self.assertEqual(filter_framework_files(files), files)
|
||||
self.assertEqual(set(filter_framework_files(files, ["pt", "tf", "flax"])), set(files))
|
||||
|
||||
@@ -467,7 +471,7 @@ NEW_BERT_CONSTANT = "value"
|
||||
bert_files = get_model_files("bert")
|
||||
|
||||
doc_file = str(Path(bert_files["doc_file"]).relative_to(REPO_PATH))
|
||||
self.assertEqual(doc_file, "docs/source/model_doc/bert.mdx")
|
||||
self.assertEqual(doc_file, "docs/source/en/model_doc/bert.mdx")
|
||||
|
||||
model_files = {str(Path(f).relative_to(REPO_PATH)) for f in bert_files["model_files"]}
|
||||
self.assertEqual(model_files, BERT_MODEL_FILES)
|
||||
@@ -476,17 +480,17 @@ NEW_BERT_CONSTANT = "value"
|
||||
|
||||
test_files = {str(Path(f).relative_to(REPO_PATH)) for f in bert_files["test_files"]}
|
||||
bert_test_files = {
|
||||
"tests/test_tokenization_bert.py",
|
||||
"tests/test_modeling_bert.py",
|
||||
"tests/test_modeling_tf_bert.py",
|
||||
"tests/test_modeling_flax_bert.py",
|
||||
"tests/models/bert/test_tokenization_bert.py",
|
||||
"tests/models/bert/test_modeling_bert.py",
|
||||
"tests/models/bert/test_modeling_tf_bert.py",
|
||||
"tests/models/bert/test_modeling_flax_bert.py",
|
||||
}
|
||||
self.assertEqual(test_files, bert_test_files)
|
||||
|
||||
# VIT
|
||||
vit_files = get_model_files("vit")
|
||||
doc_file = str(Path(vit_files["doc_file"]).relative_to(REPO_PATH))
|
||||
self.assertEqual(doc_file, "docs/source/model_doc/vit.mdx")
|
||||
self.assertEqual(doc_file, "docs/source/en/model_doc/vit.mdx")
|
||||
|
||||
model_files = {str(Path(f).relative_to(REPO_PATH)) for f in vit_files["model_files"]}
|
||||
self.assertEqual(model_files, VIT_MODEL_FILES)
|
||||
@@ -495,17 +499,17 @@ NEW_BERT_CONSTANT = "value"
|
||||
|
||||
test_files = {str(Path(f).relative_to(REPO_PATH)) for f in vit_files["test_files"]}
|
||||
vit_test_files = {
|
||||
"tests/test_feature_extraction_vit.py",
|
||||
"tests/test_modeling_vit.py",
|
||||
"tests/test_modeling_tf_vit.py",
|
||||
"tests/test_modeling_flax_vit.py",
|
||||
"tests/models/vit/test_feature_extraction_vit.py",
|
||||
"tests/models/vit/test_modeling_vit.py",
|
||||
"tests/models/vit/test_modeling_tf_vit.py",
|
||||
"tests/models/vit/test_modeling_flax_vit.py",
|
||||
}
|
||||
self.assertEqual(test_files, vit_test_files)
|
||||
|
||||
# Wav2Vec2
|
||||
wav2vec2_files = get_model_files("wav2vec2")
|
||||
doc_file = str(Path(wav2vec2_files["doc_file"]).relative_to(REPO_PATH))
|
||||
self.assertEqual(doc_file, "docs/source/model_doc/wav2vec2.mdx")
|
||||
self.assertEqual(doc_file, "docs/source/en/model_doc/wav2vec2.mdx")
|
||||
|
||||
model_files = {str(Path(f).relative_to(REPO_PATH)) for f in wav2vec2_files["model_files"]}
|
||||
self.assertEqual(model_files, WAV2VEC2_MODEL_FILES)
|
||||
@@ -514,12 +518,12 @@ NEW_BERT_CONSTANT = "value"
|
||||
|
||||
test_files = {str(Path(f).relative_to(REPO_PATH)) for f in wav2vec2_files["test_files"]}
|
||||
wav2vec2_test_files = {
|
||||
"tests/test_feature_extraction_wav2vec2.py",
|
||||
"tests/test_modeling_wav2vec2.py",
|
||||
"tests/test_modeling_tf_wav2vec2.py",
|
||||
"tests/test_modeling_flax_wav2vec2.py",
|
||||
"tests/test_processor_wav2vec2.py",
|
||||
"tests/test_tokenization_wav2vec2.py",
|
||||
"tests/models/wav2vec2/test_feature_extraction_wav2vec2.py",
|
||||
"tests/models/wav2vec2/test_modeling_wav2vec2.py",
|
||||
"tests/models/wav2vec2/test_modeling_tf_wav2vec2.py",
|
||||
"tests/models/wav2vec2/test_modeling_flax_wav2vec2.py",
|
||||
"tests/models/wav2vec2/test_processor_wav2vec2.py",
|
||||
"tests/models/wav2vec2/test_tokenization_wav2vec2.py",
|
||||
}
|
||||
self.assertEqual(test_files, wav2vec2_test_files)
|
||||
|
||||
@@ -528,7 +532,7 @@ NEW_BERT_CONSTANT = "value"
|
||||
bert_files = get_model_files("bert", frameworks=["pt"])
|
||||
|
||||
doc_file = str(Path(bert_files["doc_file"]).relative_to(REPO_PATH))
|
||||
self.assertEqual(doc_file, "docs/source/model_doc/bert.mdx")
|
||||
self.assertEqual(doc_file, "docs/source/en/model_doc/bert.mdx")
|
||||
|
||||
model_files = {str(Path(f).relative_to(REPO_PATH)) for f in bert_files["model_files"]}
|
||||
bert_model_files = BERT_MODEL_FILES - {
|
||||
@@ -541,15 +545,15 @@ NEW_BERT_CONSTANT = "value"
|
||||
|
||||
test_files = {str(Path(f).relative_to(REPO_PATH)) for f in bert_files["test_files"]}
|
||||
bert_test_files = {
|
||||
"tests/test_tokenization_bert.py",
|
||||
"tests/test_modeling_bert.py",
|
||||
"tests/models/bert/test_tokenization_bert.py",
|
||||
"tests/models/bert/test_modeling_bert.py",
|
||||
}
|
||||
self.assertEqual(test_files, bert_test_files)
|
||||
|
||||
# VIT
|
||||
vit_files = get_model_files("vit", frameworks=["pt"])
|
||||
doc_file = str(Path(vit_files["doc_file"]).relative_to(REPO_PATH))
|
||||
self.assertEqual(doc_file, "docs/source/model_doc/vit.mdx")
|
||||
self.assertEqual(doc_file, "docs/source/en/model_doc/vit.mdx")
|
||||
|
||||
model_files = {str(Path(f).relative_to(REPO_PATH)) for f in vit_files["model_files"]}
|
||||
vit_model_files = VIT_MODEL_FILES - {
|
||||
@@ -562,15 +566,15 @@ NEW_BERT_CONSTANT = "value"
|
||||
|
||||
test_files = {str(Path(f).relative_to(REPO_PATH)) for f in vit_files["test_files"]}
|
||||
vit_test_files = {
|
||||
"tests/test_feature_extraction_vit.py",
|
||||
"tests/test_modeling_vit.py",
|
||||
"tests/models/vit/test_feature_extraction_vit.py",
|
||||
"tests/models/vit/test_modeling_vit.py",
|
||||
}
|
||||
self.assertEqual(test_files, vit_test_files)
|
||||
|
||||
# Wav2Vec2
|
||||
wav2vec2_files = get_model_files("wav2vec2", frameworks=["pt"])
|
||||
doc_file = str(Path(wav2vec2_files["doc_file"]).relative_to(REPO_PATH))
|
||||
self.assertEqual(doc_file, "docs/source/model_doc/wav2vec2.mdx")
|
||||
self.assertEqual(doc_file, "docs/source/en/model_doc/wav2vec2.mdx")
|
||||
|
||||
model_files = {str(Path(f).relative_to(REPO_PATH)) for f in wav2vec2_files["model_files"]}
|
||||
wav2vec2_model_files = WAV2VEC2_MODEL_FILES - {
|
||||
@@ -583,10 +587,10 @@ NEW_BERT_CONSTANT = "value"
|
||||
|
||||
test_files = {str(Path(f).relative_to(REPO_PATH)) for f in wav2vec2_files["test_files"]}
|
||||
wav2vec2_test_files = {
|
||||
"tests/test_feature_extraction_wav2vec2.py",
|
||||
"tests/test_modeling_wav2vec2.py",
|
||||
"tests/test_processor_wav2vec2.py",
|
||||
"tests/test_tokenization_wav2vec2.py",
|
||||
"tests/models/wav2vec2/test_feature_extraction_wav2vec2.py",
|
||||
"tests/models/wav2vec2/test_modeling_wav2vec2.py",
|
||||
"tests/models/wav2vec2/test_processor_wav2vec2.py",
|
||||
"tests/models/wav2vec2/test_tokenization_wav2vec2.py",
|
||||
}
|
||||
self.assertEqual(test_files, wav2vec2_test_files)
|
||||
|
||||
@@ -595,7 +599,7 @@ NEW_BERT_CONSTANT = "value"
|
||||
bert_files = get_model_files("bert", frameworks=["tf", "flax"])
|
||||
|
||||
doc_file = str(Path(bert_files["doc_file"]).relative_to(REPO_PATH))
|
||||
self.assertEqual(doc_file, "docs/source/model_doc/bert.mdx")
|
||||
self.assertEqual(doc_file, "docs/source/en/model_doc/bert.mdx")
|
||||
|
||||
model_files = {str(Path(f).relative_to(REPO_PATH)) for f in bert_files["model_files"]}
|
||||
bert_model_files = BERT_MODEL_FILES - {"src/transformers/models/bert/modeling_bert.py"}
|
||||
@@ -605,16 +609,16 @@ NEW_BERT_CONSTANT = "value"
|
||||
|
||||
test_files = {str(Path(f).relative_to(REPO_PATH)) for f in bert_files["test_files"]}
|
||||
bert_test_files = {
|
||||
"tests/test_tokenization_bert.py",
|
||||
"tests/test_modeling_tf_bert.py",
|
||||
"tests/test_modeling_flax_bert.py",
|
||||
"tests/models/bert/test_tokenization_bert.py",
|
||||
"tests/models/bert/test_modeling_tf_bert.py",
|
||||
"tests/models/bert/test_modeling_flax_bert.py",
|
||||
}
|
||||
self.assertEqual(test_files, bert_test_files)
|
||||
|
||||
# VIT
|
||||
vit_files = get_model_files("vit", frameworks=["tf", "flax"])
|
||||
doc_file = str(Path(vit_files["doc_file"]).relative_to(REPO_PATH))
|
||||
self.assertEqual(doc_file, "docs/source/model_doc/vit.mdx")
|
||||
self.assertEqual(doc_file, "docs/source/en/model_doc/vit.mdx")
|
||||
|
||||
model_files = {str(Path(f).relative_to(REPO_PATH)) for f in vit_files["model_files"]}
|
||||
vit_model_files = VIT_MODEL_FILES - {"src/transformers/models/vit/modeling_vit.py"}
|
||||
@@ -624,16 +628,16 @@ NEW_BERT_CONSTANT = "value"
|
||||
|
||||
test_files = {str(Path(f).relative_to(REPO_PATH)) for f in vit_files["test_files"]}
|
||||
vit_test_files = {
|
||||
"tests/test_feature_extraction_vit.py",
|
||||
"tests/test_modeling_tf_vit.py",
|
||||
"tests/test_modeling_flax_vit.py",
|
||||
"tests/models/vit/test_feature_extraction_vit.py",
|
||||
"tests/models/vit/test_modeling_tf_vit.py",
|
||||
"tests/models/vit/test_modeling_flax_vit.py",
|
||||
}
|
||||
self.assertEqual(test_files, vit_test_files)
|
||||
|
||||
# Wav2Vec2
|
||||
wav2vec2_files = get_model_files("wav2vec2", frameworks=["tf", "flax"])
|
||||
doc_file = str(Path(wav2vec2_files["doc_file"]).relative_to(REPO_PATH))
|
||||
self.assertEqual(doc_file, "docs/source/model_doc/wav2vec2.mdx")
|
||||
self.assertEqual(doc_file, "docs/source/en/model_doc/wav2vec2.mdx")
|
||||
|
||||
model_files = {str(Path(f).relative_to(REPO_PATH)) for f in wav2vec2_files["model_files"]}
|
||||
wav2vec2_model_files = WAV2VEC2_MODEL_FILES - {"src/transformers/models/wav2vec2/modeling_wav2vec2.py"}
|
||||
@@ -643,11 +647,11 @@ NEW_BERT_CONSTANT = "value"
|
||||
|
||||
test_files = {str(Path(f).relative_to(REPO_PATH)) for f in wav2vec2_files["test_files"]}
|
||||
wav2vec2_test_files = {
|
||||
"tests/test_feature_extraction_wav2vec2.py",
|
||||
"tests/test_modeling_tf_wav2vec2.py",
|
||||
"tests/test_modeling_flax_wav2vec2.py",
|
||||
"tests/test_processor_wav2vec2.py",
|
||||
"tests/test_tokenization_wav2vec2.py",
|
||||
"tests/models/wav2vec2/test_feature_extraction_wav2vec2.py",
|
||||
"tests/models/wav2vec2/test_modeling_tf_wav2vec2.py",
|
||||
"tests/models/wav2vec2/test_modeling_flax_wav2vec2.py",
|
||||
"tests/models/wav2vec2/test_processor_wav2vec2.py",
|
||||
"tests/models/wav2vec2/test_tokenization_wav2vec2.py",
|
||||
}
|
||||
self.assertEqual(test_files, wav2vec2_test_files)
|
||||
|
||||
@@ -688,7 +692,7 @@ NEW_BERT_CONSTANT = "value"
|
||||
expected_model_classes = {
|
||||
"pt": set(bert_classes),
|
||||
"tf": {f"TF{m}" for m in bert_classes},
|
||||
"flax": {f"Flax{m}" for m in bert_classes[:-1]},
|
||||
"flax": {f"Flax{m}" for m in bert_classes[:-1] + ["BertForCausalLM"]},
|
||||
}
|
||||
|
||||
self.assertEqual(set(bert_info["frameworks"]), {"pt", "tf", "flax"})
|
||||
@@ -701,15 +705,15 @@ NEW_BERT_CONSTANT = "value"
|
||||
|
||||
test_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_bert_files["test_files"]}
|
||||
bert_test_files = {
|
||||
"tests/test_tokenization_bert.py",
|
||||
"tests/test_modeling_bert.py",
|
||||
"tests/test_modeling_tf_bert.py",
|
||||
"tests/test_modeling_flax_bert.py",
|
||||
"tests/models/bert/test_tokenization_bert.py",
|
||||
"tests/models/bert/test_modeling_bert.py",
|
||||
"tests/models/bert/test_modeling_tf_bert.py",
|
||||
"tests/models/bert/test_modeling_flax_bert.py",
|
||||
}
|
||||
self.assertEqual(test_files, bert_test_files)
|
||||
|
||||
doc_file = str(Path(all_bert_files["doc_file"]).relative_to(REPO_PATH))
|
||||
self.assertEqual(doc_file, "docs/source/model_doc/bert.mdx")
|
||||
self.assertEqual(doc_file, "docs/source/en/model_doc/bert.mdx")
|
||||
|
||||
self.assertEqual(all_bert_files["module_name"], "bert")
|
||||
|
||||
@@ -751,14 +755,14 @@ NEW_BERT_CONSTANT = "value"
|
||||
|
||||
test_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_bert_files["test_files"]}
|
||||
bert_test_files = {
|
||||
"tests/test_tokenization_bert.py",
|
||||
"tests/test_modeling_bert.py",
|
||||
"tests/test_modeling_tf_bert.py",
|
||||
"tests/models/bert/test_tokenization_bert.py",
|
||||
"tests/models/bert/test_modeling_bert.py",
|
||||
"tests/models/bert/test_modeling_tf_bert.py",
|
||||
}
|
||||
self.assertEqual(test_files, bert_test_files)
|
||||
|
||||
doc_file = str(Path(all_bert_files["doc_file"]).relative_to(REPO_PATH))
|
||||
self.assertEqual(doc_file, "docs/source/model_doc/bert.mdx")
|
||||
self.assertEqual(doc_file, "docs/source/en/model_doc/bert.mdx")
|
||||
|
||||
self.assertEqual(all_bert_files["module_name"], "bert")
|
||||
|
||||
@@ -777,8 +781,9 @@ NEW_BERT_CONSTANT = "value"
|
||||
def test_retrieve_info_for_model_with_vit(self):
|
||||
vit_info = retrieve_info_for_model("vit")
|
||||
vit_classes = ["ViTForImageClassification", "ViTModel"]
|
||||
pt_only_classes = ["ViTForMaskedImageModeling"]
|
||||
expected_model_classes = {
|
||||
"pt": set(vit_classes),
|
||||
"pt": set(vit_classes + pt_only_classes),
|
||||
"tf": {f"TF{m}" for m in vit_classes},
|
||||
"flax": {f"Flax{m}" for m in vit_classes},
|
||||
}
|
||||
@@ -793,27 +798,28 @@ NEW_BERT_CONSTANT = "value"
|
||||
|
||||
test_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_vit_files["test_files"]}
|
||||
vit_test_files = {
|
||||
"tests/test_feature_extraction_vit.py",
|
||||
"tests/test_modeling_vit.py",
|
||||
"tests/test_modeling_tf_vit.py",
|
||||
"tests/test_modeling_flax_vit.py",
|
||||
"tests/models/vit/test_feature_extraction_vit.py",
|
||||
"tests/models/vit/test_modeling_vit.py",
|
||||
"tests/models/vit/test_modeling_tf_vit.py",
|
||||
"tests/models/vit/test_modeling_flax_vit.py",
|
||||
}
|
||||
self.assertEqual(test_files, vit_test_files)
|
||||
|
||||
doc_file = str(Path(all_vit_files["doc_file"]).relative_to(REPO_PATH))
|
||||
self.assertEqual(doc_file, "docs/source/model_doc/vit.mdx")
|
||||
self.assertEqual(doc_file, "docs/source/en/model_doc/vit.mdx")
|
||||
|
||||
self.assertEqual(all_vit_files["module_name"], "vit")
|
||||
|
||||
vit_model_patterns = vit_info["model_patterns"]
|
||||
self.assertEqual(vit_model_patterns.model_name, "ViT")
|
||||
self.assertEqual(vit_model_patterns.checkpoint, "google/vit-base-patch16-224")
|
||||
self.assertEqual(vit_model_patterns.checkpoint, "google/vit-base-patch16-224-in21k")
|
||||
self.assertEqual(vit_model_patterns.model_type, "vit")
|
||||
self.assertEqual(vit_model_patterns.model_lower_cased, "vit")
|
||||
self.assertEqual(vit_model_patterns.model_camel_cased, "ViT")
|
||||
self.assertEqual(vit_model_patterns.model_upper_cased, "VIT")
|
||||
self.assertEqual(vit_model_patterns.config_class, "ViTConfig")
|
||||
self.assertEqual(vit_model_patterns.feature_extractor_class, "ViTFeatureExtractor")
|
||||
self.assertEqual(vit_model_patterns.image_processor_class, "ViTImageProcessor")
|
||||
self.assertIsNone(vit_model_patterns.tokenizer_class)
|
||||
self.assertIsNone(vit_model_patterns.processor_class)
|
||||
|
||||
@@ -844,17 +850,17 @@ NEW_BERT_CONSTANT = "value"
|
||||
|
||||
test_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_wav2vec2_files["test_files"]}
|
||||
wav2vec2_test_files = {
|
||||
"tests/test_feature_extraction_wav2vec2.py",
|
||||
"tests/test_modeling_wav2vec2.py",
|
||||
"tests/test_modeling_tf_wav2vec2.py",
|
||||
"tests/test_modeling_flax_wav2vec2.py",
|
||||
"tests/test_processor_wav2vec2.py",
|
||||
"tests/test_tokenization_wav2vec2.py",
|
||||
"tests/models/wav2vec2/test_feature_extraction_wav2vec2.py",
|
||||
"tests/models/wav2vec2/test_modeling_wav2vec2.py",
|
||||
"tests/models/wav2vec2/test_modeling_tf_wav2vec2.py",
|
||||
"tests/models/wav2vec2/test_modeling_flax_wav2vec2.py",
|
||||
"tests/models/wav2vec2/test_processor_wav2vec2.py",
|
||||
"tests/models/wav2vec2/test_tokenization_wav2vec2.py",
|
||||
}
|
||||
self.assertEqual(test_files, wav2vec2_test_files)
|
||||
|
||||
doc_file = str(Path(all_wav2vec2_files["doc_file"]).relative_to(REPO_PATH))
|
||||
self.assertEqual(doc_file, "docs/source/model_doc/wav2vec2.mdx")
|
||||
self.assertEqual(doc_file, "docs/source/en/model_doc/wav2vec2.mdx")
|
||||
|
||||
self.assertEqual(all_wav2vec2_files["module_name"], "wav2vec2")
|
||||
|
||||
@@ -881,32 +887,72 @@ _import_structure = {
|
||||
"tokenization_gpt2": ["GPT2Tokenizer"],
|
||||
}
|
||||
|
||||
if is_tokenizers_available():
|
||||
try:
|
||||
if not is_tokenizers_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["tokenization_gpt2_fast"] = ["GPT2TokenizerFast"]
|
||||
|
||||
if is_torch_available():
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_gpt2"] = ["GPT2Model"]
|
||||
|
||||
if is_tf_available():
|
||||
try:
|
||||
if not is_tf_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_tf_gpt2"] = ["TFGPT2Model"]
|
||||
|
||||
if is_flax_available():
|
||||
try:
|
||||
if not is_flax_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_flax_gpt2"] = ["FlaxGPT2Model"]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2OnnxConfig
|
||||
from .tokenization_gpt2 import GPT2Tokenizer
|
||||
|
||||
if is_tokenizers_available():
|
||||
try:
|
||||
if not is_tokenizers_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .tokenization_gpt2_fast import GPT2TokenizerFast
|
||||
|
||||
if is_torch_available():
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_gpt2 import GPT2Model
|
||||
|
||||
if is_tf_available():
|
||||
try:
|
||||
if not is_tf_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_tf_gpt2 import TFGPT2Model
|
||||
|
||||
if is_flax_available():
|
||||
try:
|
||||
if not is_flax_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_flax_gpt2 import FlaxGPT2Model
|
||||
|
||||
else:
|
||||
@@ -924,25 +970,55 @@ _import_structure = {
|
||||
"configuration_gpt2": ["GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPT2Config", "GPT2OnnxConfig"],
|
||||
}
|
||||
|
||||
if is_torch_available():
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_gpt2"] = ["GPT2Model"]
|
||||
|
||||
if is_tf_available():
|
||||
try:
|
||||
if not is_tf_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_tf_gpt2"] = ["TFGPT2Model"]
|
||||
|
||||
if is_flax_available():
|
||||
try:
|
||||
if not is_flax_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_flax_gpt2"] = ["FlaxGPT2Model"]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2OnnxConfig
|
||||
|
||||
if is_torch_available():
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_gpt2 import GPT2Model
|
||||
|
||||
if is_tf_available():
|
||||
try:
|
||||
if not is_tf_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_tf_gpt2 import TFGPT2Model
|
||||
|
||||
if is_flax_available():
|
||||
try:
|
||||
if not is_flax_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_flax_gpt2 import FlaxGPT2Model
|
||||
|
||||
else:
|
||||
@@ -961,20 +1037,40 @@ _import_structure = {
|
||||
"tokenization_gpt2": ["GPT2Tokenizer"],
|
||||
}
|
||||
|
||||
if is_tokenizers_available():
|
||||
try:
|
||||
if not is_tokenizers_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["tokenization_gpt2_fast"] = ["GPT2TokenizerFast"]
|
||||
|
||||
if is_torch_available():
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_gpt2"] = ["GPT2Model"]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2OnnxConfig
|
||||
from .tokenization_gpt2 import GPT2Tokenizer
|
||||
|
||||
if is_tokenizers_available():
|
||||
try:
|
||||
if not is_tokenizers_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .tokenization_gpt2_fast import GPT2TokenizerFast
|
||||
|
||||
if is_torch_available():
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_gpt2 import GPT2Model
|
||||
|
||||
else:
|
||||
@@ -992,13 +1088,23 @@ _import_structure = {
|
||||
"configuration_gpt2": ["GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPT2Config", "GPT2OnnxConfig"],
|
||||
}
|
||||
|
||||
if is_torch_available():
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_gpt2"] = ["GPT2Model"]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2OnnxConfig
|
||||
|
||||
if is_torch_available():
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_gpt2 import GPT2Model
|
||||
|
||||
else:
|
||||
@@ -1032,32 +1138,72 @@ _import_structure = {
|
||||
"configuration_vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig"],
|
||||
}
|
||||
|
||||
if is_vision_available():
|
||||
_import_structure["feature_extraction_vit"] = ["ViTFeatureExtractor"]
|
||||
try:
|
||||
if not is_vision_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["image_processing_vit"] = ["ViTImageProcessor"]
|
||||
|
||||
if is_torch_available():
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_vit"] = ["ViTModel"]
|
||||
|
||||
if is_tf_available():
|
||||
try:
|
||||
if not is_tf_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_tf_vit"] = ["TFViTModel"]
|
||||
|
||||
if is_flax_available():
|
||||
try:
|
||||
if not is_flax_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_flax_vit"] = ["FlaxViTModel"]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig
|
||||
|
||||
if is_vision_available():
|
||||
from .feature_extraction_vit import ViTFeatureExtractor
|
||||
try:
|
||||
if not is_vision_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .image_processing_vit import ViTImageProcessor
|
||||
|
||||
if is_torch_available():
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_vit import ViTModel
|
||||
|
||||
if is_tf_available():
|
||||
from .modeling_tf_vit import ViTModel
|
||||
try:
|
||||
if not is_tf_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_tf_vit import TFViTModel
|
||||
|
||||
if is_flax_available():
|
||||
from .modeling_flax_vit import ViTModel
|
||||
try:
|
||||
if not is_flax_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_flax_vit import FlaxViTModel
|
||||
|
||||
else:
|
||||
import sys
|
||||
@@ -1074,26 +1220,56 @@ _import_structure = {
|
||||
"configuration_vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig"],
|
||||
}
|
||||
|
||||
if is_torch_available():
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_vit"] = ["ViTModel"]
|
||||
|
||||
if is_tf_available():
|
||||
try:
|
||||
if not is_tf_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_tf_vit"] = ["TFViTModel"]
|
||||
|
||||
if is_flax_available():
|
||||
try:
|
||||
if not is_flax_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_flax_vit"] = ["FlaxViTModel"]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig
|
||||
|
||||
if is_torch_available():
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_vit import ViTModel
|
||||
|
||||
if is_tf_available():
|
||||
from .modeling_tf_vit import ViTModel
|
||||
try:
|
||||
if not is_tf_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_tf_vit import TFViTModel
|
||||
|
||||
if is_flax_available():
|
||||
from .modeling_flax_vit import ViTModel
|
||||
try:
|
||||
if not is_flax_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_flax_vit import FlaxViTModel
|
||||
|
||||
else:
|
||||
import sys
|
||||
@@ -1110,19 +1286,39 @@ _import_structure = {
|
||||
"configuration_vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig"],
|
||||
}
|
||||
|
||||
if is_vision_available():
|
||||
_import_structure["feature_extraction_vit"] = ["ViTFeatureExtractor"]
|
||||
try:
|
||||
if not is_vision_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["image_processing_vit"] = ["ViTImageProcessor"]
|
||||
|
||||
if is_torch_available():
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_vit"] = ["ViTModel"]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig
|
||||
|
||||
if is_vision_available():
|
||||
from .feature_extraction_vit import ViTFeatureExtractor
|
||||
try:
|
||||
if not is_vision_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .image_processing_vit import ViTImageProcessor
|
||||
|
||||
if is_torch_available():
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_vit import ViTModel
|
||||
|
||||
else:
|
||||
@@ -1140,13 +1336,23 @@ _import_structure = {
|
||||
"configuration_vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig"],
|
||||
}
|
||||
|
||||
if is_torch_available():
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_vit"] = ["ViTModel"]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig
|
||||
|
||||
if is_torch_available():
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_vit import ViTModel
|
||||
|
||||
else:
|
||||
@@ -1218,7 +1424,7 @@ Overview of the model.
|
||||
|
||||
## Overview
|
||||
|
||||
The GPT-New New model was proposed in [<INSERT PAPER NAME HERE>(<INSERT PAPER LINK HERE>) by <INSERT AUTHORS HERE>.
|
||||
The GPT-New New model was proposed in [<INSERT PAPER NAME HERE>](<INSERT PAPER LINK HERE>) by <INSERT AUTHORS HERE>.
|
||||
<INSERT SHORT SUMMARY HERE>
|
||||
|
||||
The abstract from the paper is the following:
|
||||
@@ -1229,7 +1435,7 @@ Tips:
|
||||
|
||||
<INSERT TIPS ABOUT MODEL HERE>
|
||||
|
||||
This model was contributed by [INSERT YOUR HF USERNAME HERE](<https://huggingface.co/<INSERT YOUR HF USERNAME HERE>).
|
||||
This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/<INSERT YOUR HF USERNAME HERE>).
|
||||
The original code can be found [here](<INSERT LINK TO GITHUB REPO HERE>).
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user