From 912fdff899cf0fd674ed357e46a0209311aefad2 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Mon, 12 Aug 2019 13:49:50 -0400 Subject: [PATCH] [RoBERTa] Update `run_glue` for RoBERTa --- examples/run_glue.py | 13 +++++++++---- examples/utils_glue.py | 17 +++++++++++++---- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/examples/run_glue.py b/examples/run_glue.py index a939ea373b..f6cd73ed0b 100644 --- a/examples/run_glue.py +++ b/examples/run_glue.py @@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" Finetuning the library models for sequence classification on GLUE (Bert, XLM, XLNet).""" +""" Finetuning the library models for sequence classification on GLUE (Bert, XLM, XLNet, RoBERTa).""" from __future__ import absolute_import, division, print_function @@ -33,6 +33,9 @@ from tqdm import tqdm, trange from pytorch_transformers import (WEIGHTS_NAME, BertConfig, BertForSequenceClassification, BertTokenizer, + RobertaConfig, + RobertaForSequenceClassification, + RobertaTokenizer, XLMConfig, XLMForSequenceClassification, XLMTokenizer, XLNetConfig, XLNetForSequenceClassification, @@ -45,12 +48,13 @@ from utils_glue import (compute_metrics, convert_examples_to_features, logger = logging.getLogger(__name__) -ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, XLMConfig)), ()) +ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, XLMConfig, RobertaConfig)), ()) MODEL_CLASSES = { 'bert': (BertConfig, BertForSequenceClassification, BertTokenizer), 'xlnet': (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer), 'xlm': (XLMConfig, XLMForSequenceClassification, XLMTokenizer), + 'roberta': (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer), } @@ -214,7 +218,7 @@ def evaluate(args, model, tokenizer, prefix=""): with torch.no_grad(): inputs = {'input_ids': batch[0], 'attention_mask': batch[1], - 'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids + 'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM and RoBERTa don't use segment_ids 'labels': batch[3]} outputs = model(**inputs) tmp_eval_loss, logits = outputs[:2] @@ -268,8 +272,9 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False): features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, output_mode, cls_token_at_end=bool(args.model_type in ['xlnet']), # xlnet has a cls token at the end cls_token=tokenizer.cls_token, - sep_token=tokenizer.sep_token, cls_token_segment_id=2 if args.model_type in ['xlnet'] else 0, + sep_token=tokenizer.sep_token, + sep_token_extra=bool(args.model_type in ['roberta']), # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805 pad_on_left=bool(args.model_type in ['xlnet']), # pad on the left for xlnet pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0) if args.local_rank in [-1, 0]: diff --git a/examples/utils_glue.py b/examples/utils_glue.py index bba9a901a8..c955e4d0ce 100644 --- a/examples/utils_glue.py +++ b/examples/utils_glue.py @@ -390,10 +390,16 @@ class WnliProcessor(DataProcessor): def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer, output_mode, - cls_token_at_end=False, pad_on_left=False, - cls_token='[CLS]', sep_token='[SEP]', pad_token=0, - sequence_a_segment_id=0, sequence_b_segment_id=1, - cls_token_segment_id=1, pad_token_segment_id=0, + cls_token_at_end=False, + cls_token='[CLS]', + cls_token_segment_id=1, + sep_token='[SEP]', + sep_token_extra=False, + pad_on_left=False, + pad_token=0, + pad_token_segment_id=0, + sequence_a_segment_id=0, + sequence_b_segment_id=1, mask_padding_with_zero=True): """ Loads a data file into a list of `InputBatch`s `cls_token_at_end` define the location of the CLS token: @@ -442,6 +448,9 @@ def convert_examples_to_features(examples, label_list, max_seq_length, # used as as the "sentence vector". Note that this only makes sense because # the entire model is fine-tuned. tokens = tokens_a + [sep_token] + if sep_token_extra: + # roberta uses an extra separator b/w pairs of sentences + tokens += [sep_token] segment_ids = [sequence_a_segment_id] * len(tokens) if tokens_b: