Add video modality for InstrucBLIP (#30182)

* squash in single commit

* add docs

* dummy obj

* more changes in diff converter

* tiny fix

* make docs happy

* skip test

* repo consistency tests

* update docstring

* style

* fix tests

* change diff imports

* [run-slow] instructblipvideo

* [run-slow] instructblipvideo

* fix tests and remove logit check

* [run-slow] instructblipvideo
This commit is contained in:
Raushan Turganbay
2024-06-25 15:45:39 +05:00
committed by GitHub
parent a958c4a801
commit fc689d75a0
28 changed files with 4358 additions and 17 deletions

View File

@@ -90,6 +90,7 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
"RecurrentGemmaModel", # Building part of bigger (tested) model.
"FuyuForCausalLM", # Not tested fort now
"InstructBlipQFormerModel", # Building part of bigger (tested) model.
"InstructBlipVideoQFormerModel", # Building part of bigger (tested) model.
"UMT5EncoderModel", # Building part of bigger (tested) model.
"Blip2QFormerModel", # Building part of bigger (tested) model.
"ErnieMForInformationExtraction",
@@ -245,6 +246,8 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
"GPTSw3DoubleHeadsModel",
"InstructBlipVisionModel",
"InstructBlipQFormerModel",
"InstructBlipVideoVisionModel",
"InstructBlipVideoQFormerModel",
"LayoutLMForQuestionAnswering",
"LukeForMaskedLM",
"LukeForEntityClassification",

View File

@@ -173,7 +173,7 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
- LLaMa -> MyNewModel abd MyNewModel -> Llama
"""
def __init__(self, old_name, new_name):
def __init__(self, old_name, new_name, given_old_name=None, given_new_name=None):
super().__init__()
self.old_name = old_name
self.new_name = new_name
@@ -183,6 +183,8 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
old_name.upper(): new_name.upper(),
"".join(x.title() for x in old_name.split("_")): self.default_name,
}
if given_old_name is not None and given_new_name is not None and given_old_name not in self.patterns:
self.patterns[given_old_name] = given_new_name
def preserve_case_replace(self, text):
# Create a regex pattern to match all variations
@@ -201,9 +203,9 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
return updated_node.with_changes(value=update)
def find_classes_in_file(module: cst.Module, old_id="llama", new_id="gemma"):
def find_classes_in_file(module: cst.Module, old_id="llama", new_id="gemma", given_old_name=None, given_new_name=None):
"""Helper function to rename and then parse a source file using the ClassFinder"""
transformer = ReplaceNameTransformer(old_id, new_id)
transformer = ReplaceNameTransformer(old_id, new_id, given_old_name, given_new_name)
new_module = module.visit(transformer)
wrapper = MetadataWrapper(new_module)
@@ -356,11 +358,13 @@ def replace_call_to_super(class_finder: ClassFinder, updated_node: cst.ClassDef,
class DiffConverterTransformer(CSTTransformer):
METADATA_DEPENDENCIES = (ParentNodeProvider, ScopeProvider, PositionProvider)
def __init__(self, python_module, new_name):
def __init__(self, python_module, new_name, given_old_name=None, given_new_name=None):
super().__init__()
self.model_name = (
new_name # name of the model being defined. Should be in the format of `llama` or `layout_xlm` our `phi3`
)
self.given_old_name = given_old_name
self.given_new_name = given_new_name
# fmt: off
self.python_module = python_module # we store the original module to use `code_for_node`
self.transformers_imports = {} # maps the imports name like "from transformers.models.xxx" to the parsed AST module
@@ -426,6 +430,7 @@ class DiffConverterTransformer(CSTTransformer):
"insert_idx": self.global_scope_index,
"node": updated_node,
}
self.config_body = [updated_node]
return updated_node
def leave_ClassDef(self, original_node, updated_node):
@@ -457,13 +462,18 @@ class DiffConverterTransformer(CSTTransformer):
f"Tried parsing the name of the imported package from {super_file_name}, could not extract the model name"
)
if super_file_name not in self.visited_module: # only extract classes once
visited_module = self.visited_module
if super_file_name not in visited_module: # only extract classes once
class_finder = find_classes_in_file(
self.transformers_imports[super_file_name], model_name, self.model_name
self.transformers_imports[super_file_name],
model_name,
self.model_name,
self.given_old_name,
self.given_new_name,
)
self.visited_module[super_file_name] = class_finder
visited_module[super_file_name] = class_finder
else: # we are re-using the previously parsed data
class_finder = self.visited_module[super_file_name]
class_finder = visited_module[super_file_name]
list_dependencies = {
dep: class_finder.class_start_line.get(dep, 1000)
@@ -474,7 +484,7 @@ class DiffConverterTransformer(CSTTransformer):
start_insert_idx = self.global_scope_index
for dependency, _ in list_dependencies:
node = class_finder.global_nodes.get(dependency, None)
if node is not None:
if node is not None and "Config" not in class_name:
if dependency not in self.new_body:
start_insert_idx -= 1
self.new_body[dependency] = {"insert_idx": start_insert_idx, "node": node}
@@ -485,7 +495,7 @@ class DiffConverterTransformer(CSTTransformer):
if len(list_dependencies) > 0:
updated_node = replace_call_to_super(class_finder, updated_node, class_name)
if "Config" in class_name:
self.config_body = [updated_node]
self.config_body += [updated_node]
else:
self.new_body[class_name] = {"insert_idx": self.global_scope_index, "node": updated_node}
return updated_node
@@ -503,10 +513,24 @@ class DiffConverterTransformer(CSTTransformer):
def leave_Module(self, original_node: cst.Assign, node):
imports = {self.python_module.code_for_node(k): k for k in self.all_imports}
dependency_imports = {}
config_imports = []
for visiter in self.visited_module.values():
dependency_imports.update({self.python_module.code_for_node(k): k for k in visiter.imports.values()})
# manually clean up if it's importing a config from configuration file (ruff doesn't do that)
config_imports = []
for i in list(dependency_imports.values()):
if (
hasattr(i.body[0], "module")
and isinstance(i.body[0].module, cst.Name)
and f"configuration_{self.model_name}" in i.body[0].module.value
):
pass
else:
config_imports.append(i)
if hasattr(self, "config_body"):
self.config_body = list(imports.values()) + self.config_body
self.config_body = list(imports.values()) + config_imports + self.config_body
dependency_imports.update(imports)
new_body = list(dependency_imports.values())
if len(self.new_body.keys()) > 0:
@@ -516,7 +540,7 @@ class DiffConverterTransformer(CSTTransformer):
return node.with_changes(body=[*new_body])
def convert_file(diff_file, cst_transformers=None):
def convert_file(diff_file, old_model_name=None, new_model_name=None, cst_transformers=None):
model_name = re.search(r"diff_(.*)(?=\.py$)", diff_file).groups()[0]
# Parse the Python file
with open(diff_file, "r") as file:
@@ -524,7 +548,7 @@ def convert_file(diff_file, cst_transformers=None):
module = cst.parse_module(code)
wrapper = MetadataWrapper(module)
if cst_transformers is None:
cst_transformers = DiffConverterTransformer(module, model_name)
cst_transformers = DiffConverterTransformer(module, model_name, old_model_name, new_model_name)
new_mod = wrapper.visit(cst_transformers)
ruffed_code = run_ruff(new_mod.code, True)
formatted_code = run_ruff(ruffed_code, False)
@@ -551,10 +575,20 @@ if __name__ == "__main__":
nargs="+",
help="A list of `diff_xxxx` files that should be converted to single model file",
)
parser.add_argument(
"--old_model_name",
required=False,
help="The name of the model from which the copying is done in CamelCase. If not provided is inferred from diff-file",
)
parser.add_argument(
"--new_model_name",
required=False,
help="The name of the new model being added in CamelCase. If not provided is inferred from diff-file",
)
args = parser.parse_args()
if args.files_to_parse == ["all"]:
args.files_to_parse = glob.glob("src/transformers/models/**/diff_*.py", recursive=True)
for file_name in args.files_to_parse:
print(f"Converting {file_name} to a single model single file format")
module_path = file_name.replace("/", ".").replace(".py", "").replace("src.", "")
converter = convert_file(file_name)
converter = convert_file(file_name, args.old_model_name, args.new_model_name)