Check copies blackify (#10775)
* Apply black before checking copies * Fix for class methods * Deal with lonely brackets * Remove debug and add forward changes * Separate copies and fix test * Add black as a test dependency
This commit is contained in:
@@ -19,6 +19,8 @@ import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import black
|
||||
|
||||
|
||||
git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
|
||||
sys.path.append(os.path.join(git_repo_path, "utils"))
|
||||
@@ -66,6 +68,7 @@ class CopyCheckTester(unittest.TestCase):
|
||||
code = comment + f"\nclass {class_name}(nn.Module):\n" + class_code
|
||||
if overwrite_result is not None:
|
||||
expected = comment + f"\nclass {class_name}(nn.Module):\n" + overwrite_result
|
||||
code = black.format_str(code, mode=black.FileMode([black.TargetVersion.PY35], line_length=119))
|
||||
fname = os.path.join(self.transformer_dir, "new_code.py")
|
||||
with open(fname, "w") as f:
|
||||
f.write(code)
|
||||
@@ -103,7 +106,7 @@ class CopyCheckTester(unittest.TestCase):
|
||||
)
|
||||
|
||||
# Copy consistency with a really long name
|
||||
long_class_name = "TestModelWithAReallyLongNameBecauseSomePeopleLikeThatForSomeReasonIReallyDontUnderstand"
|
||||
long_class_name = "TestModelWithAReallyLongNameBecauseSomePeopleLikeThatForSomeReason"
|
||||
self.check_copy_consistency(
|
||||
f"# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->{long_class_name}",
|
||||
f"{long_class_name}LMPredictionHead",
|
||||
|
||||
Reference in New Issue
Block a user