Merge branch 'master' into tf2

This commit is contained in:
thomwolf
2019-09-26 12:02:54 +02:00
10 changed files with 146 additions and 70 deletions

View File

@@ -23,7 +23,7 @@ import shutil
import numpy as np
import torch
from transformers import BertTokenizer, BertForMaskedLM
from transformers import BertTokenizer, BertForMaskedLM, RobertaTokenizer, RobertaForMaskedLM
from transformers import DistilBertForMaskedLM, DistilBertConfig
from distiller import Distiller
@@ -70,8 +70,10 @@ def main():
help="Load student initialization checkpoint.")
parser.add_argument("--from_pretrained_config", default=None, type=str,
help="Load student initialization architecture config.")
parser.add_argument("--bert_model", default='bert-base-uncased', type=str,
help="The teacher BERT model.")
parser.add_argument("--teacher_type", default="bert", choices=["bert", "roberta"],
help="Teacher type (BERT, RoBERTa).")
parser.add_argument("--teacher_name", default="bert-base-uncased", type=str,
help="The teacher model.")
parser.add_argument("--temperature", default=2., type=float,
help="Temperature for the softmax temperature.")
@@ -81,6 +83,8 @@ def main():
help="Linear weight for the MLM loss. Must be >=0.")
parser.add_argument("--alpha_mse", default=0.0, type=float,
help="Linear weight of the MSE loss. Must be >=0.")
parser.add_argument("--alpha_cos", default=0.0, type=float,
help="Linear weight of the cosine embedding loss. Must be >=0.")
parser.add_argument("--mlm_mask_prop", default=0.15, type=float,
help="Proportion of tokens for which we need to make a prediction.")
parser.add_argument("--word_mask", default=0.8, type=float,
@@ -165,11 +169,14 @@ def main():
### TOKENIZER ###
bert_tokenizer = BertTokenizer.from_pretrained(args.bert_model)
if args.teacher_type == 'bert':
tokenizer = BertTokenizer.from_pretrained(args.teacher_name)
elif args.teacher_type == 'roberta':
tokenizer = RobertaTokenizer.from_pretrained(args.teacher_name)
special_tok_ids = {}
for tok_name, tok_symbol in bert_tokenizer.special_tokens_map.items():
idx = bert_tokenizer.all_special_tokens.index(tok_symbol)
special_tok_ids[tok_name] = bert_tokenizer.all_special_ids[idx]
for tok_name, tok_symbol in tokenizer.special_tokens_map.items():
idx = tokenizer.all_special_tokens.index(tok_symbol)
special_tok_ids[tok_name] = tokenizer.all_special_ids[idx]
logger.info(f'Special tokens {special_tok_ids}')
args.special_tok_ids = special_tok_ids
@@ -197,16 +204,17 @@ def main():
## STUDENT ##
if args.from_pretrained_weights is not None:
assert os.path.isfile(os.path.join(args.from_pretrained_weights))
assert os.path.isfile(os.path.join(args.from_pretrained_config))
assert os.path.isfile(args.from_pretrained_weights)
assert os.path.isfile(args.from_pretrained_config)
logger.info(f'Loading pretrained weights from {args.from_pretrained_weights}')
logger.info(f'Loading pretrained config from {args.from_pretrained_config}')
stu_architecture_config = DistilBertConfig.from_json_file(args.from_pretrained_config)
stu_architecture_config.output_hidden_states = True
student = DistilBertForMaskedLM.from_pretrained(args.from_pretrained_weights,
config=stu_architecture_config)
config=stu_architecture_config)
else:
args.vocab_size_or_config_json_file = args.vocab_size
stu_architecture_config = DistilBertConfig(**vars(args))
stu_architecture_config = DistilBertConfig(**vars(args), output_hidden_states=True)
student = DistilBertForMaskedLM(stu_architecture_config)
@@ -216,10 +224,13 @@ def main():
## TEACHER ##
teacher = BertForMaskedLM.from_pretrained(args.bert_model)
if args.teacher_type == 'bert':
teacher = BertForMaskedLM.from_pretrained(args.teacher_name, output_hidden_states=True)
elif args.teacher_type == 'roberta':
teacher = RobertaForMaskedLM.from_pretrained(args.teacher_name, output_hidden_states=True)
if args.n_gpu > 0:
teacher.to(f'cuda:{args.local_rank}')
logger.info(f'Teacher loaded from {args.bert_model}.')
logger.info(f'Teacher loaded from {args.teacher_name}.')
## DISTILLER ##
torch.cuda.empty_cache()