Updated GLUE script to add DistilBERT. Cleaned up unused args in the utils file.
This commit is contained in:
@@ -39,7 +39,10 @@ from pytorch_transformers import (WEIGHTS_NAME, BertConfig,
|
|||||||
XLMConfig, XLMForSequenceClassification,
|
XLMConfig, XLMForSequenceClassification,
|
||||||
XLMTokenizer, XLNetConfig,
|
XLMTokenizer, XLNetConfig,
|
||||||
XLNetForSequenceClassification,
|
XLNetForSequenceClassification,
|
||||||
XLNetTokenizer)
|
XLNetTokenizer,
|
||||||
|
DistilBertConfig,
|
||||||
|
DistilBertForSequenceClassification,
|
||||||
|
DistilBertTokenizer)
|
||||||
|
|
||||||
from pytorch_transformers import AdamW, WarmupLinearSchedule
|
from pytorch_transformers import AdamW, WarmupLinearSchedule
|
||||||
|
|
||||||
@@ -55,6 +58,7 @@ MODEL_CLASSES = {
|
|||||||
'xlnet': (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer),
|
'xlnet': (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer),
|
||||||
'xlm': (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
|
'xlm': (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
|
||||||
'roberta': (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer),
|
'roberta': (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer),
|
||||||
|
'distilbert': (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -128,7 +132,7 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
batch = tuple(t.to(args.device) for t in batch)
|
batch = tuple(t.to(args.device) for t in batch)
|
||||||
inputs = {'input_ids': batch[0],
|
inputs = {'input_ids': batch[0],
|
||||||
'attention_mask': batch[1],
|
'attention_mask': batch[1],
|
||||||
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM and RoBERTa don't use segment_ids
|
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM, DistilBERT and RoBERTa don't use segment_ids
|
||||||
'labels': batch[3]}
|
'labels': batch[3]}
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
|
loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
|
||||||
@@ -218,7 +222,7 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
inputs = {'input_ids': batch[0],
|
inputs = {'input_ids': batch[0],
|
||||||
'attention_mask': batch[1],
|
'attention_mask': batch[1],
|
||||||
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM and RoBERTa don't use segment_ids
|
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM, DistilBERT and RoBERTa don't use segment_ids
|
||||||
'labels': batch[3]}
|
'labels': batch[3]}
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
tmp_eval_loss, logits = outputs[:2]
|
tmp_eval_loss, logits = outputs[:2]
|
||||||
@@ -273,11 +277,6 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
|
|||||||
label_list[1], label_list[2] = label_list[2], label_list[1]
|
label_list[1], label_list[2] = label_list[2], label_list[1]
|
||||||
examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
|
examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
|
||||||
features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, output_mode,
|
features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, output_mode,
|
||||||
cls_token_at_end=bool(args.model_type in ['xlnet']), # xlnet has a cls token at the end
|
|
||||||
cls_token=tokenizer.cls_token,
|
|
||||||
cls_token_segment_id=2 if args.model_type in ['xlnet'] else 0,
|
|
||||||
sep_token=tokenizer.sep_token,
|
|
||||||
sep_token_extra=bool(args.model_type in ['roberta']), # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
|
|
||||||
pad_on_left=bool(args.model_type in ['xlnet']), # pad on the left for xlnet
|
pad_on_left=bool(args.model_type in ['xlnet']), # pad on the left for xlnet
|
||||||
pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
|
pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
|
||||||
pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0,
|
pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0,
|
||||||
|
|||||||
@@ -390,22 +390,12 @@ class WnliProcessor(DataProcessor):
|
|||||||
|
|
||||||
def convert_examples_to_features(examples, label_list, max_seq_length,
|
def convert_examples_to_features(examples, label_list, max_seq_length,
|
||||||
tokenizer, output_mode,
|
tokenizer, output_mode,
|
||||||
cls_token_at_end=False,
|
|
||||||
cls_token='[CLS]',
|
|
||||||
cls_token_segment_id=1,
|
|
||||||
sep_token='[SEP]',
|
|
||||||
sep_token_extra=False,
|
|
||||||
pad_on_left=False,
|
pad_on_left=False,
|
||||||
pad_token=0,
|
pad_token=0,
|
||||||
pad_token_segment_id=0,
|
pad_token_segment_id=0,
|
||||||
sequence_a_segment_id=0,
|
|
||||||
sequence_b_segment_id=1,
|
|
||||||
mask_padding_with_zero=True):
|
mask_padding_with_zero=True):
|
||||||
""" Loads a data file into a list of `InputBatch`s
|
"""
|
||||||
`cls_token_at_end` define the location of the CLS token:
|
Loads a data file into a list of `InputBatch`s
|
||||||
- False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
|
|
||||||
- True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS]
|
|
||||||
`cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
label_map = {label : i for i, label in enumerate(label_list)}
|
label_map = {label : i for i, label in enumerate(label_list)}
|
||||||
|
|||||||
Reference in New Issue
Block a user