Fix E741 flake8 warning (x14).
This commit is contained in:
@@ -89,25 +89,25 @@ def load_tf_weights_in_xxx(model, config, tf_checkpoint_path):
|
||||
pointer = model
|
||||
for m_name in name:
|
||||
if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
|
||||
l = re.split(r"_(\d+)", m_name)
|
||||
scope_names = re.split(r"_(\d+)", m_name)
|
||||
else:
|
||||
l = [m_name]
|
||||
if l[0] == "kernel" or l[0] == "gamma":
|
||||
scope_names = [m_name]
|
||||
if scope_names[0] == "kernel" or scope_names[0] == "gamma":
|
||||
pointer = getattr(pointer, "weight")
|
||||
elif l[0] == "output_bias" or l[0] == "beta":
|
||||
elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
|
||||
pointer = getattr(pointer, "bias")
|
||||
elif l[0] == "output_weights":
|
||||
elif scope_names[0] == "output_weights":
|
||||
pointer = getattr(pointer, "weight")
|
||||
elif l[0] == "squad":
|
||||
elif scope_names[0] == "squad":
|
||||
pointer = getattr(pointer, "classifier")
|
||||
else:
|
||||
try:
|
||||
pointer = getattr(pointer, l[0])
|
||||
pointer = getattr(pointer, scope_names[0])
|
||||
except AttributeError:
|
||||
logger.info("Skipping {}".format("/".join(name)))
|
||||
continue
|
||||
if len(l) >= 2:
|
||||
num = int(l[1])
|
||||
if len(scope_names) >= 2:
|
||||
num = int(scope_names[1])
|
||||
pointer = pointer[num]
|
||||
if m_name[-11:] == "_embeddings":
|
||||
pointer = getattr(pointer, "weight")
|
||||
|
||||
Reference in New Issue
Block a user