Upgrade black to version ~=22.0 (#15565)
* Upgrade black to version ~=22.0 * Check copies * Fix code
This commit is contained in:
@@ -229,20 +229,14 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
|
|||||||
assert end_logits_tea.size() == end_logits_stu.size()
|
assert end_logits_tea.size() == end_logits_stu.size()
|
||||||
|
|
||||||
loss_fct = nn.KLDivLoss(reduction="batchmean")
|
loss_fct = nn.KLDivLoss(reduction="batchmean")
|
||||||
loss_start = (
|
loss_start = loss_fct(
|
||||||
loss_fct(
|
|
||||||
nn.functional.log_softmax(start_logits_stu / args.temperature, dim=-1),
|
nn.functional.log_softmax(start_logits_stu / args.temperature, dim=-1),
|
||||||
nn.functional.softmax(start_logits_tea / args.temperature, dim=-1),
|
nn.functional.softmax(start_logits_tea / args.temperature, dim=-1),
|
||||||
)
|
) * (args.temperature**2)
|
||||||
* (args.temperature ** 2)
|
loss_end = loss_fct(
|
||||||
)
|
|
||||||
loss_end = (
|
|
||||||
loss_fct(
|
|
||||||
nn.functional.log_softmax(end_logits_stu / args.temperature, dim=-1),
|
nn.functional.log_softmax(end_logits_stu / args.temperature, dim=-1),
|
||||||
nn.functional.softmax(end_logits_tea / args.temperature, dim=-1),
|
nn.functional.softmax(end_logits_tea / args.temperature, dim=-1),
|
||||||
)
|
) * (args.temperature**2)
|
||||||
* (args.temperature ** 2)
|
|
||||||
)
|
|
||||||
loss_ce = (loss_start + loss_end) / 2.0
|
loss_ce = (loss_start + loss_end) / 2.0
|
||||||
|
|
||||||
loss = args.alpha_ce * loss_ce + args.alpha_squad * loss
|
loss = args.alpha_ce * loss_ce + args.alpha_squad * loss
|
||||||
|
|||||||
@@ -285,14 +285,11 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
|
|||||||
attention_mask=inputs["attention_mask"],
|
attention_mask=inputs["attention_mask"],
|
||||||
)
|
)
|
||||||
|
|
||||||
loss_logits = (
|
loss_logits = nn.functional.kl_div(
|
||||||
nn.functional.kl_div(
|
|
||||||
input=nn.functional.log_softmax(logits_stu / args.temperature, dim=-1),
|
input=nn.functional.log_softmax(logits_stu / args.temperature, dim=-1),
|
||||||
target=nn.functional.softmax(logits_tea / args.temperature, dim=-1),
|
target=nn.functional.softmax(logits_tea / args.temperature, dim=-1),
|
||||||
reduction="batchmean",
|
reduction="batchmean",
|
||||||
)
|
) * (args.temperature**2)
|
||||||
* (args.temperature ** 2)
|
|
||||||
)
|
|
||||||
|
|
||||||
loss = args.alpha_distil * loss_logits + args.alpha_ce * loss
|
loss = args.alpha_distil * loss_logits + args.alpha_ce * loss
|
||||||
|
|
||||||
|
|||||||
@@ -306,22 +306,16 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
|
|||||||
attention_mask=inputs["attention_mask"],
|
attention_mask=inputs["attention_mask"],
|
||||||
)
|
)
|
||||||
|
|
||||||
loss_start = (
|
loss_start = nn.functional.kl_div(
|
||||||
nn.functional.kl_div(
|
|
||||||
input=nn.functional.log_softmax(start_logits_stu / args.temperature, dim=-1),
|
input=nn.functional.log_softmax(start_logits_stu / args.temperature, dim=-1),
|
||||||
target=nn.functional.softmax(start_logits_tea / args.temperature, dim=-1),
|
target=nn.functional.softmax(start_logits_tea / args.temperature, dim=-1),
|
||||||
reduction="batchmean",
|
reduction="batchmean",
|
||||||
)
|
) * (args.temperature**2)
|
||||||
* (args.temperature ** 2)
|
loss_end = nn.functional.kl_div(
|
||||||
)
|
|
||||||
loss_end = (
|
|
||||||
nn.functional.kl_div(
|
|
||||||
input=nn.functional.log_softmax(end_logits_stu / args.temperature, dim=-1),
|
input=nn.functional.log_softmax(end_logits_stu / args.temperature, dim=-1),
|
||||||
target=nn.functional.softmax(end_logits_tea / args.temperature, dim=-1),
|
target=nn.functional.softmax(end_logits_tea / args.temperature, dim=-1),
|
||||||
reduction="batchmean",
|
reduction="batchmean",
|
||||||
)
|
) * (args.temperature**2)
|
||||||
* (args.temperature ** 2)
|
|
||||||
)
|
|
||||||
loss_logits = (loss_start + loss_end) / 2.0
|
loss_logits = (loss_start + loss_end) / 2.0
|
||||||
|
|
||||||
loss = args.alpha_distil * loss_logits + args.alpha_ce * loss
|
loss = args.alpha_distil * loss_logits + args.alpha_ce * loss
|
||||||
|
|||||||
4
setup.py
4
setup.py
@@ -93,7 +93,7 @@ if stale_egg_info.exists():
|
|||||||
# 2. once modified, run: `make deps_table_update` to update src/transformers/dependency_versions_table.py
|
# 2. once modified, run: `make deps_table_update` to update src/transformers/dependency_versions_table.py
|
||||||
_deps = [
|
_deps = [
|
||||||
"Pillow",
|
"Pillow",
|
||||||
"black==21.4b0",
|
"black~=22.0",
|
||||||
"codecarbon==1.2.0",
|
"codecarbon==1.2.0",
|
||||||
"cookiecutter==1.7.2",
|
"cookiecutter==1.7.2",
|
||||||
"dataclasses",
|
"dataclasses",
|
||||||
@@ -166,7 +166,7 @@ _deps = [
|
|||||||
# packaging: "packaging"
|
# packaging: "packaging"
|
||||||
#
|
#
|
||||||
# some of the values are versioned whereas others aren't.
|
# some of the values are versioned whereas others aren't.
|
||||||
deps = {b: a for a, b in (re.findall(r"^(([^!=<>]+)(?:[!=<>].*)?$)", x)[0] for x in _deps)}
|
deps = {b: a for a, b in (re.findall(r"^(([^!=<>~]+)(?:[!=<>~].*)?$)", x)[0] for x in _deps)}
|
||||||
|
|
||||||
# since we save this data in src/transformers/dependency_versions_table.py it can be easily accessed from
|
# since we save this data in src/transformers/dependency_versions_table.py it can be easily accessed from
|
||||||
# anywhere. If you need to quickly access the data from this table in a shell, you can do so easily with:
|
# anywhere. If you need to quickly access the data from this table in a shell, you can do so easily with:
|
||||||
|
|||||||
@@ -292,7 +292,7 @@ def replace_model_patterns(
|
|||||||
attributes_to_check.append("model_type")
|
attributes_to_check.append("model_type")
|
||||||
else:
|
else:
|
||||||
text = re.sub(
|
text = re.sub(
|
||||||
fr'(\s*)model_type = "{old_model_patterns.model_type}"',
|
rf'(\s*)model_type = "{old_model_patterns.model_type}"',
|
||||||
r'\1model_type = "[MODEL_TYPE]"',
|
r'\1model_type = "[MODEL_TYPE]"',
|
||||||
text,
|
text,
|
||||||
)
|
)
|
||||||
@@ -301,8 +301,8 @@ def replace_model_patterns(
|
|||||||
# not the new one. We can't just do a replace in all the text and will need a special regex
|
# not the new one. We can't just do a replace in all the text and will need a special regex
|
||||||
if old_model_patterns.model_upper_cased == old_model_patterns.model_camel_cased:
|
if old_model_patterns.model_upper_cased == old_model_patterns.model_camel_cased:
|
||||||
old_model_value = old_model_patterns.model_upper_cased
|
old_model_value = old_model_patterns.model_upper_cased
|
||||||
if re.search(fr"{old_model_value}_[A-Z_]*[^A-Z_]", text) is not None:
|
if re.search(rf"{old_model_value}_[A-Z_]*[^A-Z_]", text) is not None:
|
||||||
text = re.sub(fr"{old_model_value}([A-Z_]*)([^a-zA-Z_])", r"[MODEL_UPPER_CASED]\1\2", text)
|
text = re.sub(rf"{old_model_value}([A-Z_]*)([^a-zA-Z_])", r"[MODEL_UPPER_CASED]\1\2", text)
|
||||||
else:
|
else:
|
||||||
attributes_to_check.append("model_upper_cased")
|
attributes_to_check.append("model_upper_cased")
|
||||||
|
|
||||||
@@ -750,8 +750,8 @@ def clean_frameworks_in_init(
|
|||||||
return
|
return
|
||||||
|
|
||||||
remove_pattern = "|".join(to_remove)
|
remove_pattern = "|".join(to_remove)
|
||||||
re_conditional_imports = re.compile(fr"^\s*if is_({remove_pattern})_available\(\):\s*$")
|
re_conditional_imports = re.compile(rf"^\s*if is_({remove_pattern})_available\(\):\s*$")
|
||||||
re_is_xxx_available = re.compile(fr"is_({remove_pattern})_available")
|
re_is_xxx_available = re.compile(rf"is_({remove_pattern})_available")
|
||||||
|
|
||||||
with open(init_file, "r", encoding="utf-8") as f:
|
with open(init_file, "r", encoding="utf-8") as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
@@ -831,7 +831,7 @@ def add_model_to_main_init(
|
|||||||
if framework is not None and frameworks is not None and framework not in frameworks:
|
if framework is not None and frameworks is not None and framework not in frameworks:
|
||||||
new_lines.append(lines[idx])
|
new_lines.append(lines[idx])
|
||||||
idx += 1
|
idx += 1
|
||||||
elif re.search(fr'models.{old_model_patterns.model_lower_cased}( |")', lines[idx]) is not None:
|
elif re.search(rf'models.{old_model_patterns.model_lower_cased}( |")', lines[idx]) is not None:
|
||||||
block = [lines[idx]]
|
block = [lines[idx]]
|
||||||
indent = find_indent(lines[idx])
|
indent = find_indent(lines[idx])
|
||||||
idx += 1
|
idx += 1
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
# 2. run `make deps_table_update``
|
# 2. run `make deps_table_update``
|
||||||
deps = {
|
deps = {
|
||||||
"Pillow": "Pillow",
|
"Pillow": "Pillow",
|
||||||
"black": "black==21.4b0",
|
"black": "black~=22.0",
|
||||||
"codecarbon": "codecarbon==1.2.0",
|
"codecarbon": "codecarbon==1.2.0",
|
||||||
"cookiecutter": "cookiecutter==1.7.2",
|
"cookiecutter": "cookiecutter==1.7.2",
|
||||||
"dataclasses": "dataclasses",
|
"dataclasses": "dataclasses",
|
||||||
|
|||||||
@@ -405,13 +405,10 @@ def _compute_global_attention_mask(input_ids_shape, sep_token_indices, before_se
|
|||||||
else:
|
else:
|
||||||
# last token is separation token and should not be counted and in the middle are two separation tokens
|
# last token is separation token and should not be counted and in the middle are two separation tokens
|
||||||
question_end_index = tf.tile(question_end_index + 1, (1, input_ids_shape[1]))
|
question_end_index = tf.tile(question_end_index + 1, (1, input_ids_shape[1]))
|
||||||
attention_mask = (
|
attention_mask = tf.cast(
|
||||||
tf.cast(
|
|
||||||
attention_mask > question_end_index,
|
attention_mask > question_end_index,
|
||||||
dtype=question_end_index.dtype,
|
dtype=question_end_index.dtype,
|
||||||
)
|
) * tf.cast(attention_mask < input_ids_shape[-1], dtype=question_end_index.dtype)
|
||||||
* tf.cast(attention_mask < input_ids_shape[-1], dtype=question_end_index.dtype)
|
|
||||||
)
|
|
||||||
|
|
||||||
return attention_mask
|
return attention_mask
|
||||||
|
|
||||||
|
|||||||
@@ -68,7 +68,8 @@ class CopyCheckTester(unittest.TestCase):
|
|||||||
code = comment + f"\nclass {class_name}(nn.Module):\n" + class_code
|
code = comment + f"\nclass {class_name}(nn.Module):\n" + class_code
|
||||||
if overwrite_result is not None:
|
if overwrite_result is not None:
|
||||||
expected = comment + f"\nclass {class_name}(nn.Module):\n" + overwrite_result
|
expected = comment + f"\nclass {class_name}(nn.Module):\n" + overwrite_result
|
||||||
code = black.format_str(code, mode=black.FileMode([black.TargetVersion.PY35], line_length=119))
|
mode = black.Mode(target_versions={black.TargetVersion.PY35}, line_length=119)
|
||||||
|
code = black.format_str(code, mode=mode)
|
||||||
fname = os.path.join(self.transformer_dir, "new_code.py")
|
fname = os.path.join(self.transformer_dir, "new_code.py")
|
||||||
with open(fname, "w", newline="\n") as f:
|
with open(fname, "w", newline="\n") as f:
|
||||||
f.write(code)
|
f.write(code)
|
||||||
|
|||||||
@@ -88,7 +88,7 @@ def find_code_in_transformers(object_name):
|
|||||||
line_index = 0
|
line_index = 0
|
||||||
for name in parts[i + 1 :]:
|
for name in parts[i + 1 :]:
|
||||||
while (
|
while (
|
||||||
line_index < len(lines) and re.search(fr"^{indent}(class|def)\s+{name}(\(|\:)", lines[line_index]) is None
|
line_index < len(lines) and re.search(rf"^{indent}(class|def)\s+{name}(\(|\:)", lines[line_index]) is None
|
||||||
):
|
):
|
||||||
line_index += 1
|
line_index += 1
|
||||||
indent += " "
|
indent += " "
|
||||||
@@ -130,7 +130,8 @@ def blackify(code):
|
|||||||
has_indent = len(get_indent(code)) > 0
|
has_indent = len(get_indent(code)) > 0
|
||||||
if has_indent:
|
if has_indent:
|
||||||
code = f"class Bla:\n{code}"
|
code = f"class Bla:\n{code}"
|
||||||
result = black.format_str(code, mode=black.FileMode([black.TargetVersion.PY35], line_length=119))
|
mode = black.Mode(target_versions={black.TargetVersion.PY35}, line_length=119)
|
||||||
|
result = black.format_str(code, mode=mode)
|
||||||
result, _ = style_docstrings_in_code(result)
|
result, _ = style_docstrings_in_code(result)
|
||||||
return result[len("class Bla:\n") :] if has_indent else result
|
return result[len("class Bla:\n") :] if has_indent else result
|
||||||
|
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ fork_point_sha = subprocess.check_output("git merge-base master HEAD".split()).d
|
|||||||
modified_files = subprocess.check_output(f"git diff --name-only {fork_point_sha}".split()).decode("utf-8").split()
|
modified_files = subprocess.check_output(f"git diff --name-only {fork_point_sha}".split()).decode("utf-8").split()
|
||||||
|
|
||||||
joined_dirs = "|".join(sys.argv[1:])
|
joined_dirs = "|".join(sys.argv[1:])
|
||||||
regex = re.compile(fr"^({joined_dirs}).*?\.py$")
|
regex = re.compile(rf"^({joined_dirs}).*?\.py$")
|
||||||
|
|
||||||
relevant_modified_files = [x for x in modified_files if regex.match(x)]
|
relevant_modified_files = [x for x in modified_files if regex.match(x)]
|
||||||
print(" ".join(relevant_modified_files), end="")
|
print(" ".join(relevant_modified_files), end="")
|
||||||
|
|||||||
@@ -147,9 +147,8 @@ def format_code_example(code: str, max_len: int, in_docstring: bool = False):
|
|||||||
for k, v in BLACK_AVOID_PATTERNS.items():
|
for k, v in BLACK_AVOID_PATTERNS.items():
|
||||||
full_code = full_code.replace(k, v)
|
full_code = full_code.replace(k, v)
|
||||||
try:
|
try:
|
||||||
formatted_code = black.format_str(
|
mode = black.Mode(target_versions={black.TargetVersion.PY37}, line_length=line_length)
|
||||||
full_code, mode=black.FileMode([black.TargetVersion.PY37], line_length=line_length)
|
formatted_code = black.format_str(full_code, mode=mode)
|
||||||
)
|
|
||||||
error = ""
|
error = ""
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
formatted_code = full_code
|
formatted_code = full_code
|
||||||
|
|||||||
Reference in New Issue
Block a user