Fix: unpin flake8 and fix cs errors (#4367)

* Fix: unpin flake8 and fix cs errors

* Ok we still need to quote those
This commit is contained in:
Julien Chaumond
2020-05-14 13:14:26 -04:00
committed by GitHub
parent c547f15a17
commit 448c467256
13 changed files with 35 additions and 21 deletions

View File

@@ -60,7 +60,7 @@ def main():
with open(args.file_path, "r", encoding="utf8") as fp:
data = fp.readlines()
logger.info(f"Start encoding")
logger.info("Start encoding")
logger.info(f"{len(data)} examples to process.")
rslt = []

View File

@@ -93,7 +93,7 @@ if __name__ == "__main__":
elif args.model_type == "gpt2":
for w in ["weight", "bias"]:
compressed_sd[f"{prefix}.ln_f.{w}"] = state_dict[f"{prefix}.ln_f.{w}"]
compressed_sd[f"lm_head.weight"] = state_dict[f"lm_head.weight"]
compressed_sd["lm_head.weight"] = state_dict["lm_head.weight"]
print(f"N layers selected for distillation: {std_idx}")
print(f"Number of params transfered for distillation: {len(compressed_sd.keys())}")

View File

@@ -37,7 +37,7 @@ if __name__ == "__main__":
model = BertForMaskedLM.from_pretrained(args.model_name)
prefix = "bert"
else:
raise ValueError(f'args.model_type should be "bert".')
raise ValueError('args.model_type should be "bert".')
state_dict = model.state_dict()
compressed_sd = {}
@@ -78,12 +78,12 @@ if __name__ == "__main__":
]
std_idx += 1
compressed_sd[f"vocab_projector.weight"] = state_dict[f"cls.predictions.decoder.weight"]
compressed_sd[f"vocab_projector.bias"] = state_dict[f"cls.predictions.bias"]
compressed_sd["vocab_projector.weight"] = state_dict["cls.predictions.decoder.weight"]
compressed_sd["vocab_projector.bias"] = state_dict["cls.predictions.bias"]
if args.vocab_transform:
for w in ["weight", "bias"]:
compressed_sd[f"vocab_transform.{w}"] = state_dict[f"cls.predictions.transform.dense.{w}"]
compressed_sd[f"vocab_layer_norm.{w}"] = state_dict[f"cls.predictions.transform.LayerNorm.{w}"]
compressed_sd[f"vocab_transform.{w}"] = state_dict["cls.predictions.transform.dense.{w}"]
compressed_sd[f"vocab_layer_norm.{w}"] = state_dict["cls.predictions.transform.LayerNorm.{w}"]
print(f"N layers selected for distillation: {std_idx}")
print(f"Number of params transfered for distillation: {len(compressed_sd.keys())}")