Add video modality for InstrucBLIP (#30182)
* squash in single commit * add docs * dummy obj * more changes in diff converter * tiny fix * make docs happy * skip test * repo consistency tests * update docstring * style * fix tests * change diff imports * [run-slow] instructblipvideo * [run-slow] instructblipvideo * fix tests and remove logit check * [run-slow] instructblipvideo
This commit is contained in:
committed by
GitHub
parent
a958c4a801
commit
fc689d75a0
@@ -90,6 +90,7 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
|
||||
"RecurrentGemmaModel", # Building part of bigger (tested) model.
|
||||
"FuyuForCausalLM", # Not tested fort now
|
||||
"InstructBlipQFormerModel", # Building part of bigger (tested) model.
|
||||
"InstructBlipVideoQFormerModel", # Building part of bigger (tested) model.
|
||||
"UMT5EncoderModel", # Building part of bigger (tested) model.
|
||||
"Blip2QFormerModel", # Building part of bigger (tested) model.
|
||||
"ErnieMForInformationExtraction",
|
||||
@@ -245,6 +246,8 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
|
||||
"GPTSw3DoubleHeadsModel",
|
||||
"InstructBlipVisionModel",
|
||||
"InstructBlipQFormerModel",
|
||||
"InstructBlipVideoVisionModel",
|
||||
"InstructBlipVideoQFormerModel",
|
||||
"LayoutLMForQuestionAnswering",
|
||||
"LukeForMaskedLM",
|
||||
"LukeForEntityClassification",
|
||||
|
||||
@@ -173,7 +173,7 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
|
||||
- LLaMa -> MyNewModel abd MyNewModel -> Llama
|
||||
"""
|
||||
|
||||
def __init__(self, old_name, new_name):
|
||||
def __init__(self, old_name, new_name, given_old_name=None, given_new_name=None):
|
||||
super().__init__()
|
||||
self.old_name = old_name
|
||||
self.new_name = new_name
|
||||
@@ -183,6 +183,8 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
|
||||
old_name.upper(): new_name.upper(),
|
||||
"".join(x.title() for x in old_name.split("_")): self.default_name,
|
||||
}
|
||||
if given_old_name is not None and given_new_name is not None and given_old_name not in self.patterns:
|
||||
self.patterns[given_old_name] = given_new_name
|
||||
|
||||
def preserve_case_replace(self, text):
|
||||
# Create a regex pattern to match all variations
|
||||
@@ -201,9 +203,9 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
|
||||
return updated_node.with_changes(value=update)
|
||||
|
||||
|
||||
def find_classes_in_file(module: cst.Module, old_id="llama", new_id="gemma"):
|
||||
def find_classes_in_file(module: cst.Module, old_id="llama", new_id="gemma", given_old_name=None, given_new_name=None):
|
||||
"""Helper function to rename and then parse a source file using the ClassFinder"""
|
||||
transformer = ReplaceNameTransformer(old_id, new_id)
|
||||
transformer = ReplaceNameTransformer(old_id, new_id, given_old_name, given_new_name)
|
||||
new_module = module.visit(transformer)
|
||||
|
||||
wrapper = MetadataWrapper(new_module)
|
||||
@@ -356,11 +358,13 @@ def replace_call_to_super(class_finder: ClassFinder, updated_node: cst.ClassDef,
|
||||
class DiffConverterTransformer(CSTTransformer):
|
||||
METADATA_DEPENDENCIES = (ParentNodeProvider, ScopeProvider, PositionProvider)
|
||||
|
||||
def __init__(self, python_module, new_name):
|
||||
def __init__(self, python_module, new_name, given_old_name=None, given_new_name=None):
|
||||
super().__init__()
|
||||
self.model_name = (
|
||||
new_name # name of the model being defined. Should be in the format of `llama` or `layout_xlm` our `phi3`
|
||||
)
|
||||
self.given_old_name = given_old_name
|
||||
self.given_new_name = given_new_name
|
||||
# fmt: off
|
||||
self.python_module = python_module # we store the original module to use `code_for_node`
|
||||
self.transformers_imports = {} # maps the imports name like "from transformers.models.xxx" to the parsed AST module
|
||||
@@ -426,6 +430,7 @@ class DiffConverterTransformer(CSTTransformer):
|
||||
"insert_idx": self.global_scope_index,
|
||||
"node": updated_node,
|
||||
}
|
||||
self.config_body = [updated_node]
|
||||
return updated_node
|
||||
|
||||
def leave_ClassDef(self, original_node, updated_node):
|
||||
@@ -457,13 +462,18 @@ class DiffConverterTransformer(CSTTransformer):
|
||||
f"Tried parsing the name of the imported package from {super_file_name}, could not extract the model name"
|
||||
)
|
||||
|
||||
if super_file_name not in self.visited_module: # only extract classes once
|
||||
visited_module = self.visited_module
|
||||
if super_file_name not in visited_module: # only extract classes once
|
||||
class_finder = find_classes_in_file(
|
||||
self.transformers_imports[super_file_name], model_name, self.model_name
|
||||
self.transformers_imports[super_file_name],
|
||||
model_name,
|
||||
self.model_name,
|
||||
self.given_old_name,
|
||||
self.given_new_name,
|
||||
)
|
||||
self.visited_module[super_file_name] = class_finder
|
||||
visited_module[super_file_name] = class_finder
|
||||
else: # we are re-using the previously parsed data
|
||||
class_finder = self.visited_module[super_file_name]
|
||||
class_finder = visited_module[super_file_name]
|
||||
|
||||
list_dependencies = {
|
||||
dep: class_finder.class_start_line.get(dep, 1000)
|
||||
@@ -474,7 +484,7 @@ class DiffConverterTransformer(CSTTransformer):
|
||||
start_insert_idx = self.global_scope_index
|
||||
for dependency, _ in list_dependencies:
|
||||
node = class_finder.global_nodes.get(dependency, None)
|
||||
if node is not None:
|
||||
if node is not None and "Config" not in class_name:
|
||||
if dependency not in self.new_body:
|
||||
start_insert_idx -= 1
|
||||
self.new_body[dependency] = {"insert_idx": start_insert_idx, "node": node}
|
||||
@@ -485,7 +495,7 @@ class DiffConverterTransformer(CSTTransformer):
|
||||
if len(list_dependencies) > 0:
|
||||
updated_node = replace_call_to_super(class_finder, updated_node, class_name)
|
||||
if "Config" in class_name:
|
||||
self.config_body = [updated_node]
|
||||
self.config_body += [updated_node]
|
||||
else:
|
||||
self.new_body[class_name] = {"insert_idx": self.global_scope_index, "node": updated_node}
|
||||
return updated_node
|
||||
@@ -503,10 +513,24 @@ class DiffConverterTransformer(CSTTransformer):
|
||||
def leave_Module(self, original_node: cst.Assign, node):
|
||||
imports = {self.python_module.code_for_node(k): k for k in self.all_imports}
|
||||
dependency_imports = {}
|
||||
config_imports = []
|
||||
for visiter in self.visited_module.values():
|
||||
dependency_imports.update({self.python_module.code_for_node(k): k for k in visiter.imports.values()})
|
||||
|
||||
# manually clean up if it's importing a config from configuration file (ruff doesn't do that)
|
||||
config_imports = []
|
||||
for i in list(dependency_imports.values()):
|
||||
if (
|
||||
hasattr(i.body[0], "module")
|
||||
and isinstance(i.body[0].module, cst.Name)
|
||||
and f"configuration_{self.model_name}" in i.body[0].module.value
|
||||
):
|
||||
pass
|
||||
else:
|
||||
config_imports.append(i)
|
||||
|
||||
if hasattr(self, "config_body"):
|
||||
self.config_body = list(imports.values()) + self.config_body
|
||||
self.config_body = list(imports.values()) + config_imports + self.config_body
|
||||
dependency_imports.update(imports)
|
||||
new_body = list(dependency_imports.values())
|
||||
if len(self.new_body.keys()) > 0:
|
||||
@@ -516,7 +540,7 @@ class DiffConverterTransformer(CSTTransformer):
|
||||
return node.with_changes(body=[*new_body])
|
||||
|
||||
|
||||
def convert_file(diff_file, cst_transformers=None):
|
||||
def convert_file(diff_file, old_model_name=None, new_model_name=None, cst_transformers=None):
|
||||
model_name = re.search(r"diff_(.*)(?=\.py$)", diff_file).groups()[0]
|
||||
# Parse the Python file
|
||||
with open(diff_file, "r") as file:
|
||||
@@ -524,7 +548,7 @@ def convert_file(diff_file, cst_transformers=None):
|
||||
module = cst.parse_module(code)
|
||||
wrapper = MetadataWrapper(module)
|
||||
if cst_transformers is None:
|
||||
cst_transformers = DiffConverterTransformer(module, model_name)
|
||||
cst_transformers = DiffConverterTransformer(module, model_name, old_model_name, new_model_name)
|
||||
new_mod = wrapper.visit(cst_transformers)
|
||||
ruffed_code = run_ruff(new_mod.code, True)
|
||||
formatted_code = run_ruff(ruffed_code, False)
|
||||
@@ -551,10 +575,20 @@ if __name__ == "__main__":
|
||||
nargs="+",
|
||||
help="A list of `diff_xxxx` files that should be converted to single model file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--old_model_name",
|
||||
required=False,
|
||||
help="The name of the model from which the copying is done in CamelCase. If not provided is inferred from diff-file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--new_model_name",
|
||||
required=False,
|
||||
help="The name of the new model being added in CamelCase. If not provided is inferred from diff-file",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if args.files_to_parse == ["all"]:
|
||||
args.files_to_parse = glob.glob("src/transformers/models/**/diff_*.py", recursive=True)
|
||||
for file_name in args.files_to_parse:
|
||||
print(f"Converting {file_name} to a single model single file format")
|
||||
module_path = file_name.replace("/", ".").replace(".py", "").replace("src.", "")
|
||||
converter = convert_file(file_name)
|
||||
converter = convert_file(file_name, args.old_model_name, args.new_model_name)
|
||||
|
||||
Reference in New Issue
Block a user