Upgrade black to version ~=22.0 (#15565)
* Upgrade black to version ~=22.0 * Check copies * Fix code
This commit is contained in:
@@ -229,20 +229,14 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
|
||||
assert end_logits_tea.size() == end_logits_stu.size()
|
||||
|
||||
loss_fct = nn.KLDivLoss(reduction="batchmean")
|
||||
loss_start = (
|
||||
loss_fct(
|
||||
loss_start = loss_fct(
|
||||
nn.functional.log_softmax(start_logits_stu / args.temperature, dim=-1),
|
||||
nn.functional.softmax(start_logits_tea / args.temperature, dim=-1),
|
||||
)
|
||||
* (args.temperature ** 2)
|
||||
)
|
||||
loss_end = (
|
||||
loss_fct(
|
||||
) * (args.temperature**2)
|
||||
loss_end = loss_fct(
|
||||
nn.functional.log_softmax(end_logits_stu / args.temperature, dim=-1),
|
||||
nn.functional.softmax(end_logits_tea / args.temperature, dim=-1),
|
||||
)
|
||||
* (args.temperature ** 2)
|
||||
)
|
||||
) * (args.temperature**2)
|
||||
loss_ce = (loss_start + loss_end) / 2.0
|
||||
|
||||
loss = args.alpha_ce * loss_ce + args.alpha_squad * loss
|
||||
|
||||
@@ -285,14 +285,11 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
|
||||
attention_mask=inputs["attention_mask"],
|
||||
)
|
||||
|
||||
loss_logits = (
|
||||
nn.functional.kl_div(
|
||||
loss_logits = nn.functional.kl_div(
|
||||
input=nn.functional.log_softmax(logits_stu / args.temperature, dim=-1),
|
||||
target=nn.functional.softmax(logits_tea / args.temperature, dim=-1),
|
||||
reduction="batchmean",
|
||||
)
|
||||
* (args.temperature ** 2)
|
||||
)
|
||||
) * (args.temperature**2)
|
||||
|
||||
loss = args.alpha_distil * loss_logits + args.alpha_ce * loss
|
||||
|
||||
|
||||
@@ -306,22 +306,16 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
|
||||
attention_mask=inputs["attention_mask"],
|
||||
)
|
||||
|
||||
loss_start = (
|
||||
nn.functional.kl_div(
|
||||
loss_start = nn.functional.kl_div(
|
||||
input=nn.functional.log_softmax(start_logits_stu / args.temperature, dim=-1),
|
||||
target=nn.functional.softmax(start_logits_tea / args.temperature, dim=-1),
|
||||
reduction="batchmean",
|
||||
)
|
||||
* (args.temperature ** 2)
|
||||
)
|
||||
loss_end = (
|
||||
nn.functional.kl_div(
|
||||
) * (args.temperature**2)
|
||||
loss_end = nn.functional.kl_div(
|
||||
input=nn.functional.log_softmax(end_logits_stu / args.temperature, dim=-1),
|
||||
target=nn.functional.softmax(end_logits_tea / args.temperature, dim=-1),
|
||||
reduction="batchmean",
|
||||
)
|
||||
* (args.temperature ** 2)
|
||||
)
|
||||
) * (args.temperature**2)
|
||||
loss_logits = (loss_start + loss_end) / 2.0
|
||||
|
||||
loss = args.alpha_distil * loss_logits + args.alpha_ce * loss
|
||||
|
||||
4
setup.py
4
setup.py
@@ -93,7 +93,7 @@ if stale_egg_info.exists():
|
||||
# 2. once modified, run: `make deps_table_update` to update src/transformers/dependency_versions_table.py
|
||||
_deps = [
|
||||
"Pillow",
|
||||
"black==21.4b0",
|
||||
"black~=22.0",
|
||||
"codecarbon==1.2.0",
|
||||
"cookiecutter==1.7.2",
|
||||
"dataclasses",
|
||||
@@ -166,7 +166,7 @@ _deps = [
|
||||
# packaging: "packaging"
|
||||
#
|
||||
# some of the values are versioned whereas others aren't.
|
||||
deps = {b: a for a, b in (re.findall(r"^(([^!=<>]+)(?:[!=<>].*)?$)", x)[0] for x in _deps)}
|
||||
deps = {b: a for a, b in (re.findall(r"^(([^!=<>~]+)(?:[!=<>~].*)?$)", x)[0] for x in _deps)}
|
||||
|
||||
# since we save this data in src/transformers/dependency_versions_table.py it can be easily accessed from
|
||||
# anywhere. If you need to quickly access the data from this table in a shell, you can do so easily with:
|
||||
|
||||
@@ -292,7 +292,7 @@ def replace_model_patterns(
|
||||
attributes_to_check.append("model_type")
|
||||
else:
|
||||
text = re.sub(
|
||||
fr'(\s*)model_type = "{old_model_patterns.model_type}"',
|
||||
rf'(\s*)model_type = "{old_model_patterns.model_type}"',
|
||||
r'\1model_type = "[MODEL_TYPE]"',
|
||||
text,
|
||||
)
|
||||
@@ -301,8 +301,8 @@ def replace_model_patterns(
|
||||
# not the new one. We can't just do a replace in all the text and will need a special regex
|
||||
if old_model_patterns.model_upper_cased == old_model_patterns.model_camel_cased:
|
||||
old_model_value = old_model_patterns.model_upper_cased
|
||||
if re.search(fr"{old_model_value}_[A-Z_]*[^A-Z_]", text) is not None:
|
||||
text = re.sub(fr"{old_model_value}([A-Z_]*)([^a-zA-Z_])", r"[MODEL_UPPER_CASED]\1\2", text)
|
||||
if re.search(rf"{old_model_value}_[A-Z_]*[^A-Z_]", text) is not None:
|
||||
text = re.sub(rf"{old_model_value}([A-Z_]*)([^a-zA-Z_])", r"[MODEL_UPPER_CASED]\1\2", text)
|
||||
else:
|
||||
attributes_to_check.append("model_upper_cased")
|
||||
|
||||
@@ -750,8 +750,8 @@ def clean_frameworks_in_init(
|
||||
return
|
||||
|
||||
remove_pattern = "|".join(to_remove)
|
||||
re_conditional_imports = re.compile(fr"^\s*if is_({remove_pattern})_available\(\):\s*$")
|
||||
re_is_xxx_available = re.compile(fr"is_({remove_pattern})_available")
|
||||
re_conditional_imports = re.compile(rf"^\s*if is_({remove_pattern})_available\(\):\s*$")
|
||||
re_is_xxx_available = re.compile(rf"is_({remove_pattern})_available")
|
||||
|
||||
with open(init_file, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
@@ -831,7 +831,7 @@ def add_model_to_main_init(
|
||||
if framework is not None and frameworks is not None and framework not in frameworks:
|
||||
new_lines.append(lines[idx])
|
||||
idx += 1
|
||||
elif re.search(fr'models.{old_model_patterns.model_lower_cased}( |")', lines[idx]) is not None:
|
||||
elif re.search(rf'models.{old_model_patterns.model_lower_cased}( |")', lines[idx]) is not None:
|
||||
block = [lines[idx]]
|
||||
indent = find_indent(lines[idx])
|
||||
idx += 1
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
# 2. run `make deps_table_update``
|
||||
deps = {
|
||||
"Pillow": "Pillow",
|
||||
"black": "black==21.4b0",
|
||||
"black": "black~=22.0",
|
||||
"codecarbon": "codecarbon==1.2.0",
|
||||
"cookiecutter": "cookiecutter==1.7.2",
|
||||
"dataclasses": "dataclasses",
|
||||
|
||||
@@ -405,13 +405,10 @@ def _compute_global_attention_mask(input_ids_shape, sep_token_indices, before_se
|
||||
else:
|
||||
# last token is separation token and should not be counted and in the middle are two separation tokens
|
||||
question_end_index = tf.tile(question_end_index + 1, (1, input_ids_shape[1]))
|
||||
attention_mask = (
|
||||
tf.cast(
|
||||
attention_mask = tf.cast(
|
||||
attention_mask > question_end_index,
|
||||
dtype=question_end_index.dtype,
|
||||
)
|
||||
* tf.cast(attention_mask < input_ids_shape[-1], dtype=question_end_index.dtype)
|
||||
)
|
||||
) * tf.cast(attention_mask < input_ids_shape[-1], dtype=question_end_index.dtype)
|
||||
|
||||
return attention_mask
|
||||
|
||||
|
||||
@@ -68,7 +68,8 @@ class CopyCheckTester(unittest.TestCase):
|
||||
code = comment + f"\nclass {class_name}(nn.Module):\n" + class_code
|
||||
if overwrite_result is not None:
|
||||
expected = comment + f"\nclass {class_name}(nn.Module):\n" + overwrite_result
|
||||
code = black.format_str(code, mode=black.FileMode([black.TargetVersion.PY35], line_length=119))
|
||||
mode = black.Mode(target_versions={black.TargetVersion.PY35}, line_length=119)
|
||||
code = black.format_str(code, mode=mode)
|
||||
fname = os.path.join(self.transformer_dir, "new_code.py")
|
||||
with open(fname, "w", newline="\n") as f:
|
||||
f.write(code)
|
||||
|
||||
@@ -88,7 +88,7 @@ def find_code_in_transformers(object_name):
|
||||
line_index = 0
|
||||
for name in parts[i + 1 :]:
|
||||
while (
|
||||
line_index < len(lines) and re.search(fr"^{indent}(class|def)\s+{name}(\(|\:)", lines[line_index]) is None
|
||||
line_index < len(lines) and re.search(rf"^{indent}(class|def)\s+{name}(\(|\:)", lines[line_index]) is None
|
||||
):
|
||||
line_index += 1
|
||||
indent += " "
|
||||
@@ -130,7 +130,8 @@ def blackify(code):
|
||||
has_indent = len(get_indent(code)) > 0
|
||||
if has_indent:
|
||||
code = f"class Bla:\n{code}"
|
||||
result = black.format_str(code, mode=black.FileMode([black.TargetVersion.PY35], line_length=119))
|
||||
mode = black.Mode(target_versions={black.TargetVersion.PY35}, line_length=119)
|
||||
result = black.format_str(code, mode=mode)
|
||||
result, _ = style_docstrings_in_code(result)
|
||||
return result[len("class Bla:\n") :] if has_indent else result
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ fork_point_sha = subprocess.check_output("git merge-base master HEAD".split()).d
|
||||
modified_files = subprocess.check_output(f"git diff --name-only {fork_point_sha}".split()).decode("utf-8").split()
|
||||
|
||||
joined_dirs = "|".join(sys.argv[1:])
|
||||
regex = re.compile(fr"^({joined_dirs}).*?\.py$")
|
||||
regex = re.compile(rf"^({joined_dirs}).*?\.py$")
|
||||
|
||||
relevant_modified_files = [x for x in modified_files if regex.match(x)]
|
||||
print(" ".join(relevant_modified_files), end="")
|
||||
|
||||
@@ -147,9 +147,8 @@ def format_code_example(code: str, max_len: int, in_docstring: bool = False):
|
||||
for k, v in BLACK_AVOID_PATTERNS.items():
|
||||
full_code = full_code.replace(k, v)
|
||||
try:
|
||||
formatted_code = black.format_str(
|
||||
full_code, mode=black.FileMode([black.TargetVersion.PY37], line_length=line_length)
|
||||
)
|
||||
mode = black.Mode(target_versions={black.TargetVersion.PY37}, line_length=line_length)
|
||||
formatted_code = black.format_str(full_code, mode=mode)
|
||||
error = ""
|
||||
except Exception as e:
|
||||
formatted_code = full_code
|
||||
|
||||
Reference in New Issue
Block a user