Fix: unpin flake8 and fix cs errors (#4367)
* Fix: unpin flake8 and fix cs errors * Ok we still need to quote those
This commit is contained in:
@@ -80,7 +80,7 @@ class Distiller:
|
||||
|
||||
self.mlm = params.mlm
|
||||
if self.mlm:
|
||||
logger.info(f"Using MLM loss for LM step.")
|
||||
logger.info("Using MLM loss for LM step.")
|
||||
self.mlm_mask_prop = params.mlm_mask_prop
|
||||
assert 0.0 <= self.mlm_mask_prop <= 1.0
|
||||
assert params.word_mask + params.word_keep + params.word_rand == 1.0
|
||||
@@ -91,7 +91,7 @@ class Distiller:
|
||||
self.pred_probs = self.pred_probs.half()
|
||||
self.token_probs = self.token_probs.half()
|
||||
else:
|
||||
logger.info(f"Using CLM loss for LM step.")
|
||||
logger.info("Using CLM loss for LM step.")
|
||||
|
||||
self.epoch = 0
|
||||
self.n_iter = 0
|
||||
@@ -365,8 +365,8 @@ class Distiller:
|
||||
self.end_epoch()
|
||||
|
||||
if self.is_master:
|
||||
logger.info(f"Save very last checkpoint as `pytorch_model.bin`.")
|
||||
self.save_checkpoint(checkpoint_name=f"pytorch_model.bin")
|
||||
logger.info("Save very last checkpoint as `pytorch_model.bin`.")
|
||||
self.save_checkpoint(checkpoint_name="pytorch_model.bin")
|
||||
logger.info("Training is finished")
|
||||
|
||||
def step(self, input_ids: torch.tensor, attention_mask: torch.tensor, lm_labels: torch.tensor):
|
||||
|
||||
@@ -60,7 +60,7 @@ def main():
|
||||
with open(args.file_path, "r", encoding="utf8") as fp:
|
||||
data = fp.readlines()
|
||||
|
||||
logger.info(f"Start encoding")
|
||||
logger.info("Start encoding")
|
||||
logger.info(f"{len(data)} examples to process.")
|
||||
|
||||
rslt = []
|
||||
|
||||
@@ -93,7 +93,7 @@ if __name__ == "__main__":
|
||||
elif args.model_type == "gpt2":
|
||||
for w in ["weight", "bias"]:
|
||||
compressed_sd[f"{prefix}.ln_f.{w}"] = state_dict[f"{prefix}.ln_f.{w}"]
|
||||
compressed_sd[f"lm_head.weight"] = state_dict[f"lm_head.weight"]
|
||||
compressed_sd["lm_head.weight"] = state_dict["lm_head.weight"]
|
||||
|
||||
print(f"N layers selected for distillation: {std_idx}")
|
||||
print(f"Number of params transfered for distillation: {len(compressed_sd.keys())}")
|
||||
|
||||
@@ -37,7 +37,7 @@ if __name__ == "__main__":
|
||||
model = BertForMaskedLM.from_pretrained(args.model_name)
|
||||
prefix = "bert"
|
||||
else:
|
||||
raise ValueError(f'args.model_type should be "bert".')
|
||||
raise ValueError('args.model_type should be "bert".')
|
||||
|
||||
state_dict = model.state_dict()
|
||||
compressed_sd = {}
|
||||
@@ -78,12 +78,12 @@ if __name__ == "__main__":
|
||||
]
|
||||
std_idx += 1
|
||||
|
||||
compressed_sd[f"vocab_projector.weight"] = state_dict[f"cls.predictions.decoder.weight"]
|
||||
compressed_sd[f"vocab_projector.bias"] = state_dict[f"cls.predictions.bias"]
|
||||
compressed_sd["vocab_projector.weight"] = state_dict["cls.predictions.decoder.weight"]
|
||||
compressed_sd["vocab_projector.bias"] = state_dict["cls.predictions.bias"]
|
||||
if args.vocab_transform:
|
||||
for w in ["weight", "bias"]:
|
||||
compressed_sd[f"vocab_transform.{w}"] = state_dict[f"cls.predictions.transform.dense.{w}"]
|
||||
compressed_sd[f"vocab_layer_norm.{w}"] = state_dict[f"cls.predictions.transform.LayerNorm.{w}"]
|
||||
compressed_sd[f"vocab_transform.{w}"] = state_dict["cls.predictions.transform.dense.{w}"]
|
||||
compressed_sd[f"vocab_layer_norm.{w}"] = state_dict["cls.predictions.transform.LayerNorm.{w}"]
|
||||
|
||||
print(f"N layers selected for distillation: {std_idx}")
|
||||
print(f"Number of params transfered for distillation: {len(compressed_sd.keys())}")
|
||||
|
||||
@@ -273,7 +273,7 @@ def main():
|
||||
token_probs = None
|
||||
|
||||
train_lm_seq_dataset = LmSeqsDataset(params=args, data=data)
|
||||
logger.info(f"Data loader created.")
|
||||
logger.info("Data loader created.")
|
||||
|
||||
# STUDENT #
|
||||
logger.info(f"Loading student config from {args.student_config}")
|
||||
@@ -288,7 +288,7 @@ def main():
|
||||
|
||||
if args.n_gpu > 0:
|
||||
student.to(f"cuda:{args.local_rank}")
|
||||
logger.info(f"Student loaded.")
|
||||
logger.info("Student loaded.")
|
||||
|
||||
# TEACHER #
|
||||
teacher = teacher_model_class.from_pretrained(args.teacher_name, output_hidden_states=True)
|
||||
|
||||
Reference in New Issue
Block a user