Copy code from Bert to Roberta and add safeguard script (#7219)
* Copy code from Bert to Roberta and add safeguard script * Fix docstring * Comment code * Formatting * Update src/transformers/modeling_roberta.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Add test and fix bugs * Fix style and make new comand Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
181
utils/check_copies.py
Normal file
181
utils/check_copies.py
Normal file
@@ -0,0 +1,181 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
|
||||
|
||||
# All paths are set with the intent you should run this script from the root of the repo with the command
|
||||
# python utils/check_copies.py
|
||||
TRANSFORMERS_PATH = "src/transformers"
|
||||
|
||||
|
||||
def find_code_in_transformers(object_name):
|
||||
""" Find and return the code source code of `object_name`."""
|
||||
parts = object_name.split(".")
|
||||
i = 0
|
||||
|
||||
# First let's find the module where our object lives.
|
||||
module = parts[i]
|
||||
while i < len(parts) and not os.path.isfile(os.path.join(TRANSFORMERS_PATH, f"{module}.py")):
|
||||
i += 1
|
||||
module = os.path.join(module, parts[i])
|
||||
if i >= len(parts):
|
||||
raise ValueError(
|
||||
f"`object_name` should begin with the name of a module of transformers but got {object_name}."
|
||||
)
|
||||
|
||||
with open(os.path.join(TRANSFORMERS_PATH, f"{module}.py"), "r") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# Now let's find the class / func in the code!
|
||||
indent = ""
|
||||
line_index = 0
|
||||
for name in parts[i + 1 :]:
|
||||
while line_index < len(lines) and re.search(f"^{indent}(class|def)\s+{name}", lines[line_index]) is None:
|
||||
line_index += 1
|
||||
indent += " "
|
||||
line_index += 1
|
||||
|
||||
if line_index >= len(lines):
|
||||
raise ValueError(f" {object_name} does not match any function or class in {module}.")
|
||||
|
||||
# We found the beginning of the class / func, now let's find the end (when the indent diminishes).
|
||||
start_index = line_index
|
||||
while line_index < len(lines) and (lines[line_index].startswith(indent) or len(lines[line_index]) <= 1):
|
||||
line_index += 1
|
||||
# Clean up empty lines at the end (if any).
|
||||
while len(lines[line_index - 1]) <= 1:
|
||||
line_index -= 1
|
||||
|
||||
code_lines = lines[start_index:line_index]
|
||||
return "".join(code_lines)
|
||||
|
||||
|
||||
_re_copy_warning = re.compile(r"^(\s*)#\s*Copied from\s+transformers\.(\S+\.\S+)\s*($|\S.*$)")
|
||||
_re_replace_pattern = re.compile(r"with\s+(\S+)->(\S+)(?:\s|$)")
|
||||
|
||||
|
||||
def blackify(code):
|
||||
"""
|
||||
Applies the black part of our `make style` command to `code`.
|
||||
"""
|
||||
has_indent = code.startswith(" ")
|
||||
if has_indent:
|
||||
code = f"class Bla:\n{code}"
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
fname = os.path.join(d, "tmp.py")
|
||||
with open(fname, "w") as f:
|
||||
f.write(code)
|
||||
os.system(f"black -q --line-length 119 --target-version py35 {fname}")
|
||||
with open(fname, "r") as f:
|
||||
result = f.read()
|
||||
return result[len("class Bla:\n") :] if has_indent else result
|
||||
|
||||
|
||||
def is_copy_consistent(filename, overwrite=False):
|
||||
"""
|
||||
Check if the code commented as a copy in `filename` matches the original.
|
||||
|
||||
Return the differences or overwrites the content depending on `overwrite`.
|
||||
"""
|
||||
with open(filename) as f:
|
||||
lines = f.readlines()
|
||||
found_diff = False
|
||||
line_index = 0
|
||||
# Not a foor loop cause `lines` is going to change (if `overwrite=True`).
|
||||
while line_index < len(lines):
|
||||
search = _re_copy_warning.search(lines[line_index])
|
||||
if search is None:
|
||||
line_index += 1
|
||||
continue
|
||||
|
||||
# There is some copied code here, let's retrieve the original.
|
||||
indent, object_name, replace_pattern = search.groups()
|
||||
theoretical_code = find_code_in_transformers(object_name)
|
||||
theoretical_indent = re.search(r"^(\s*)\S", theoretical_code).groups()[0]
|
||||
|
||||
start_index = line_index + 1 if indent == theoretical_indent else line_index + 2
|
||||
indent = theoretical_indent
|
||||
line_index = start_index
|
||||
|
||||
# Loop to check the observed code, stop when indentation diminishes or if we see a End copy comment.
|
||||
should_continue = True
|
||||
while line_index < len(lines) and should_continue:
|
||||
line_index += 1
|
||||
if line_index >= len(lines):
|
||||
break
|
||||
line = lines[line_index]
|
||||
should_continue = (len(line) <= 1 or line.startswith(indent)) and re.search(
|
||||
f"^{indent}# End copy", line
|
||||
) is None
|
||||
# Clean up empty lines at the end (if any).
|
||||
while len(lines[line_index - 1]) <= 1:
|
||||
line_index -= 1
|
||||
|
||||
observed_code_lines = lines[start_index:line_index]
|
||||
observed_code = "".join(observed_code_lines)
|
||||
|
||||
# Before comparing, use the `replace_pattern` on the original code.
|
||||
if len(replace_pattern) > 0:
|
||||
search_patterns = _re_replace_pattern.search(replace_pattern)
|
||||
if search_patterns is not None:
|
||||
obj1, obj2 = search_patterns.groups()
|
||||
theoretical_code = re.sub(obj1, obj2, theoretical_code)
|
||||
|
||||
# Blackify each version before comparing them.
|
||||
observed_code = blackify(observed_code)
|
||||
theoretical_code = blackify(theoretical_code)
|
||||
|
||||
# Test for a diff and act accordingly.
|
||||
if observed_code != theoretical_code:
|
||||
found_diff = True
|
||||
if overwrite:
|
||||
lines = lines[:start_index] + [theoretical_code] + lines[line_index:]
|
||||
line_index = start_index + 1
|
||||
|
||||
if overwrite and found_diff:
|
||||
# Warn the user a file has been modified.
|
||||
print(f"Detected changes, rewriting {filename}.")
|
||||
with open(filename, "w") as f:
|
||||
f.writelines(lines)
|
||||
return not found_diff
|
||||
|
||||
|
||||
def check_copies(overwrite: bool = False):
|
||||
all_files = glob.glob(os.path.join(TRANSFORMERS_PATH, "**/*.py"), recursive=True)
|
||||
diffs = []
|
||||
for filename in all_files:
|
||||
consistent = is_copy_consistent(filename, overwrite)
|
||||
if not consistent:
|
||||
diffs.append(filename)
|
||||
if not overwrite and len(diffs) > 0:
|
||||
diff = "\n".join(diffs)
|
||||
raise Exception(
|
||||
"Found copy inconsistencies in the following files:\n"
|
||||
+ diff
|
||||
+ "\nRun `make fix-copies` or `python utils/check_copies --fix_and_overwrite` to fix them."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
|
||||
args = parser.parse_args()
|
||||
|
||||
check_copies(args.fix_and_overwrite)
|
||||
@@ -1,3 +1,18 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
|
||||
Reference in New Issue
Block a user