[RoBERTa] Update run_glue for RoBERTa
This commit is contained in:
@@ -13,7 +13,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" Finetuning the library models for sequence classification on GLUE (Bert, XLM, XLNet)."""
|
""" Finetuning the library models for sequence classification on GLUE (Bert, XLM, XLNet, RoBERTa)."""
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
@@ -33,6 +33,9 @@ from tqdm import tqdm, trange
|
|||||||
|
|
||||||
from pytorch_transformers import (WEIGHTS_NAME, BertConfig,
|
from pytorch_transformers import (WEIGHTS_NAME, BertConfig,
|
||||||
BertForSequenceClassification, BertTokenizer,
|
BertForSequenceClassification, BertTokenizer,
|
||||||
|
RobertaConfig,
|
||||||
|
RobertaForSequenceClassification,
|
||||||
|
RobertaTokenizer,
|
||||||
XLMConfig, XLMForSequenceClassification,
|
XLMConfig, XLMForSequenceClassification,
|
||||||
XLMTokenizer, XLNetConfig,
|
XLMTokenizer, XLNetConfig,
|
||||||
XLNetForSequenceClassification,
|
XLNetForSequenceClassification,
|
||||||
@@ -45,12 +48,13 @@ from utils_glue import (compute_metrics, convert_examples_to_features,
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, XLMConfig)), ())
|
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, XLMConfig, RobertaConfig)), ())
|
||||||
|
|
||||||
MODEL_CLASSES = {
|
MODEL_CLASSES = {
|
||||||
'bert': (BertConfig, BertForSequenceClassification, BertTokenizer),
|
'bert': (BertConfig, BertForSequenceClassification, BertTokenizer),
|
||||||
'xlnet': (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer),
|
'xlnet': (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer),
|
||||||
'xlm': (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
|
'xlm': (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
|
||||||
|
'roberta': (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -214,7 +218,7 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
inputs = {'input_ids': batch[0],
|
inputs = {'input_ids': batch[0],
|
||||||
'attention_mask': batch[1],
|
'attention_mask': batch[1],
|
||||||
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids
|
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM and RoBERTa don't use segment_ids
|
||||||
'labels': batch[3]}
|
'labels': batch[3]}
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
tmp_eval_loss, logits = outputs[:2]
|
tmp_eval_loss, logits = outputs[:2]
|
||||||
@@ -268,8 +272,9 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
|
|||||||
features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, output_mode,
|
features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, output_mode,
|
||||||
cls_token_at_end=bool(args.model_type in ['xlnet']), # xlnet has a cls token at the end
|
cls_token_at_end=bool(args.model_type in ['xlnet']), # xlnet has a cls token at the end
|
||||||
cls_token=tokenizer.cls_token,
|
cls_token=tokenizer.cls_token,
|
||||||
sep_token=tokenizer.sep_token,
|
|
||||||
cls_token_segment_id=2 if args.model_type in ['xlnet'] else 0,
|
cls_token_segment_id=2 if args.model_type in ['xlnet'] else 0,
|
||||||
|
sep_token=tokenizer.sep_token,
|
||||||
|
sep_token_extra=bool(args.model_type in ['roberta']), # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
|
||||||
pad_on_left=bool(args.model_type in ['xlnet']), # pad on the left for xlnet
|
pad_on_left=bool(args.model_type in ['xlnet']), # pad on the left for xlnet
|
||||||
pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0)
|
pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0)
|
||||||
if args.local_rank in [-1, 0]:
|
if args.local_rank in [-1, 0]:
|
||||||
|
|||||||
@@ -390,10 +390,16 @@ class WnliProcessor(DataProcessor):
|
|||||||
|
|
||||||
def convert_examples_to_features(examples, label_list, max_seq_length,
|
def convert_examples_to_features(examples, label_list, max_seq_length,
|
||||||
tokenizer, output_mode,
|
tokenizer, output_mode,
|
||||||
cls_token_at_end=False, pad_on_left=False,
|
cls_token_at_end=False,
|
||||||
cls_token='[CLS]', sep_token='[SEP]', pad_token=0,
|
cls_token='[CLS]',
|
||||||
sequence_a_segment_id=0, sequence_b_segment_id=1,
|
cls_token_segment_id=1,
|
||||||
cls_token_segment_id=1, pad_token_segment_id=0,
|
sep_token='[SEP]',
|
||||||
|
sep_token_extra=False,
|
||||||
|
pad_on_left=False,
|
||||||
|
pad_token=0,
|
||||||
|
pad_token_segment_id=0,
|
||||||
|
sequence_a_segment_id=0,
|
||||||
|
sequence_b_segment_id=1,
|
||||||
mask_padding_with_zero=True):
|
mask_padding_with_zero=True):
|
||||||
""" Loads a data file into a list of `InputBatch`s
|
""" Loads a data file into a list of `InputBatch`s
|
||||||
`cls_token_at_end` define the location of the CLS token:
|
`cls_token_at_end` define the location of the CLS token:
|
||||||
@@ -442,6 +448,9 @@ def convert_examples_to_features(examples, label_list, max_seq_length,
|
|||||||
# used as as the "sentence vector". Note that this only makes sense because
|
# used as as the "sentence vector". Note that this only makes sense because
|
||||||
# the entire model is fine-tuned.
|
# the entire model is fine-tuned.
|
||||||
tokens = tokens_a + [sep_token]
|
tokens = tokens_a + [sep_token]
|
||||||
|
if sep_token_extra:
|
||||||
|
# roberta uses an extra separator b/w pairs of sentences
|
||||||
|
tokens += [sep_token]
|
||||||
segment_ids = [sequence_a_segment_id] * len(tokens)
|
segment_ids = [sequence_a_segment_id] * len(tokens)
|
||||||
|
|
||||||
if tokens_b:
|
if tokens_b:
|
||||||
|
|||||||
Reference in New Issue
Block a user