Refactor checkpoint name in BERT and MobileBERT (#10424)
* Refactor checkpoint name in BERT and MobileBERT * Add option to check copies * Add QuestionAnswering * Add last models * Make black happy
This commit is contained in:
@@ -73,7 +73,7 @@ def find_code_in_transformers(object_name):
|
||||
|
||||
|
||||
_re_copy_warning = re.compile(r"^(\s*)#\s*Copied from\s+transformers\.(\S+\.\S+)\s*($|\S.*$)")
|
||||
_re_replace_pattern = re.compile(r"with\s+(\S+)->(\S+)(?:\s|$)")
|
||||
_re_replace_pattern = re.compile(r"^\s*(\S+)->(\S+)(\s+.*|$)")
|
||||
|
||||
|
||||
def blackify(code):
|
||||
@@ -93,6 +93,16 @@ def blackify(code):
|
||||
return result[len("class Bla:\n") :] if has_indent else result
|
||||
|
||||
|
||||
def get_indent(code):
|
||||
lines = code.split("\n")
|
||||
idx = 0
|
||||
while idx < len(lines) and len(lines[idx]) == 0:
|
||||
idx += 1
|
||||
if idx < len(lines):
|
||||
return re.search(r"^(\s*)\S", lines[idx]).groups()[0]
|
||||
return 0
|
||||
|
||||
|
||||
def is_copy_consistent(filename, overwrite=False):
|
||||
"""
|
||||
Check if the code commented as a copy in `filename` matches the original.
|
||||
@@ -113,7 +123,7 @@ def is_copy_consistent(filename, overwrite=False):
|
||||
# There is some copied code here, let's retrieve the original.
|
||||
indent, object_name, replace_pattern = search.groups()
|
||||
theoretical_code = find_code_in_transformers(object_name)
|
||||
theoretical_indent = re.search(r"^(\s*)\S", theoretical_code).groups()[0]
|
||||
theoretical_indent = get_indent(theoretical_code)
|
||||
|
||||
start_index = line_index + 1 if indent == theoretical_indent else line_index + 2
|
||||
indent = theoretical_indent
|
||||
@@ -138,10 +148,16 @@ def is_copy_consistent(filename, overwrite=False):
|
||||
|
||||
# Before comparing, use the `replace_pattern` on the original code.
|
||||
if len(replace_pattern) > 0:
|
||||
search_patterns = _re_replace_pattern.search(replace_pattern)
|
||||
if search_patterns is not None:
|
||||
obj1, obj2 = search_patterns.groups()
|
||||
patterns = replace_pattern.replace("with", "").split(",")
|
||||
patterns = [_re_replace_pattern.search(p) for p in patterns]
|
||||
for pattern in patterns:
|
||||
if pattern is None:
|
||||
continue
|
||||
obj1, obj2, option = pattern.groups()
|
||||
theoretical_code = re.sub(obj1, obj2, theoretical_code)
|
||||
if option.strip() == "all-casing":
|
||||
theoretical_code = re.sub(obj1.lower(), obj2.lower(), theoretical_code)
|
||||
theoretical_code = re.sub(obj1.upper(), obj2.upper(), theoretical_code)
|
||||
|
||||
# Test for a diff and act accordingly.
|
||||
if observed_code != theoretical_code:
|
||||
|
||||
Reference in New Issue
Block a user