Check copies blackify (#10775)
* Apply black before checking copies * Fix for class methods * Deal with lonely brackets * Remove debug and add forward changes * Separate copies and fix test * Add black as a test dependency
This commit is contained in:
@@ -17,7 +17,8 @@ import argparse
|
||||
import glob
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
|
||||
import black
|
||||
|
||||
|
||||
# All paths are set with the intent you should run this script from the root of the repo with the command
|
||||
@@ -27,6 +28,10 @@ PATH_TO_DOCS = "docs/source"
|
||||
REPO_PATH = "."
|
||||
|
||||
|
||||
def _should_continue(line, indent):
|
||||
return line.startswith(indent) or len(line) <= 1 or re.search(r"^\s*\):\s*$", line) is not None
|
||||
|
||||
|
||||
def find_code_in_transformers(object_name):
|
||||
""" Find and return the code source code of `object_name`."""
|
||||
parts = object_name.split(".")
|
||||
@@ -62,7 +67,7 @@ def find_code_in_transformers(object_name):
|
||||
|
||||
# We found the beginning of the class / func, now let's find the end (when the indent diminishes).
|
||||
start_index = line_index
|
||||
while line_index < len(lines) and (lines[line_index].startswith(indent) or len(lines[line_index]) <= 1):
|
||||
while line_index < len(lines) and _should_continue(lines[line_index], indent):
|
||||
line_index += 1
|
||||
# Clean up empty lines at the end (if any).
|
||||
while len(lines[line_index - 1]) <= 1:
|
||||
@@ -76,23 +81,6 @@ _re_copy_warning = re.compile(r"^(\s*)#\s*Copied from\s+transformers\.(\S+\.\S+)
|
||||
_re_replace_pattern = re.compile(r"^\s*(\S+)->(\S+)(\s+.*|$)")
|
||||
|
||||
|
||||
def blackify(code):
|
||||
"""
|
||||
Applies the black part of our `make style` command to `code`.
|
||||
"""
|
||||
has_indent = code.startswith(" ")
|
||||
if has_indent:
|
||||
code = f"class Bla:\n{code}"
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
fname = os.path.join(d, "tmp.py")
|
||||
with open(fname, "w", encoding="utf-8", newline="\n") as f:
|
||||
f.write(code)
|
||||
os.system(f"black -q --line-length 119 --target-version py35 {fname}")
|
||||
with open(fname, "r", encoding="utf-8", newline="\n") as f:
|
||||
result = f.read()
|
||||
return result[len("class Bla:\n") :] if has_indent else result
|
||||
|
||||
|
||||
def get_indent(code):
|
||||
lines = code.split("\n")
|
||||
idx = 0
|
||||
@@ -100,7 +88,18 @@ def get_indent(code):
|
||||
idx += 1
|
||||
if idx < len(lines):
|
||||
return re.search(r"^(\s*)\S", lines[idx]).groups()[0]
|
||||
return 0
|
||||
return ""
|
||||
|
||||
|
||||
def blackify(code):
|
||||
"""
|
||||
Applies the black part of our `make style` command to `code`.
|
||||
"""
|
||||
has_indent = len(get_indent(code)) > 0
|
||||
if has_indent:
|
||||
code = f"class Bla:\n{code}"
|
||||
result = black.format_str(code, mode=black.FileMode([black.TargetVersion.PY35], line_length=119))
|
||||
return result[len("class Bla:\n") :] if has_indent else result
|
||||
|
||||
|
||||
def is_copy_consistent(filename, overwrite=False):
|
||||
@@ -136,9 +135,7 @@ def is_copy_consistent(filename, overwrite=False):
|
||||
if line_index >= len(lines):
|
||||
break
|
||||
line = lines[line_index]
|
||||
should_continue = (len(line) <= 1 or line.startswith(indent)) and re.search(
|
||||
f"^{indent}# End copy", line
|
||||
) is None
|
||||
should_continue = _should_continue(line, indent) and re.search(f"^{indent}# End copy", line) is None
|
||||
# Clean up empty lines at the end (if any).
|
||||
while len(lines[line_index - 1]) <= 1:
|
||||
line_index -= 1
|
||||
@@ -159,6 +156,11 @@ def is_copy_consistent(filename, overwrite=False):
|
||||
theoretical_code = re.sub(obj1.lower(), obj2.lower(), theoretical_code)
|
||||
theoretical_code = re.sub(obj1.upper(), obj2.upper(), theoretical_code)
|
||||
|
||||
# Blackify after replacement. To be able to do that, we need the header (class or function definition)
|
||||
# from the previous line
|
||||
theoretical_code = blackify(lines[start_index - 1] + theoretical_code)
|
||||
theoretical_code = theoretical_code[len(lines[start_index - 1]) :]
|
||||
|
||||
# Test for a diff and act accordingly.
|
||||
if observed_code != theoretical_code:
|
||||
diffs.append([object_name, start_index])
|
||||
|
||||
Reference in New Issue
Block a user