Fix: unpin flake8 and fix cs errors (#4367)
* Fix: unpin flake8 and fix cs errors * Ok we still need to quote those
This commit is contained in:
@@ -37,7 +37,7 @@ if __name__ == "__main__":
|
||||
model = BertForMaskedLM.from_pretrained(args.model_name)
|
||||
prefix = "bert"
|
||||
else:
|
||||
raise ValueError(f'args.model_type should be "bert".')
|
||||
raise ValueError('args.model_type should be "bert".')
|
||||
|
||||
state_dict = model.state_dict()
|
||||
compressed_sd = {}
|
||||
@@ -78,12 +78,12 @@ if __name__ == "__main__":
|
||||
]
|
||||
std_idx += 1
|
||||
|
||||
compressed_sd[f"vocab_projector.weight"] = state_dict[f"cls.predictions.decoder.weight"]
|
||||
compressed_sd[f"vocab_projector.bias"] = state_dict[f"cls.predictions.bias"]
|
||||
compressed_sd["vocab_projector.weight"] = state_dict["cls.predictions.decoder.weight"]
|
||||
compressed_sd["vocab_projector.bias"] = state_dict["cls.predictions.bias"]
|
||||
if args.vocab_transform:
|
||||
for w in ["weight", "bias"]:
|
||||
compressed_sd[f"vocab_transform.{w}"] = state_dict[f"cls.predictions.transform.dense.{w}"]
|
||||
compressed_sd[f"vocab_layer_norm.{w}"] = state_dict[f"cls.predictions.transform.LayerNorm.{w}"]
|
||||
compressed_sd[f"vocab_transform.{w}"] = state_dict["cls.predictions.transform.dense.{w}"]
|
||||
compressed_sd[f"vocab_layer_norm.{w}"] = state_dict["cls.predictions.transform.LayerNorm.{w}"]
|
||||
|
||||
print(f"N layers selected for distillation: {std_idx}")
|
||||
print(f"Number of params transfered for distillation: {len(compressed_sd.keys())}")
|
||||
|
||||
Reference in New Issue
Block a user