Upgrade black to version ~=22.0 (#15565)

* Upgrade black to version ~=22.0

* Check copies

* Fix code
This commit is contained in:
Lysandre Debut
2022-02-09 09:28:57 -05:00
committed by GitHub
parent d923f76203
commit 7732d0fe7a
91 changed files with 208 additions and 225 deletions

View File

@@ -229,20 +229,14 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
assert end_logits_tea.size() == end_logits_stu.size()
loss_fct = nn.KLDivLoss(reduction="batchmean")
loss_start = (
loss_fct(
loss_start = loss_fct(
nn.functional.log_softmax(start_logits_stu / args.temperature, dim=-1),
nn.functional.softmax(start_logits_tea / args.temperature, dim=-1),
)
* (args.temperature ** 2)
)
loss_end = (
loss_fct(
) * (args.temperature**2)
loss_end = loss_fct(
nn.functional.log_softmax(end_logits_stu / args.temperature, dim=-1),
nn.functional.softmax(end_logits_tea / args.temperature, dim=-1),
)
* (args.temperature ** 2)
)
) * (args.temperature**2)
loss_ce = (loss_start + loss_end) / 2.0
loss = args.alpha_ce * loss_ce + args.alpha_squad * loss

View File

@@ -285,14 +285,11 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
attention_mask=inputs["attention_mask"],
)
loss_logits = (
nn.functional.kl_div(
loss_logits = nn.functional.kl_div(
input=nn.functional.log_softmax(logits_stu / args.temperature, dim=-1),
target=nn.functional.softmax(logits_tea / args.temperature, dim=-1),
reduction="batchmean",
)
* (args.temperature ** 2)
)
) * (args.temperature**2)
loss = args.alpha_distil * loss_logits + args.alpha_ce * loss

View File

@@ -306,22 +306,16 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
attention_mask=inputs["attention_mask"],
)
loss_start = (
nn.functional.kl_div(
loss_start = nn.functional.kl_div(
input=nn.functional.log_softmax(start_logits_stu / args.temperature, dim=-1),
target=nn.functional.softmax(start_logits_tea / args.temperature, dim=-1),
reduction="batchmean",
)
* (args.temperature ** 2)
)
loss_end = (
nn.functional.kl_div(
) * (args.temperature**2)
loss_end = nn.functional.kl_div(
input=nn.functional.log_softmax(end_logits_stu / args.temperature, dim=-1),
target=nn.functional.softmax(end_logits_tea / args.temperature, dim=-1),
reduction="batchmean",
)
* (args.temperature ** 2)
)
) * (args.temperature**2)
loss_logits = (loss_start + loss_end) / 2.0
loss = args.alpha_distil * loss_logits + args.alpha_ce * loss

View File

