Fix E266 flake8 warning (x90).

This commit is contained in:
Aymeric Augustin
2019-12-21 21:22:55 +01:00
parent 2ab78325f0
commit fa2ccbc081
30 changed files with 92 additions and 90 deletions

View File

@@ -219,7 +219,7 @@ def main():
args = parser.parse_args()
sanity_checks(args)
## ARGS ##
# ARGS #
init_gpu_params(args)
set_seed(args)
if args.is_master:
@@ -236,7 +236,7 @@ def main():
os.makedirs(args.dump_path)
logger.info(f"Experiment will be dumped and logged in {args.dump_path}")
### SAVE PARAMS ###
# SAVE PARAMS #
logger.info(f"Param: {args}")
with open(os.path.join(args.dump_path, "parameters.json"), "w") as f:
json.dump(vars(args), f, indent=4)
@@ -245,7 +245,7 @@ def main():
student_config_class, student_model_class, _ = MODEL_CLASSES[args.student_type]
teacher_config_class, teacher_model_class, teacher_tokenizer_class = MODEL_CLASSES[args.teacher_type]
### TOKENIZER ###
# TOKENIZER #
tokenizer = teacher_tokenizer_class.from_pretrained(args.teacher_name)
special_tok_ids = {}
for tok_name, tok_symbol in tokenizer.special_tokens_map.items():
@@ -255,7 +255,7 @@ def main():
args.special_tok_ids = special_tok_ids
args.max_model_input_size = tokenizer.max_model_input_sizes[args.teacher_name]
## DATA LOADER ##
# DATA LOADER #
logger.info(f"Loading data from {args.data_file}")
with open(args.data_file, "rb") as fp:
data = pickle.load(fp)
@@ -275,7 +275,7 @@ def main():
train_lm_seq_dataset = LmSeqsDataset(params=args, data=data)
logger.info(f"Data loader created.")
## STUDENT ##
# STUDENT #
logger.info(f"Loading student config from {args.student_config}")
stu_architecture_config = student_config_class.from_pretrained(args.student_config)
stu_architecture_config.output_hidden_states = True
@@ -290,26 +290,26 @@ def main():
student.to(f"cuda:{args.local_rank}")
logger.info(f"Student loaded.")
## TEACHER ##
# TEACHER #
teacher = teacher_model_class.from_pretrained(args.teacher_name, output_hidden_states=True)
if args.n_gpu > 0:
teacher.to(f"cuda:{args.local_rank}")
logger.info(f"Teacher loaded from {args.teacher_name}.")
## FREEZING ##
# FREEZING #
if args.freeze_pos_embs:
freeze_pos_embeddings(student, args)
if args.freeze_token_type_embds:
freeze_token_type_embeddings(student, args)
## SANITY CHECKS ##
# SANITY CHECKS #
assert student.config.vocab_size == teacher.config.vocab_size
assert student.config.hidden_size == teacher.config.hidden_size
assert student.config.max_position_embeddings == teacher.config.max_position_embeddings
if args.mlm:
assert token_probs.size(0) == stu_architecture_config.vocab_size
## DISTILLER ##
# DISTILLER #
torch.cuda.empty_cache()
distiller = Distiller(
params=args, dataset=train_lm_seq_dataset, token_probs=token_probs, student=student, teacher=teacher