[distillation] big update w/ new weights

This commit is contained in:
VictorSanh
2019-09-19 19:25:21 +00:00
parent 0d1dad6d53
commit 354944e607
7 changed files with 132 additions and 54 deletions

View File

@@ -23,7 +23,7 @@ import shutil
import numpy as np
import torch
from pytorch_transformers import BertTokenizer, BertForMaskedLM
from pytorch_transformers import BertTokenizer, BertForMaskedLM, RobertaTokenizer, RobertaForMaskedLM
from pytorch_transformers import DistilBertForMaskedLM, DistilBertConfig
from distiller import Distiller
@@ -70,8 +70,10 @@ def main():
help="Load student initialization checkpoint.")
parser.add_argument("--from_pretrained_config", default=None, type=str,
help="Load student initialization architecture config.")
parser.add_argument("--bert_model", default='bert-base-uncased', type=str,
help="The teacher BERT model.")
parser.add_argument("--teacher_type", default="bert", choices=["bert", "roberta"],
help="Teacher type (BERT, RoBERTa).")
parser.add_argument("--teacher_name", default="bert-base-uncased", type=str,
help="The teacher model.")
parser.add_argument("--temperature", default=2., type=float,
help="Temperature for the softmax temperature.")
@@ -81,6 +83,8 @@ def main():
help="Linear weight for the MLM loss. Must be >=0.")
parser.add_argument("--alpha_mse", default=0.0, type=float,
help="Linear weight of the MSE loss. Must be >=0.")
parser.add_argument("--alpha_cos", default=0.0, type=float,
help="Linear weight of the cosine embedding loss. Must be >=0.")
parser.add_argument("--mlm_mask_prop", default=0.15, type=float,
help="Proportion of tokens for which we need to make a prediction.")
parser.add_argument("--word_mask", default=0.8, type=float,
@@ -165,11 +169,14 @@ def main():
### TOKENIZER ###
bert_tokenizer = BertTokenizer.from_pretrained(args.bert_model)
if args.teacher_type == 'bert':
tokenizer = BertTokenizer.from_pretrained(args.teacher_name)
elif args.teacher_type == 'roberta':
tokenizer = RobertaTokenizer.from_pretrained(args.teacher_name)
special_tok_ids = {}
for tok_name, tok_symbol in bert_tokenizer.special_tokens_map.items():
idx = bert_tokenizer.all_special_tokens.index(tok_symbol)
special_tok_ids[tok_name] = bert_tokenizer.all_special_ids[idx]
for tok_name, tok_symbol in tokenizer.special_tokens_map.items():
idx = tokenizer.all_special_tokens.index(tok_symbol)
special_tok_ids[tok_name] = tokenizer.all_special_ids[idx]
logger.info(f'Special tokens {special_tok_ids}')
args.special_tok_ids = special_tok_ids
@@ -202,11 +209,12 @@ def main():
logger.info(f'Loading pretrained weights from {args.from_pretrained_weights}')
logger.info(f'Loading pretrained config from {args.from_pretrained_config}')
stu_architecture_config = DistilBertConfig.from_json_file(args.from_pretrained_config)
stu_architecture_config.output_hidden_states = True
student = DistilBertForMaskedLM.from_pretrained(args.from_pretrained_weights,
config=stu_architecture_config)
config=stu_architecture_config)
else:
args.vocab_size_or_config_json_file = args.vocab_size
stu_architecture_config = DistilBertConfig(**vars(args))
stu_architecture_config = DistilBertConfig(**vars(args), output_hidden_states=True)
student = DistilBertForMaskedLM(stu_architecture_config)
@@ -216,10 +224,13 @@ def main():
## TEACHER ##
teacher = BertForMaskedLM.from_pretrained(args.bert_model)
if args.teacher_type == 'bert':
teacher = BertForMaskedLM.from_pretrained(args.teacher_name, output_hidden_states=True)
elif args.teacher_type == 'roberta':
teacher = RobertaForMaskedLM.from_pretrained(args.teacher_name, output_hidden_states=True)
if args.n_gpu > 0:
teacher.to(f'cuda:{args.local_rank}')
logger.info(f'Teacher loaded from {args.bert_model}.')
logger.info(f'Teacher loaded from {args.teacher_name}.')
## DISTILLER ##
torch.cuda.empty_cache()