From 75635072e11afbf716011464d30869efbe226886 Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Fri, 6 Sep 2019 17:23:47 -0400 Subject: [PATCH] Updated GLUE script to add DistilBERT. Cleaned up unused args in the utils file. --- examples/run_glue.py | 15 +++++++-------- examples/utils_glue.py | 14 ++------------ 2 files changed, 9 insertions(+), 20 deletions(-) diff --git a/examples/run_glue.py b/examples/run_glue.py index e20f6d84c4..e8c2edc833 100644 --- a/examples/run_glue.py +++ b/examples/run_glue.py @@ -39,7 +39,10 @@ from pytorch_transformers import (WEIGHTS_NAME, BertConfig, XLMConfig, XLMForSequenceClassification, XLMTokenizer, XLNetConfig, XLNetForSequenceClassification, - XLNetTokenizer) + XLNetTokenizer, + DistilBertConfig, + DistilBertForSequenceClassification, + DistilBertTokenizer) from pytorch_transformers import AdamW, WarmupLinearSchedule @@ -55,6 +58,7 @@ MODEL_CLASSES = { 'xlnet': (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer), 'xlm': (XLMConfig, XLMForSequenceClassification, XLMTokenizer), 'roberta': (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer), + 'distilbert': (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer) } @@ -128,7 +132,7 @@ def train(args, train_dataset, model, tokenizer): batch = tuple(t.to(args.device) for t in batch) inputs = {'input_ids': batch[0], 'attention_mask': batch[1], - 'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM and RoBERTa don't use segment_ids + 'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM, DistilBERT and RoBERTa don't use segment_ids 'labels': batch[3]} outputs = model(**inputs) loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc) @@ -218,7 +222,7 @@ def evaluate(args, model, tokenizer, prefix=""): with torch.no_grad(): inputs = {'input_ids': batch[0], 'attention_mask': batch[1], - 'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM and RoBERTa don't use segment_ids + 'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM, DistilBERT and RoBERTa don't use segment_ids 'labels': batch[3]} outputs = model(**inputs) tmp_eval_loss, logits = outputs[:2] @@ -273,11 +277,6 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False): label_list[1], label_list[2] = label_list[2], label_list[1] examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir) features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, output_mode, - cls_token_at_end=bool(args.model_type in ['xlnet']), # xlnet has a cls token at the end - cls_token=tokenizer.cls_token, - cls_token_segment_id=2 if args.model_type in ['xlnet'] else 0, - sep_token=tokenizer.sep_token, - sep_token_extra=bool(args.model_type in ['roberta']), # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805 pad_on_left=bool(args.model_type in ['xlnet']), # pad on the left for xlnet pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0], pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0, diff --git a/examples/utils_glue.py b/examples/utils_glue.py index 96023e149c..8d0ed737d7 100644 --- a/examples/utils_glue.py +++ b/examples/utils_glue.py @@ -390,22 +390,12 @@ class WnliProcessor(DataProcessor): def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer, output_mode, - cls_token_at_end=False, - cls_token='[CLS]', - cls_token_segment_id=1, - sep_token='[SEP]', - sep_token_extra=False, pad_on_left=False, pad_token=0, pad_token_segment_id=0, - sequence_a_segment_id=0, - sequence_b_segment_id=1, mask_padding_with_zero=True): - """ Loads a data file into a list of `InputBatch`s - `cls_token_at_end` define the location of the CLS token: - - False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP] - - True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS] - `cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet) + """ + Loads a data file into a list of `InputBatch`s """ label_map = {label : i for i, label in enumerate(label_list)}