@@ -93,7 +93,7 @@ if stale_egg_info.exists():
# 2. once modified, run: `make deps_table_update` to update src/transformers/dependency_versions_table.py
_deps = [
"Pillow",
"black==21.4b0",
"black~=22.0",
"codecarbon==1.2.0",
"cookiecutter==1.7.2",
"dataclasses",
@@ -166,7 +166,7 @@ _deps = [
# packaging: "packaging"
#
# some of the values are versioned whereas others aren't.
deps = {b: a for a, b in (re.findall(r"^(([^!=<>]+)(?:[!=<>].*)?$)", x)[0] for x in _deps)}
deps = {b: a for a, b in (re.findall(r"^(([^!=<>~]+)(?:[!=<>~].*)?$)", x)[0] for x in _deps)}
# since we save this data in src/transformers/dependency_versions_table.py it can be easily accessed from
# anywhere. If you need to quickly access the data from this table in a shell, you can do so easily with:

View File

@@ -292,7 +292,7 @@ def replace_model_patterns(
attributes_to_check.append("model_type")
else:
text = re.sub(
fr'(\s*)model_type = "{old_model_patterns.model_type}"',
rf'(\s*)model_type = "{old_model_patterns.model_type}"',
r'\1model_type = "[MODEL_TYPE]"',
text,
)
@@ -301,8 +301,8 @@ def replace_model_patterns(
# not the new one. We can't just do a replace in all the text and will need a special regex
if old_model_patterns.model_upper_cased == old_model_patterns.model_camel_cased:
old_model_value = old_model_patterns.model_upper_cased
if re.search(fr"{old_model_value}_[A-Z_]*[^A-Z_]", text) is not None:
text = re.sub(fr"{old_model_value}([A-Z_]*)([^a-zA-Z_])", r"[MODEL_UPPER_CASED]\1\2", text)
if re.search(rf"{old_model_value}_[A-Z_]*[^A-Z_]", text) is not None:
text = re.sub(rf"{old_model_value}([A-Z_]*)([^a-zA-Z_])", r"[MODEL_UPPER_CASED]\1\2", text)
else:
attributes_to_check.append("model_upper_cased")
@@ -750,8 +750,8 @@ def clean_frameworks_in_init(
return
remove_pattern = "|".join(to_remove)
re_conditional_imports = re.compile(fr"^\s*if is_({remove_pattern})_available\(\):\s*$")
re_is_xxx_available = re.compile(fr"is_({remove_pattern})_available")
re_conditional_imports = re.compile(rf"^\s*if is_({remove_pattern})_available\(\):\s*$")
re_is_xxx_available = re.compile(rf"is_({remove_pattern})_available")
with open(init_file, "r", encoding="utf-8") as f:
content = f.read()
@@ -831,7 +831,7 @@ def add_model_to_main_init(
if framework is not None and frameworks is not None and framework not in frameworks:
new_lines.append(lines[idx])
idx += 1
elif re.search(fr'models.{old_model_patterns.model_lower_cased}( |")', lines[idx]) is not None:
elif re.search(rf'models.{old_model_patterns.model_lower_cased}( |")', lines[idx]) is not None:
block = [lines[idx]]
indent = find_indent(lines[idx])
idx += 1

View File

@@ -3,7 +3,7 @@
# 2. run `make deps_table_update``
deps = {
"Pillow": "Pillow",
"black": "black==21.4b0",
"black": "black~=22.0",
"codecarbon": "codecarbon==1.2.0",
"cookiecutter": "cookiecutter==1.7.2",
"dataclasses": "dataclasses",

View File

@@ -405,13 +405,10 @@ def _compute_global_attention_mask(input_ids_shape, sep_token_indices, before_se
else:
# last token is separation token and should not be counted and in the middle are two separation tokens
question_end_index = tf.tile(question_end_index + 1, (1, input_ids_shape[1]))
attention_mask = (
tf.cast(
attention_mask = tf.cast(
attention_mask > question_end_index,
dtype=question_end_index.dtype,
)
* tf.cast(attention_mask < input_ids_shape[-1], dtype=question_end_index.dtype)
)
) * tf.cast(attention_mask < input_ids_shape[-1], dtype=question_end_index.dtype)
return attention_mask

View File

@@ -68,7 +68,8 @@ class CopyCheckTester(unittest.TestCase):
code = comment + f"\nclass {class_name}(nn.Module):\n" + class_code
if overwrite_result is not None:
expected = comment + f"\nclass {class_name}(nn.Module):\n" + overwrite_result
code = black.format_str(code, mode=black.FileMode([black.TargetVersion.PY35], line_length=119))
mode = black.Mode(target_versions={black.TargetVersion.PY35}, line_length=119)
code = black.format_str(code, mode=mode)
fname = os.path.join(self.transformer_dir, "new_code.py")
with open(fname, "w", newline="\n") as f:
f.write(code)

View File

@@ -88,7 +88,7 @@ def find_code_in_transformers(object_name):
line_index = 0
for name in parts[i + 1 :]:
while (
line_index < len(lines) and re.search(fr"^{indent}(class|def)\s+{name}(\(|\:)", lines[line_index]) is None
line_index < len(lines) and re.search(rf"^{indent}(class|def)\s+{name}(\(|\:)", lines[line_index]) is None
):
line_index += 1
indent += " "
@@ -130,7 +130,8 @@ def blackify(code):
has_indent = len(get_indent(code)) > 0
if has_indent:
code = f"class Bla:\n{code}"
result = black.format_str(code, mode=black.FileMode([black.TargetVersion.PY35], line_length=119))
mode = black.Mode(target_versions={black.TargetVersion.PY35}, line_length=119)
result = black.format_str(code, mode=mode)
result, _ = style_docstrings_in_code(result)
return result[len("class Bla:\n") :] if has_indent else result

View File

@@ -28,7 +28,7 @@ fork_point_sha = subprocess.check_output("git merge-base master HEAD".split()).d
modified_files = subprocess.check_output(f"git diff --name-only {fork_point_sha}".split()).decode("utf-8").split()
joined_dirs = "|".join(sys.argv[1:])
regex = re.compile(fr"^({joined_dirs}).*?\.py$")
regex = re.compile(rf"^({joined_dirs}).*?\.py$")
relevant_modified_files = [x for x in modified_files if regex.match(x)]
print(" ".join(relevant_modified_files), end="")

View File

@@ -147,9 +147,8 @@ def format_code_example(code: str, max_len: int, in_docstring: bool = False):
for k, v in BLACK_AVOID_PATTERNS.items():
full_code = full_code.replace(k, v)
try:
formatted_code = black.format_str(
full_code, mode=black.FileMode([black.TargetVersion.PY37], line_length=line_length)
)
mode = black.Mode(target_versions={black.TargetVersion.PY37}, line_length=line_length)
formatted_code = black.format_str(full_code, mode=mode)
error = ""
except Exception as e:
formatted_code = full_code