Fix E266 flake8 warning (x90).

This commit is contained in:
Aymeric Augustin
2019-12-21 21:22:55 +01:00
parent 2ab78325f0
commit fa2ccbc081
30 changed files with 92 additions and 90 deletions

View File

@@ -43,7 +43,7 @@ if __name__ == "__main__":
state_dict = model.state_dict()
compressed_sd = {}
### Embeddings ###
# Embeddings #
if args.model_type == "gpt2":
for param_name in ["wte.weight", "wpe.weight"]:
compressed_sd[f"{prefix}.{param_name}"] = state_dict[f"{prefix}.{param_name}"]
@@ -55,7 +55,7 @@ if __name__ == "__main__":
param_name = f"{prefix}.embeddings.LayerNorm.{w}"
compressed_sd[param_name] = state_dict[param_name]
### Transformer Blocks ###
# Transformer Blocks #
std_idx = 0
for teacher_idx in [0, 2, 4, 7, 9, 11]:
if args.model_type == "gpt2":
@@ -82,7 +82,7 @@ if __name__ == "__main__":
]
std_idx += 1
### Language Modeling Head ###s
# Language Modeling Head ###s
if args.model_type == "roberta":
for layer in ["lm_head.decoder.weight", "lm_head.bias"]:
compressed_sd[f"{layer}"] = state_dict[f"{layer}"]