changing is_regression to unified API
This commit is contained in:
@@ -591,3 +591,15 @@ output_modes = {
|
|||||||
"rte": "classification",
|
"rte": "classification",
|
||||||
"wnli": "classification",
|
"wnli": "classification",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
GLUE_TASKS_NUM_LABELS = {
|
||||||
|
"cola": 2,
|
||||||
|
"mnli": 3,
|
||||||
|
"mrpc": 2,
|
||||||
|
"sst-2": 2,
|
||||||
|
"sts-b": 1,
|
||||||
|
"qqp": 2,
|
||||||
|
"qnli": 2,
|
||||||
|
"rte": 2,
|
||||||
|
"wnli": 2,
|
||||||
|
}
|
||||||
|
|||||||
@@ -28,16 +28,16 @@ from pytorch_pretrained_bert.modeling_xlnet import (CONFIG_NAME, WEIGHTS_NAME,
|
|||||||
XLNetForSequenceClassification,
|
XLNetForSequenceClassification,
|
||||||
load_tf_weights_in_xlnet)
|
load_tf_weights_in_xlnet)
|
||||||
|
|
||||||
GLUE_TASKS = {
|
GLUE_TASKS_NUM_LABELS = {
|
||||||
"cola": "classification",
|
"cola": 2,
|
||||||
"mnli": "classification",
|
"mnli": 3,
|
||||||
"mrpc": "classification",
|
"mrpc": 2,
|
||||||
"sst-2": "classification",
|
"sst-2": 2,
|
||||||
"sts-b": "regression",
|
"sts-b": 1,
|
||||||
"qqp": "classification",
|
"qqp": 2,
|
||||||
"qnli": "classification",
|
"qnli": 2,
|
||||||
"rte": "classification",
|
"rte": 2,
|
||||||
"wnli": "classification",
|
"wnli": 2,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -46,9 +46,9 @@ def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, py
|
|||||||
config = XLNetConfig.from_json_file(bert_config_file)
|
config = XLNetConfig.from_json_file(bert_config_file)
|
||||||
|
|
||||||
finetuning_task = finetuning_task.lower() if finetuning_task is not None else ""
|
finetuning_task = finetuning_task.lower() if finetuning_task is not None else ""
|
||||||
if finetuning_task in GLUE_TASKS:
|
if finetuning_task in GLUE_TASKS_NUM_LABELS:
|
||||||
print("Building PyTorch XLNetForSequenceClassification model from configuration: {}".format(str(config)))
|
print("Building PyTorch XLNetForSequenceClassification model from configuration: {}".format(str(config)))
|
||||||
model = XLNetForSequenceClassification(config, is_regression=bool(GLUE_TASKS[finetuning_task] == "regression"))
|
model = XLNetForSequenceClassification(config, num_labels=GLUE_TASKS_NUM_LABELS[finetuning_task])
|
||||||
elif 'squad' in finetuning_task:
|
elif 'squad' in finetuning_task:
|
||||||
model = XLNetForQuestionAnswering(config)
|
model = XLNetForQuestionAnswering(config)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ from io import open
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from .file_utils import cached_path, WEIGHTS_NAME, CONFIG_NAME
|
from .file_utils import cached_path, WEIGHTS_NAME, CONFIG_NAME
|
||||||
|
|
||||||
@@ -1196,6 +1196,11 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
|||||||
logits = self.classifier(pooled_output)
|
logits = self.classifier(pooled_output)
|
||||||
|
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
|
if self.num_labels == 1:
|
||||||
|
# We are doing regression
|
||||||
|
loss_fct = MSELoss()
|
||||||
|
loss = loss_fct(logits.view(-1), labels.view(-1))
|
||||||
|
else:
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
return loss
|
return loss
|
||||||
|
|||||||
@@ -1175,7 +1175,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
|||||||
|
|
||||||
def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None,
|
def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None,
|
||||||
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
|
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
|
||||||
target=None, output_all_encoded_layers=True, head_mask=None):
|
labels=None, output_all_encoded_layers=True, head_mask=None):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
|
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
|
||||||
@@ -1212,11 +1212,11 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
|||||||
|
|
||||||
logits = self.lm_loss(output)
|
logits = self.lm_loss(output)
|
||||||
|
|
||||||
if target is not None:
|
if labels is not None:
|
||||||
# Flatten the tokens
|
# Flatten the tokens
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||||
loss = loss_fct(logits.view(-1, logits.size(-1)),
|
loss = loss_fct(logits.view(-1, logits.size(-1)),
|
||||||
target.view(-1))
|
labels.view(-1))
|
||||||
return loss, new_mems
|
return loss, new_mems
|
||||||
|
|
||||||
# if self.output_attentions:
|
# if self.output_attentions:
|
||||||
@@ -1305,13 +1305,13 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
|
|||||||
|
|
||||||
Outputs: Tuple of (logits or loss, mems)
|
Outputs: Tuple of (logits or loss, mems)
|
||||||
`logits or loss`:
|
`logits or loss`:
|
||||||
if target is None:
|
if labels is None:
|
||||||
Token logits with shape [batch_size, sequence_length]
|
Token logits with shape [batch_size, sequence_length]
|
||||||
else:
|
else:
|
||||||
CrossEntropy loss with the targets
|
CrossEntropy loss with the targets
|
||||||
`new_mems`: list (num layers) of updated mem states at the entry of each layer
|
`new_mems`: list (num layers) of updated mem states at the entry of each layer
|
||||||
each mem state is a torch.FloatTensor of size [self.config.mem_len, batch_size, self.config.d_model]
|
each mem state is a torch.FloatTensor of size [self.config.mem_len, batch_size, self.config.d_model]
|
||||||
Note that the first two dimensions are transposed in `mems` with regards to `input_ids` and `target`
|
Note that the first two dimensions are transposed in `mems` with regards to `input_ids` and `labels`
|
||||||
|
|
||||||
Example usage:
|
Example usage:
|
||||||
```python
|
```python
|
||||||
@@ -1328,13 +1328,13 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
def __init__(self, config, summary_type="last", use_proj=True, num_labels=2,
|
def __init__(self, config, summary_type="last", use_proj=True, num_labels=2,
|
||||||
is_regression=False, output_attentions=False, keep_multihead_output=False):
|
output_attentions=False, keep_multihead_output=False):
|
||||||
super(XLNetForSequenceClassification, self).__init__(config)
|
super(XLNetForSequenceClassification, self).__init__(config)
|
||||||
self.output_attentions = output_attentions
|
self.output_attentions = output_attentions
|
||||||
self.attn_type = config.attn_type
|
self.attn_type = config.attn_type
|
||||||
self.same_length = config.same_length
|
self.same_length = config.same_length
|
||||||
self.summary_type = summary_type
|
self.summary_type = summary_type
|
||||||
self.is_regression = is_regression
|
self.num_labels = num_labels
|
||||||
|
|
||||||
self.transformer = XLNetModel(config, output_attentions=output_attentions,
|
self.transformer = XLNetModel(config, output_attentions=output_attentions,
|
||||||
keep_multihead_output=keep_multihead_output)
|
keep_multihead_output=keep_multihead_output)
|
||||||
@@ -1342,12 +1342,12 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
|
|||||||
self.sequence_summary = XLNetSequenceSummary(config, summary_type=summary_type,
|
self.sequence_summary = XLNetSequenceSummary(config, summary_type=summary_type,
|
||||||
use_proj=use_proj, output_attentions=output_attentions,
|
use_proj=use_proj, output_attentions=output_attentions,
|
||||||
keep_multihead_output=keep_multihead_output)
|
keep_multihead_output=keep_multihead_output)
|
||||||
self.logits_proj = nn.Linear(config.d_model, num_labels if not is_regression else 1)
|
self.logits_proj = nn.Linear(config.d_model, num_labels)
|
||||||
self.apply(self.init_weights)
|
self.apply(self.init_weights)
|
||||||
|
|
||||||
def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None,
|
def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None,
|
||||||
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
|
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
|
||||||
target=None, output_all_encoded_layers=True, head_mask=None):
|
labels=None, output_all_encoded_layers=True, head_mask=None):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
|
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
|
||||||
@@ -1382,13 +1382,14 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
|
|||||||
output = self.sequence_summary(output)
|
output = self.sequence_summary(output)
|
||||||
logits = self.logits_proj(output)
|
logits = self.logits_proj(output)
|
||||||
|
|
||||||
if target is not None:
|
if labels is not None:
|
||||||
if self.is_regression:
|
if self.num_labels == 1:
|
||||||
|
# We are doing regression
|
||||||
loss_fct = MSELoss()
|
loss_fct = MSELoss()
|
||||||
loss = loss_fct(logits.view(-1), target.view(-1))
|
loss = loss_fct(logits.view(-1), labels.view(-1))
|
||||||
else:
|
else:
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
loss_fct = CrossEntropyLoss()
|
||||||
loss = loss_fct(logits.view(-1, logits.size(-1)), target.view(-1))
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
return loss, new_mems
|
return loss, new_mems
|
||||||
|
|
||||||
# if self.output_attentions:
|
# if self.output_attentions:
|
||||||
|
|||||||
Reference in New Issue
Block a user