Check copies blackify (#10775)
* Apply black before checking copies * Fix for class methods * Deal with lonely brackets * Remove debug and add forward changes * Separate copies and fix test * Add black as a test dependency
This commit is contained in:
2
setup.py
2
setup.py
@@ -228,7 +228,7 @@ extras["speech"] = deps_list("soundfile", "torchaudio")
|
|||||||
|
|
||||||
extras["sentencepiece"] = deps_list("sentencepiece", "protobuf")
|
extras["sentencepiece"] = deps_list("sentencepiece", "protobuf")
|
||||||
extras["testing"] = (
|
extras["testing"] = (
|
||||||
deps_list("pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil", "datasets", "pytest-sugar")
|
deps_list("pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil", "datasets", "pytest-sugar", "black")
|
||||||
+ extras["retrieval"]
|
+ extras["retrieval"]
|
||||||
+ extras["modelcreation"]
|
+ extras["modelcreation"]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -671,7 +671,6 @@ class M2M100Encoder(M2M100PreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
# Copied from transformers.models.mbart.modeling_mbart.MBartEncoder.forward with MBart->M2M100
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids=None,
|
||||||
@@ -830,7 +829,6 @@ class M2M100Decoder(M2M100PreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
# Copied from transformers.models.mbart.modeling_mbart.MBartDecoder.forward with MBart->M2M100
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids=None,
|
||||||
|
|||||||
@@ -1398,6 +1398,7 @@ class MobileBertForQuestionAnswering(MobileBertPreTrainedModel):
|
|||||||
""",
|
""",
|
||||||
MOBILEBERT_START_DOCSTRING,
|
MOBILEBERT_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
|
# Copied from transformers.models.bert.modeling_bert.BertForMultipleChoice with Bert->MobileBert all-casing
|
||||||
class MobileBertForMultipleChoice(MobileBertPreTrainedModel):
|
class MobileBertForMultipleChoice(MobileBertPreTrainedModel):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
@@ -1417,7 +1418,6 @@ class MobileBertForMultipleChoice(MobileBertPreTrainedModel):
|
|||||||
output_type=MultipleChoiceModelOutput,
|
output_type=MultipleChoiceModelOutput,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
)
|
)
|
||||||
# Copied from transformers.models.bert.modeling_bert.BertForMultipleChoice.forward with Bert->MobileBert all-casing
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids=None,
|
||||||
|
|||||||
@@ -737,8 +737,10 @@ class RobertaModel(RobertaPreTrainedModel):
|
|||||||
the model is configured as a decoder.
|
the model is configured as a decoder.
|
||||||
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
||||||
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: ``1`` for
|
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
||||||
tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
||||||
|
|
||||||
@@ -754,9 +756,10 @@ class RobertaModel(RobertaPreTrainedModel):
|
|||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
)
|
)
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
||||||
|
|
||||||
if not self.config.is_decoder:
|
if self.config.is_decoder:
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
else:
|
||||||
use_cache = False
|
use_cache = False
|
||||||
|
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
|||||||
@@ -872,7 +872,6 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel):
|
|||||||
|
|
||||||
return combined_attention_mask
|
return combined_attention_mask
|
||||||
|
|
||||||
# Copied from transformers.models.mbart.modeling_mbart.MBartDecoder.forward with MBart->Speech2Text
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids=None,
|
||||||
|
|||||||
@@ -19,6 +19,8 @@ import sys
|
|||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import black
|
||||||
|
|
||||||
|
|
||||||
git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
|
git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
|
||||||
sys.path.append(os.path.join(git_repo_path, "utils"))
|
sys.path.append(os.path.join(git_repo_path, "utils"))
|
||||||
@@ -66,6 +68,7 @@ class CopyCheckTester(unittest.TestCase):
|
|||||||
code = comment + f"\nclass {class_name}(nn.Module):\n" + class_code
|
code = comment + f"\nclass {class_name}(nn.Module):\n" + class_code
|
||||||
if overwrite_result is not None:
|
if overwrite_result is not None:
|
||||||
expected = comment + f"\nclass {class_name}(nn.Module):\n" + overwrite_result
|
expected = comment + f"\nclass {class_name}(nn.Module):\n" + overwrite_result
|
||||||
|
code = black.format_str(code, mode=black.FileMode([black.TargetVersion.PY35], line_length=119))
|
||||||
fname = os.path.join(self.transformer_dir, "new_code.py")
|
fname = os.path.join(self.transformer_dir, "new_code.py")
|
||||||
with open(fname, "w") as f:
|
with open(fname, "w") as f:
|
||||||
f.write(code)
|
f.write(code)
|
||||||
@@ -103,7 +106,7 @@ class CopyCheckTester(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Copy consistency with a really long name
|
# Copy consistency with a really long name
|
||||||
long_class_name = "TestModelWithAReallyLongNameBecauseSomePeopleLikeThatForSomeReasonIReallyDontUnderstand"
|
long_class_name = "TestModelWithAReallyLongNameBecauseSomePeopleLikeThatForSomeReason"
|
||||||
self.check_copy_consistency(
|
self.check_copy_consistency(
|
||||||
f"# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->{long_class_name}",
|
f"# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->{long_class_name}",
|
||||||
f"{long_class_name}LMPredictionHead",
|
f"{long_class_name}LMPredictionHead",
|
||||||
|
|||||||
@@ -17,7 +17,8 @@ import argparse
|
|||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import tempfile
|
|
||||||
|
import black
|
||||||
|
|
||||||
|
|
||||||
# All paths are set with the intent you should run this script from the root of the repo with the command
|
# All paths are set with the intent you should run this script from the root of the repo with the command
|
||||||
@@ -27,6 +28,10 @@ PATH_TO_DOCS = "docs/source"
|
|||||||
REPO_PATH = "."
|
REPO_PATH = "."
|
||||||
|
|
||||||
|
|
||||||
|
def _should_continue(line, indent):
|
||||||
|
return line.startswith(indent) or len(line) <= 1 or re.search(r"^\s*\):\s*$", line) is not None
|
||||||
|
|
||||||
|
|
||||||
def find_code_in_transformers(object_name):
|
def find_code_in_transformers(object_name):
|
||||||
""" Find and return the code source code of `object_name`."""
|
""" Find and return the code source code of `object_name`."""
|
||||||
parts = object_name.split(".")
|
parts = object_name.split(".")
|
||||||
@@ -62,7 +67,7 @@ def find_code_in_transformers(object_name):
|
|||||||
|
|
||||||
# We found the beginning of the class / func, now let's find the end (when the indent diminishes).
|
# We found the beginning of the class / func, now let's find the end (when the indent diminishes).
|
||||||
start_index = line_index
|
start_index = line_index
|
||||||
while line_index < len(lines) and (lines[line_index].startswith(indent) or len(lines[line_index]) <= 1):
|
while line_index < len(lines) and _should_continue(lines[line_index], indent):
|
||||||
line_index += 1
|
line_index += 1
|
||||||
# Clean up empty lines at the end (if any).
|
# Clean up empty lines at the end (if any).
|
||||||
while len(lines[line_index - 1]) <= 1:
|
while len(lines[line_index - 1]) <= 1:
|
||||||
@@ -76,23 +81,6 @@ _re_copy_warning = re.compile(r"^(\s*)#\s*Copied from\s+transformers\.(\S+\.\S+)
|
|||||||
_re_replace_pattern = re.compile(r"^\s*(\S+)->(\S+)(\s+.*|$)")
|
_re_replace_pattern = re.compile(r"^\s*(\S+)->(\S+)(\s+.*|$)")
|
||||||
|
|
||||||
|
|
||||||
def blackify(code):
|
|
||||||
"""
|
|
||||||
Applies the black part of our `make style` command to `code`.
|
|
||||||
"""
|
|
||||||
has_indent = code.startswith(" ")
|
|
||||||
if has_indent:
|
|
||||||
code = f"class Bla:\n{code}"
|
|
||||||
with tempfile.TemporaryDirectory() as d:
|
|
||||||
fname = os.path.join(d, "tmp.py")
|
|
||||||
with open(fname, "w", encoding="utf-8", newline="\n") as f:
|
|
||||||
f.write(code)
|
|
||||||
os.system(f"black -q --line-length 119 --target-version py35 {fname}")
|
|
||||||
with open(fname, "r", encoding="utf-8", newline="\n") as f:
|
|
||||||
result = f.read()
|
|
||||||
return result[len("class Bla:\n") :] if has_indent else result
|
|
||||||
|
|
||||||
|
|
||||||
def get_indent(code):
|
def get_indent(code):
|
||||||
lines = code.split("\n")
|
lines = code.split("\n")
|
||||||
idx = 0
|
idx = 0
|
||||||
@@ -100,7 +88,18 @@ def get_indent(code):
|
|||||||
idx += 1
|
idx += 1
|
||||||
if idx < len(lines):
|
if idx < len(lines):
|
||||||
return re.search(r"^(\s*)\S", lines[idx]).groups()[0]
|
return re.search(r"^(\s*)\S", lines[idx]).groups()[0]
|
||||||
return 0
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def blackify(code):
|
||||||
|
"""
|
||||||
|
Applies the black part of our `make style` command to `code`.
|
||||||
|
"""
|
||||||
|
has_indent = len(get_indent(code)) > 0
|
||||||
|
if has_indent:
|
||||||
|
code = f"class Bla:\n{code}"
|
||||||
|
result = black.format_str(code, mode=black.FileMode([black.TargetVersion.PY35], line_length=119))
|
||||||
|
return result[len("class Bla:\n") :] if has_indent else result
|
||||||
|
|
||||||
|
|
||||||
def is_copy_consistent(filename, overwrite=False):
|
def is_copy_consistent(filename, overwrite=False):
|
||||||
@@ -136,9 +135,7 @@ def is_copy_consistent(filename, overwrite=False):
|
|||||||
if line_index >= len(lines):
|
if line_index >= len(lines):
|
||||||
break
|
break
|
||||||
line = lines[line_index]
|
line = lines[line_index]
|
||||||
should_continue = (len(line) <= 1 or line.startswith(indent)) and re.search(
|
should_continue = _should_continue(line, indent) and re.search(f"^{indent}# End copy", line) is None
|
||||||
f"^{indent}# End copy", line
|
|
||||||
) is None
|
|
||||||
# Clean up empty lines at the end (if any).
|
# Clean up empty lines at the end (if any).
|
||||||
while len(lines[line_index - 1]) <= 1:
|
while len(lines[line_index - 1]) <= 1:
|
||||||
line_index -= 1
|
line_index -= 1
|
||||||
@@ -159,6 +156,11 @@ def is_copy_consistent(filename, overwrite=False):
|
|||||||
theoretical_code = re.sub(obj1.lower(), obj2.lower(), theoretical_code)
|
theoretical_code = re.sub(obj1.lower(), obj2.lower(), theoretical_code)
|
||||||
theoretical_code = re.sub(obj1.upper(), obj2.upper(), theoretical_code)
|
theoretical_code = re.sub(obj1.upper(), obj2.upper(), theoretical_code)
|
||||||
|
|
||||||
|
# Blackify after replacement. To be able to do that, we need the header (class or function definition)
|
||||||
|
# from the previous line
|
||||||
|
theoretical_code = blackify(lines[start_index - 1] + theoretical_code)
|
||||||
|
theoretical_code = theoretical_code[len(lines[start_index - 1]) :]
|
||||||
|
|
||||||
# Test for a diff and act accordingly.
|
# Test for a diff and act accordingly.
|
||||||
if observed_code != theoretical_code:
|
if observed_code != theoretical_code:
|
||||||
diffs.append([object_name, start_index])
|
diffs.append([object_name, start_index])
|
||||||
|
|||||||
Reference in New Issue
Block a user