Fix E266 flake8 warning (x90).

This commit is contained in:
Aymeric Augustin
2019-12-21 21:22:55 +01:00
parent 2ab78325f0
commit fa2ccbc081
30 changed files with 92 additions and 90 deletions

View File

@@ -430,7 +430,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
def main():
parser = argparse.ArgumentParser()
## Required parameters
# Required parameters
parser.add_argument(
"--train_file", default=None, type=str, required=True, help="SQuAD json for training. E.g., train-v1.1.json"
)
@@ -486,7 +486,7 @@ def main():
"--temperature", default=2.0, type=float, help="Distillation temperature. Only for distillation."
)
## Other parameters
# Other parameters
parser.add_argument(
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
)

View File

@@ -43,7 +43,7 @@ if __name__ == "__main__":
state_dict = model.state_dict()
compressed_sd = {}
### Embeddings ###
# Embeddings #
if args.model_type == "gpt2":
for param_name in ["wte.weight", "wpe.weight"]:
compressed_sd[f"{prefix}.{param_name}"] = state_dict[f"{prefix}.{param_name}"]
@@ -55,7 +55,7 @@ if __name__ == "__main__":
param_name = f"{prefix}.embeddings.LayerNorm.{w}"
compressed_sd[param_name] = state_dict[param_name]
### Transformer Blocks ###
# Transformer Blocks #
std_idx = 0
for teacher_idx in [0, 2, 4, 7, 9, 11]:
if args.model_type == "gpt2":
@@ -82,7 +82,7 @@ if __name__ == "__main__":
]
std_idx += 1
### Language Modeling Head ###s
# Language Modeling Head ###s
if args.model_type == "roberta":
for layer in ["lm_head.decoder.weight", "lm_head.bias"]:
compressed_sd[f"{layer}"] = state_dict[f"{layer}"]

View File

@@ -219,7 +219,7 @@ def main():
args = parser.parse_args()
sanity_checks(args)
## ARGS ##
# ARGS #
init_gpu_params(args)
set_seed(args)
if args.is_master:
@@ -236,7 +236,7 @@ def main():
os.makedirs(args.dump_path)
logger.info(f"Experiment will be dumped and logged in {args.dump_path}")
### SAVE PARAMS ###
# SAVE PARAMS #
logger.info(f"Param: {args}")
with open(os.path.join(args.dump_path, "parameters.json"), "w") as f:
json.dump(vars(args), f, indent=4)
@@ -245,7 +245,7 @@ def main():
student_config_class, student_model_class, _ = MODEL_CLASSES[args.student_type]
teacher_config_class, teacher_model_class, teacher_tokenizer_class = MODEL_CLASSES[args.teacher_type]
### TOKENIZER ###
# TOKENIZER #
tokenizer = teacher_tokenizer_class.from_pretrained(args.teacher_name)
special_tok_ids = {}
for tok_name, tok_symbol in tokenizer.special_tokens_map.items():
@@ -255,7 +255,7 @@ def main():
args.special_tok_ids = special_tok_ids
args.max_model_input_size = tokenizer.max_model_input_sizes[args.teacher_name]
## DATA LOADER ##
# DATA LOADER #
logger.info(f"Loading data from {args.data_file}")
with open(args.data_file, "rb") as fp:
data = pickle.load(fp)
@@ -275,7 +275,7 @@ def main():
train_lm_seq_dataset = LmSeqsDataset(params=args, data=data)
logger.info(f"Data loader created.")
## STUDENT ##
# STUDENT #
logger.info(f"Loading student config from {args.student_config}")
stu_architecture_config = student_config_class.from_pretrained(args.student_config)
stu_architecture_config.output_hidden_states = True
@@ -290,26 +290,26 @@ def main():
student.to(f"cuda:{args.local_rank}")
logger.info(f"Student loaded.")
## TEACHER ##
# TEACHER #
teacher = teacher_model_class.from_pretrained(args.teacher_name, output_hidden_states=True)
if args.n_gpu > 0:
teacher.to(f"cuda:{args.local_rank}")
logger.info(f"Teacher loaded from {args.teacher_name}.")
## FREEZING ##
# FREEZING #
if args.freeze_pos_embs:
freeze_pos_embeddings(student, args)
if args.freeze_token_type_embds:
freeze_token_type_embeddings(student, args)
## SANITY CHECKS ##
# SANITY CHECKS #
assert student.config.vocab_size == teacher.config.vocab_size
assert student.config.hidden_size == teacher.config.hidden_size
assert student.config.max_position_embeddings == teacher.config.max_position_embeddings
if args.mlm:
assert token_probs.size(0) == stu_architecture_config.vocab_size
## DISTILLER ##
# DISTILLER #
torch.cuda.empty_cache()
distiller = Distiller(
params=args, dataset=train_lm_seq_dataset, token_probs=token_probs, student=student, teacher=teacher