Fix E741 flake8 warning (x14).

This commit is contained in:
Aymeric Augustin
2019-12-21 18:25:59 +01:00
parent ea89bec185
commit b0f7db73cd
8 changed files with 60 additions and 60 deletions

View File

@@ -89,25 +89,25 @@ def load_tf_weights_in_xxx(model, config, tf_checkpoint_path):
pointer = model
for m_name in name:
if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
l = re.split(r"_(\d+)", m_name)
scope_names = re.split(r"_(\d+)", m_name)
else:
l = [m_name]
if l[0] == "kernel" or l[0] == "gamma":
scope_names = [m_name]
if scope_names[0] == "kernel" or scope_names[0] == "gamma":
pointer = getattr(pointer, "weight")
elif l[0] == "output_bias" or l[0] == "beta":
elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
pointer = getattr(pointer, "bias")
elif l[0] == "output_weights":
elif scope_names[0] == "output_weights":
pointer = getattr(pointer, "weight")
elif l[0] == "squad":
elif scope_names[0] == "squad":
pointer = getattr(pointer, "classifier")
else:
try:
pointer = getattr(pointer, l[0])
pointer = getattr(pointer, scope_names[0])
except AttributeError:
logger.info("Skipping {}".format("/".join(name)))
continue
if len(l) >= 2:
num = int(l[1])
if len(scope_names) >= 2:
num = int(scope_names[1])
pointer = pointer[num]
if m_name[-11:] == "_embeddings":
pointer = getattr(pointer, "weight")