update train.py
This commit is contained in:
@@ -13,7 +13,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
Training DistilBERT.
|
Training the distilled model.
|
||||||
|
Supported architectures include: BERT -> DistilBERT, RoBERTa -> DistilRoBERTa, GPT2 -> DistilGPT2.
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
import argparse
|
import argparse
|
||||||
@@ -23,68 +24,96 @@ import shutil
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import BertTokenizer, BertForMaskedLM, RobertaTokenizer, RobertaForMaskedLM
|
from transformers import BertConfig, BertForMaskedLM, BertTokenizer
|
||||||
from transformers import DistilBertForMaskedLM, DistilBertConfig
|
from transformers import RobertaConfig, RobertaForMaskedLM, RobertaTokenizer
|
||||||
|
from transformers import DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer
|
||||||
|
from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer
|
||||||
|
|
||||||
from distiller import Distiller
|
from distiller import Distiller
|
||||||
from utils import git_log, logger, init_gpu_params, set_seed
|
from utils import git_log, logger, init_gpu_params, set_seed
|
||||||
from dataset import Dataset
|
from lm_seqs_dataset import LmSeqsDataset
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_CLASSES = {
|
||||||
|
'distilbert': (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer),
|
||||||
|
'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer),
|
||||||
|
'bert': (BertConfig, BertForMaskedLM, BertTokenizer),
|
||||||
|
'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer)
|
||||||
|
}
|
||||||
|
|
||||||
|
def sanity_checks(args):
|
||||||
|
"""
|
||||||
|
A bunch of args sanity checks to perform even starting...
|
||||||
|
"""
|
||||||
|
assert (args.mlm and args.alpha_mlm > 0.) or (not args.mlm and args.alpha_mlm == 0.)
|
||||||
|
assert (args.alpha_mlm > 0. and args.alpha_clm == 0.) or (args.alpha_mlm == 0. and args.alpha_clm > 0.)
|
||||||
|
if args.mlm:
|
||||||
|
assert os.path.isfile(args.token_counts)
|
||||||
|
assert (args.student_type in ['roberta', 'distilbert']) and (args.teacher_type in ['roberta', 'bert'])
|
||||||
|
else:
|
||||||
|
assert (args.student_type in ['gpt2']) and (args.teacher_type in ['gpt2'])
|
||||||
|
|
||||||
|
assert args.teacher_type == args.student_type or (args.student_type=='distilbert' and args.teacher_type=='bert')
|
||||||
|
assert os.path.isfile(args.student_config)
|
||||||
|
if args.student_pretrained_weights is not None:
|
||||||
|
assert os.path.isfile(args.student_pretrained_weights)
|
||||||
|
|
||||||
|
if args.freeze_token_type_embds: assert args.student_type in ['roberta']
|
||||||
|
|
||||||
|
assert args.alpha_ce >= 0.
|
||||||
|
assert args.alpha_mlm >= 0.
|
||||||
|
assert args.alpha_clm >= 0.
|
||||||
|
assert args.alpha_mse >= 0.
|
||||||
|
assert args.alpha_cos >= 0.
|
||||||
|
assert args.alpha_ce + args.alpha_mlm + args.alpha_clm + args.alpha_mse + args.alpha_cos > 0.
|
||||||
|
|
||||||
|
def freeze_pos_embeddings(student, args):
|
||||||
|
if args.student_type == 'roberta':
|
||||||
|
student.roberta.embeddings.position_embeddings.weight.requires_grad = False
|
||||||
|
elif args.student_type == 'gpt2':
|
||||||
|
student.transformer.wpe.weight.requires_grad = False
|
||||||
|
|
||||||
|
def freeze_token_type_embeddings(student, args):
|
||||||
|
if args.student_type == 'roberta':
|
||||||
|
student.roberta.embeddings.token_type_embeddings.weight.requires_grad = False
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="Training")
|
parser = argparse.ArgumentParser(description="Training")
|
||||||
|
parser.add_argument("--force", action='store_true',
|
||||||
|
help="Overwrite dump_path if it already exists.")
|
||||||
|
|
||||||
parser.add_argument("--dump_path", type=str, required=True,
|
parser.add_argument("--dump_path", type=str, required=True,
|
||||||
help="The output directory (log, checkpoints, parameters, etc.)")
|
help="The output directory (log, checkpoints, parameters, etc.)")
|
||||||
parser.add_argument("--data_file", type=str, required=True,
|
parser.add_argument("--data_file", type=str, required=True,
|
||||||
help="The binarized file (tokenized + tokens_to_ids) and grouped by sequence.")
|
help="The binarized file (tokenized + tokens_to_ids) and grouped by sequence.")
|
||||||
parser.add_argument("--token_counts", type=str, required=True,
|
|
||||||
help="The token counts in the data_file for MLM.")
|
|
||||||
parser.add_argument("--force", action='store_true',
|
|
||||||
help="Overwrite dump_path if it already exists.")
|
|
||||||
|
|
||||||
parser.add_argument("--vocab_size", default=30522, type=int,
|
parser.add_argument("--student_type", type=str, choices=["distilbert", "roberta", "gpt2"], required=True,
|
||||||
help="The vocabulary size.")
|
help="The student type (DistilBERT, RoBERTa).")
|
||||||
parser.add_argument("--max_position_embeddings", default=512, type=int,
|
parser.add_argument("--student_config", type=str, required=True,
|
||||||
help="Maximum sequence length we can model (including [CLS] and [SEP]).")
|
help="Path to the student configuration.")
|
||||||
parser.add_argument("--sinusoidal_pos_embds", action='store_false',
|
parser.add_argument("--student_pretrained_weights", default=None, type=str,
|
||||||
help="If true, the position embeddings are simply fixed with sinusoidal embeddings.")
|
|
||||||
parser.add_argument("--n_layers", default=6, type=int,
|
|
||||||
help="Number of Transformer blocks.")
|
|
||||||
parser.add_argument("--n_heads", default=12, type=int,
|
|
||||||
help="Number of heads in the self-attention module.")
|
|
||||||
parser.add_argument("--dim", default=768, type=int,
|
|
||||||
help="Dimension through the network. Must be divisible by n_heads")
|
|
||||||
parser.add_argument("--hidden_dim", default=3072, type=int,
|
|
||||||
help="Intermediate dimension in the FFN.")
|
|
||||||
parser.add_argument("--dropout", default=0.1, type=float,
|
|
||||||
help="Dropout.")
|
|
||||||
parser.add_argument("--attention_dropout", default=0.1, type=float,
|
|
||||||
help="Dropout in self-attention.")
|
|
||||||
parser.add_argument("--activation", default='gelu', type=str,
|
|
||||||
help="Activation to use in self-attention")
|
|
||||||
parser.add_argument("--tie_weights_", action='store_false',
|
|
||||||
help="If true, we tie the embeddings matrix with the projection over the vocabulary matrix. Default is true.")
|
|
||||||
|
|
||||||
parser.add_argument("--from_pretrained_weights", default=None, type=str,
|
|
||||||
help="Load student initialization checkpoint.")
|
help="Load student initialization checkpoint.")
|
||||||
parser.add_argument("--from_pretrained_config", default=None, type=str,
|
|
||||||
help="Load student initialization architecture config.")
|
parser.add_argument("--teacher_type", choices=["bert", "roberta", "gpt2"], required=True,
|
||||||
parser.add_argument("--teacher_type", default="bert", choices=["bert", "roberta"],
|
|
||||||
help="Teacher type (BERT, RoBERTa).")
|
help="Teacher type (BERT, RoBERTa).")
|
||||||
parser.add_argument("--teacher_name", default="bert-base-uncased", type=str,
|
parser.add_argument("--teacher_name", type=str, required=True,
|
||||||
help="The teacher model.")
|
help="The teacher model.")
|
||||||
|
|
||||||
parser.add_argument("--temperature", default=2., type=float,
|
parser.add_argument("--temperature", default=2., type=float,
|
||||||
help="Temperature for the softmax temperature.")
|
help="Temperature for the softmax temperature.")
|
||||||
parser.add_argument("--alpha_ce", default=0.5, type=float,
|
parser.add_argument("--alpha_ce", default=0.5, type=float,
|
||||||
help="Linear weight for the distillation loss. Must be >=0.")
|
help="Linear weight for the distillation loss. Must be >=0.")
|
||||||
parser.add_argument("--alpha_mlm", default=0.5, type=float,
|
parser.add_argument("--alpha_mlm", default=0.0, type=float,
|
||||||
help="Linear weight for the MLM loss. Must be >=0.")
|
help="Linear weight for the MLM loss. Must be >=0. Should be used in coonjunction with `mlm` flag.")
|
||||||
|
parser.add_argument("--alpha_clm", default=0.5, type=float,
|
||||||
|
help="Linear weight for the CLM loss. Must be >=0.")
|
||||||
parser.add_argument("--alpha_mse", default=0.0, type=float,
|
parser.add_argument("--alpha_mse", default=0.0, type=float,
|
||||||
help="Linear weight of the MSE loss. Must be >=0.")
|
help="Linear weight of the MSE loss. Must be >=0.")
|
||||||
parser.add_argument("--alpha_cos", default=0.0, type=float,
|
parser.add_argument("--alpha_cos", default=0.0, type=float,
|
||||||
help="Linear weight of the cosine embedding loss. Must be >=0.")
|
help="Linear weight of the cosine embedding loss. Must be >=0.")
|
||||||
|
|
||||||
|
parser.add_argument("--mlm", action="store_true",
|
||||||
|
help="The LM step: MLM or CLM. If `mlm` is True, the MLM is used over CLM.")
|
||||||
parser.add_argument("--mlm_mask_prop", default=0.15, type=float,
|
parser.add_argument("--mlm_mask_prop", default=0.15, type=float,
|
||||||
help="Proportion of tokens for which we need to make a prediction.")
|
help="Proportion of tokens for which we need to make a prediction.")
|
||||||
parser.add_argument("--word_mask", default=0.8, type=float,
|
parser.add_argument("--word_mask", default=0.8, type=float,
|
||||||
@@ -95,17 +124,20 @@ def main():
|
|||||||
help="Proportion of tokens to randomly replace.")
|
help="Proportion of tokens to randomly replace.")
|
||||||
parser.add_argument("--mlm_smoothing", default=0.7, type=float,
|
parser.add_argument("--mlm_smoothing", default=0.7, type=float,
|
||||||
help="Smoothing parameter to emphasize more rare tokens (see XLM, similar to word2vec).")
|
help="Smoothing parameter to emphasize more rare tokens (see XLM, similar to word2vec).")
|
||||||
|
parser.add_argument("--token_counts", type=str,
|
||||||
|
help="The token counts in the data_file for MLM.")
|
||||||
|
|
||||||
parser.add_argument("--restrict_ce_to_mask", action='store_true',
|
parser.add_argument("--restrict_ce_to_mask", action='store_true',
|
||||||
help="If true, compute the distilation loss only the [MLM] prediction distribution.")
|
help="If true, compute the distilation loss only the [MLM] prediction distribution.")
|
||||||
|
parser.add_argument("--freeze_pos_embs", action="store_true",
|
||||||
|
help="Freeze positional embeddings during distillation. For student_type in ['roberta', 'gpt2'] only.")
|
||||||
|
parser.add_argument("--freeze_token_type_embds", action="store_true",
|
||||||
|
help="Freeze token type embeddings during distillation if existent. For student_type in ['roberta'] only.")
|
||||||
|
|
||||||
parser.add_argument("--n_epoch", type=int, default=3,
|
parser.add_argument("--n_epoch", type=int, default=3,
|
||||||
help="Number of pass on the whole dataset.")
|
help="Number of pass on the whole dataset.")
|
||||||
parser.add_argument("--batch_size", type=int, default=5,
|
parser.add_argument("--batch_size", type=int, default=5,
|
||||||
help="Batch size (for each process).")
|
help="Batch size (for each process).")
|
||||||
parser.add_argument("--tokens_per_batch", type=int, default=-1,
|
|
||||||
help="If specified, modify the batches so that they have approximately this number of tokens.")
|
|
||||||
parser.add_argument("--shuffle", action='store_false',
|
|
||||||
help="If true, shuffle the sequence order. Default is true.")
|
|
||||||
parser.add_argument("--group_by_size", action='store_false',
|
parser.add_argument("--group_by_size", action='store_false',
|
||||||
help="If true, group sequences that have similar length into the same batch. Default is true.")
|
help="If true, group sequences that have similar length into the same batch. Default is true.")
|
||||||
|
|
||||||
@@ -141,6 +173,7 @@ def main():
|
|||||||
parser.add_argument("--checkpoint_interval", type=int, default=4000,
|
parser.add_argument("--checkpoint_interval", type=int, default=4000,
|
||||||
help="Checkpoint interval.")
|
help="Checkpoint interval.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
sanity_checks(args)
|
||||||
|
|
||||||
|
|
||||||
## ARGS ##
|
## ARGS ##
|
||||||
@@ -164,21 +197,19 @@ def main():
|
|||||||
with open(os.path.join(args.dump_path, 'parameters.json'), 'w') as f:
|
with open(os.path.join(args.dump_path, 'parameters.json'), 'w') as f:
|
||||||
json.dump(vars(args), f, indent=4)
|
json.dump(vars(args), f, indent=4)
|
||||||
git_log(args.dump_path)
|
git_log(args.dump_path)
|
||||||
assert (args.from_pretrained_weights is None and args.from_pretrained_config is None) or \
|
|
||||||
(args.from_pretrained_weights is not None and args.from_pretrained_config is not None)
|
|
||||||
|
|
||||||
|
student_config_class, student_model_class, _ = MODEL_CLASSES[args.student_type]
|
||||||
|
teacher_config_class, teacher_model_class, teacher_tokenizer_class = MODEL_CLASSES[args.teacher_type]
|
||||||
|
|
||||||
### TOKENIZER ###
|
### TOKENIZER ###
|
||||||
if args.teacher_type == 'bert':
|
tokenizer = teacher_tokenizer_class.from_pretrained(args.teacher_name)
|
||||||
tokenizer = BertTokenizer.from_pretrained(args.teacher_name)
|
|
||||||
elif args.teacher_type == 'roberta':
|
|
||||||
tokenizer = RobertaTokenizer.from_pretrained(args.teacher_name)
|
|
||||||
special_tok_ids = {}
|
special_tok_ids = {}
|
||||||
for tok_name, tok_symbol in tokenizer.special_tokens_map.items():
|
for tok_name, tok_symbol in tokenizer.special_tokens_map.items():
|
||||||
idx = tokenizer.all_special_tokens.index(tok_symbol)
|
idx = tokenizer.all_special_tokens.index(tok_symbol)
|
||||||
special_tok_ids[tok_name] = tokenizer.all_special_ids[idx]
|
special_tok_ids[tok_name] = tokenizer.all_special_ids[idx]
|
||||||
logger.info(f'Special tokens {special_tok_ids}')
|
logger.info(f'Special tokens {special_tok_ids}')
|
||||||
args.special_tok_ids = special_tok_ids
|
args.special_tok_ids = special_tok_ids
|
||||||
|
args.max_model_input_size = tokenizer.max_model_input_sizes[args.teacher_name]
|
||||||
|
|
||||||
|
|
||||||
## DATA LOADER ##
|
## DATA LOADER ##
|
||||||
@@ -187,35 +218,34 @@ def main():
|
|||||||
data = pickle.load(fp)
|
data = pickle.load(fp)
|
||||||
|
|
||||||
|
|
||||||
assert os.path.isfile(args.token_counts)
|
if args.mlm:
|
||||||
logger.info(f'Loading token counts from {args.token_counts} (already pre-computed)')
|
logger.info(f'Loading token counts from {args.token_counts} (already pre-computed)')
|
||||||
with open(args.token_counts, 'rb') as fp:
|
with open(args.token_counts, 'rb') as fp:
|
||||||
counts = pickle.load(fp)
|
counts = pickle.load(fp)
|
||||||
assert len(counts) == args.vocab_size
|
|
||||||
token_probs = np.maximum(counts, 1) ** -args.mlm_smoothing
|
token_probs = np.maximum(counts, 1) ** -args.mlm_smoothing
|
||||||
for idx in special_tok_ids.values():
|
for idx in special_tok_ids.values():
|
||||||
token_probs[idx] = 0. # do not predict special tokens
|
token_probs[idx] = 0. # do not predict special tokens
|
||||||
token_probs = torch.from_numpy(token_probs)
|
token_probs = torch.from_numpy(token_probs)
|
||||||
|
else:
|
||||||
|
token_probs = None
|
||||||
|
|
||||||
|
|
||||||
train_dataloader = Dataset(params=args, data=data)
|
train_lm_seq_dataset = LmSeqsDataset(params=args, data=data)
|
||||||
logger.info(f'Data loader created.')
|
logger.info(f'Data loader created.')
|
||||||
|
|
||||||
|
|
||||||
## STUDENT ##
|
## STUDENT ##
|
||||||
if args.from_pretrained_weights is not None:
|
logger.info(f'Loading student config from {args.student_config}')
|
||||||
assert os.path.isfile(args.from_pretrained_weights)
|
stu_architecture_config = student_config_class.from_pretrained(args.student_config)
|
||||||
assert os.path.isfile(args.from_pretrained_config)
|
stu_architecture_config.output_hidden_states = True
|
||||||
logger.info(f'Loading pretrained weights from {args.from_pretrained_weights}')
|
|
||||||
logger.info(f'Loading pretrained config from {args.from_pretrained_config}')
|
if args.student_pretrained_weights is not None:
|
||||||
stu_architecture_config = DistilBertConfig.from_json_file(args.from_pretrained_config)
|
logger.info(f'Loading pretrained weights from {args.student_pretrained_weights}')
|
||||||
stu_architecture_config.output_hidden_states = True
|
student = student_model_class.from_pretrained(args.student_pretrained_weights,
|
||||||
student = DistilBertForMaskedLM.from_pretrained(args.from_pretrained_weights,
|
config=stu_architecture_config)
|
||||||
config=stu_architecture_config)
|
|
||||||
else:
|
else:
|
||||||
args.vocab_size_or_config_json_file = args.vocab_size
|
student = student_model_class(stu_architecture_config)
|
||||||
stu_architecture_config = DistilBertConfig(**vars(args), output_hidden_states=True)
|
|
||||||
student = DistilBertForMaskedLM(stu_architecture_config)
|
|
||||||
|
|
||||||
|
|
||||||
if args.n_gpu > 0:
|
if args.n_gpu > 0:
|
||||||
@@ -224,18 +254,31 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
## TEACHER ##
|
## TEACHER ##
|
||||||
if args.teacher_type == 'bert':
|
teacher = teacher_model_class.from_pretrained(args.teacher_name, output_hidden_states=True)
|
||||||
teacher = BertForMaskedLM.from_pretrained(args.teacher_name, output_hidden_states=True)
|
|
||||||
elif args.teacher_type == 'roberta':
|
|
||||||
teacher = RobertaForMaskedLM.from_pretrained(args.teacher_name, output_hidden_states=True)
|
|
||||||
if args.n_gpu > 0:
|
if args.n_gpu > 0:
|
||||||
teacher.to(f'cuda:{args.local_rank}')
|
teacher.to(f'cuda:{args.local_rank}')
|
||||||
logger.info(f'Teacher loaded from {args.teacher_name}.')
|
logger.info(f'Teacher loaded from {args.teacher_name}.')
|
||||||
|
|
||||||
|
|
||||||
|
## FREEZING ##
|
||||||
|
if args.freeze_pos_embs:
|
||||||
|
freeze_pos_embeddings(student, args)
|
||||||
|
if args.freeze_token_type_embds:
|
||||||
|
freeze_token_type_embeddings(student, args)
|
||||||
|
|
||||||
|
|
||||||
|
## SANITY CHECKS ##
|
||||||
|
assert student.config.vocab_size == teacher.config.vocab_size
|
||||||
|
assert student.config.hidden_size == teacher.config.hidden_size
|
||||||
|
assert student.config.max_position_embeddings == teacher.config.max_position_embeddings
|
||||||
|
if args.mlm:
|
||||||
|
assert token_probs.size(0) == stu_architecture_config.vocab_size
|
||||||
|
|
||||||
|
|
||||||
## DISTILLER ##
|
## DISTILLER ##
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
distiller = Distiller(params=args,
|
distiller = Distiller(params=args,
|
||||||
dataloader=train_dataloader,
|
dataset=train_lm_seq_dataset,
|
||||||
token_probs=token_probs,
|
token_probs=token_probs,
|
||||||
student=student,
|
student=student,
|
||||||
teacher=teacher)
|
teacher=teacher)
|
||||||
|
|||||||
Reference in New Issue
Block a user