Fix E266 flake8 warning (x90).
This commit is contained in:
@@ -43,7 +43,7 @@ if __name__ == "__main__":
|
||||
state_dict = model.state_dict()
|
||||
compressed_sd = {}
|
||||
|
||||
### Embeddings ###
|
||||
# Embeddings #
|
||||
if args.model_type == "gpt2":
|
||||
for param_name in ["wte.weight", "wpe.weight"]:
|
||||
compressed_sd[f"{prefix}.{param_name}"] = state_dict[f"{prefix}.{param_name}"]
|
||||
@@ -55,7 +55,7 @@ if __name__ == "__main__":
|
||||
param_name = f"{prefix}.embeddings.LayerNorm.{w}"
|
||||
compressed_sd[param_name] = state_dict[param_name]
|
||||
|
||||
### Transformer Blocks ###
|
||||
# Transformer Blocks #
|
||||
std_idx = 0
|
||||
for teacher_idx in [0, 2, 4, 7, 9, 11]:
|
||||
if args.model_type == "gpt2":
|
||||
@@ -82,7 +82,7 @@ if __name__ == "__main__":
|
||||
]
|
||||
std_idx += 1
|
||||
|
||||
### Language Modeling Head ###s
|
||||
# Language Modeling Head ###s
|
||||
if args.model_type == "roberta":
|
||||
for layer in ["lm_head.decoder.weight", "lm_head.bias"]:
|
||||
compressed_sd[f"{layer}"] = state_dict[f"{layer}"]
|
||||
|
||||
Reference in New Issue
Block a user