Merge pull request #1870 from alexzubiaga/xlnet-for-token-classification
XLNet for Token classification
This commit is contained in:
@@ -30,6 +30,7 @@ if is_tf_available():
|
||||
|
||||
from transformers.modeling_tf_xlnet import (TFXLNetModel, TFXLNetLMHeadModel,
|
||||
TFXLNetForSequenceClassification,
|
||||
TFXLNetForTokenClassification,
|
||||
TFXLNetForQuestionAnsweringSimple,
|
||||
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
else:
|
||||
@@ -42,6 +43,7 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
|
||||
all_model_classes=(TFXLNetModel, TFXLNetLMHeadModel,
|
||||
TFXLNetForSequenceClassification,
|
||||
TFXLNetForTokenClassification,
|
||||
TFXLNetForQuestionAnsweringSimple) if is_tf_available() else ()
|
||||
test_pruning = False
|
||||
|
||||
@@ -258,6 +260,26 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
list(list(mem.shape) for mem in result["mems_1"]),
|
||||
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
|
||||
|
||||
def create_and_check_xlnet_for_token_classification(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
||||
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
|
||||
config.num_labels = input_ids_1.shape[1]
|
||||
model = TFXLNetForTokenClassification(config)
|
||||
inputs = {'input_ids': input_ids_1,
|
||||
'attention_mask': input_mask,
|
||||
# 'token_type_ids': token_type_ids
|
||||
}
|
||||
logits, mems_1 = model(inputs)
|
||||
result = {
|
||||
"mems_1": [mem.numpy() for mem in mems_1],
|
||||
"logits": logits.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["logits"].shape),
|
||||
[self.batch_size, self.seq_length, config.num_labels])
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.shape) for mem in result["mems_1"]),
|
||||
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
||||
@@ -289,6 +311,10 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_xlnet_sequence_classif(*config_and_inputs)
|
||||
|
||||
def test_xlnet_token_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_xlnet_for_token_classification(*config_and_inputs)
|
||||
|
||||
def test_xlnet_qa(self):
|
||||
self.model_tester.set_seed()
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
|
||||
@@ -28,7 +28,8 @@ from transformers import is_torch_available
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (XLNetConfig, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering)
|
||||
from transformers import (XLNetConfig, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification,
|
||||
XLNetForTokenClassification, XLNetForQuestionAnswering)
|
||||
from transformers.modeling_xlnet import XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
else:
|
||||
pytestmark = pytest.mark.skip("Require Torch")
|
||||
@@ -38,7 +39,7 @@ from .configuration_common_test import ConfigTester
|
||||
|
||||
class XLNetModelTest(CommonTestCases.CommonModelTester):
|
||||
|
||||
all_model_classes=(XLNetModel, XLNetLMHeadModel,
|
||||
all_model_classes=(XLNetModel, XLNetLMHeadModel, XLNetForTokenClassification,
|
||||
XLNetForSequenceClassification, XLNetForQuestionAnswering) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
|
||||
@@ -107,10 +108,12 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
|
||||
sequence_labels = None
|
||||
lm_labels = None
|
||||
is_impossible_labels = None
|
||||
token_labels = None
|
||||
if self.use_labels:
|
||||
lm_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
is_impossible_labels = ids_tensor([self.batch_size], 2).float()
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||
|
||||
config = XLNetConfig(
|
||||
vocab_size_or_config_json_file=self.vocab_size,
|
||||
@@ -129,14 +132,14 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
|
||||
num_labels=self.type_sequence_label_size)
|
||||
|
||||
return (config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
||||
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels)
|
||||
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels)
|
||||
|
||||
def set_seed(self):
|
||||
random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
def create_and_check_xlnet_base_model(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
||||
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
|
||||
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels):
|
||||
model = XLNetModel(config)
|
||||
model.eval()
|
||||
|
||||
@@ -164,7 +167,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
|
||||
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
|
||||
|
||||
def create_and_check_xlnet_lm_head(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
||||
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
|
||||
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels):
|
||||
model = XLNetLMHeadModel(config)
|
||||
model.eval()
|
||||
|
||||
@@ -204,7 +207,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
|
||||
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
|
||||
|
||||
def create_and_check_xlnet_qa(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
||||
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
|
||||
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels):
|
||||
model = XLNetForQuestionAnswering(config)
|
||||
model.eval()
|
||||
|
||||
@@ -261,8 +264,40 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
|
||||
list(list(mem.size()) for mem in result["mems"]),
|
||||
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
|
||||
|
||||
def create_and_check_xlnet_token_classif(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
||||
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels):
|
||||
model = XLNetForTokenClassification(config)
|
||||
model.eval()
|
||||
|
||||
logits, mems_1 = model(input_ids_1)
|
||||
loss, logits, mems_1 = model(input_ids_1, labels=token_labels)
|
||||
|
||||
result = {
|
||||
"loss": loss,
|
||||
"mems_1": mems_1,
|
||||
"logits": logits,
|
||||
}
|
||||
|
||||
self.parent.assertListEqual(
|
||||
list(result["loss"].size()),
|
||||
[])
|
||||
self.parent.assertListEqual(
|
||||
list(result["logits"].size()),
|
||||
[self.batch_size, self.seq_length, self.type_sequence_label_size])
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_1"]),
|
||||
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
||||
target_mapping, segment_ids, lm_labels,
|
||||
sequence_labels, is_impossible_labels) = config_and_inputs
|
||||
inputs_dict = {'input_ids': input_ids_1}
|
||||
return config, inputs_dict
|
||||
|
||||
def create_and_check_xlnet_sequence_classif(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
||||
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
|
||||
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels):
|
||||
model = XLNetForSequenceClassification(config)
|
||||
model.eval()
|
||||
|
||||
@@ -289,7 +324,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
||||
target_mapping, segment_ids, lm_labels,
|
||||
sequence_labels, is_impossible_labels) = config_and_inputs
|
||||
sequence_labels, is_impossible_labels, token_labels) = config_and_inputs
|
||||
inputs_dict = {'input_ids': input_ids_1}
|
||||
return config, inputs_dict
|
||||
|
||||
@@ -316,6 +351,11 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_xlnet_sequence_classif(*config_and_inputs)
|
||||
|
||||
def test_xlnet_token_classif(self):
|
||||
self.model_tester.set_seed()
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_xlnet_token_classif(*config_and_inputs)
|
||||
|
||||
def test_xlnet_qa(self):
|
||||
self.model_tester.set_seed()
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
|
||||
Reference in New Issue
Block a user