weigths*weights

This commit is contained in:
Julien Chaumond
2020-04-04 15:03:26 -04:00
parent 243e687be6
commit 94eb68d742
4 changed files with 4 additions and 4 deletions

View File

@@ -136,7 +136,7 @@ def load_tf_weights_in_transfo_xl(model, config, tf_path):
if "kernel" in name or "proj" in name:
array = np.transpose(array)
if ("r_r_bias" in name or "r_w_bias" in name) and len(pointer) > 1:
# Here we will split the TF weigths
# Here we will split the TF weights
assert len(pointer) == array.shape[0]
for i, p_i in enumerate(pointer):
arr_i = array[i, ...]

View File

@@ -156,7 +156,7 @@ def load_tf_weights_in_xlnet(model, config, tf_path):
logger.info("Transposing")
array = np.transpose(array)
if isinstance(pointer, list):
# Here we will split the TF weigths
# Here we will split the TF weights
assert len(pointer) == array.shape[0]
for i, p_i in enumerate(pointer):
arr_i = array[i, ...]