Adding fine-tuning models to LUKE (#18353)

* add LUKE models for downstream tasks

* add new LUKE models to docs

* fix typos

* remove commented lines

* exclude None items from tuple return values
This commit is contained in:
Ikuya Yamada
2022-08-02 00:09:47 +09:00
committed by GitHub
parent 7b9e995b70
commit 62098b9348
8 changed files with 922 additions and 38 deletions

View File

@@ -30,6 +30,10 @@ if is_torch_available():
LukeForEntityPairClassification,
LukeForEntitySpanClassification,
LukeForMaskedLM,
LukeForMultipleChoice,
LukeForQuestionAnswering,
LukeForSequenceClassification,
LukeForTokenClassification,
LukeModel,
LukeTokenizer,
)
@@ -66,6 +70,8 @@ class LukeModelTester:
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
num_choices=4,
num_entity_classification_labels=9,
num_entity_pair_classification_labels=6,
num_entity_span_classification_labels=4,
@@ -99,6 +105,8 @@ class LukeModelTester:
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_labels = num_labels
self.num_choices = num_choices
self.num_entity_classification_labels = num_entity_classification_labels
self.num_entity_pair_classification_labels = num_entity_pair_classification_labels
self.num_entity_span_classification_labels = num_entity_span_classification_labels
@@ -139,7 +147,8 @@ class LukeModelTester:
)
sequence_labels = None
labels = None
token_labels = None
choice_labels = None
entity_labels = None
entity_classification_labels = None
entity_pair_classification_labels = None
@@ -147,7 +156,9 @@ class LukeModelTester:
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
entity_labels = ids_tensor([self.batch_size, self.entity_length], self.entity_vocab_size)
entity_classification_labels = ids_tensor([self.batch_size], self.num_entity_classification_labels)
@@ -170,7 +181,8 @@ class LukeModelTester:
entity_token_type_ids,
entity_position_ids,
sequence_labels,
labels,
token_labels,
choice_labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
@@ -207,7 +219,8 @@ class LukeModelTester:
entity_token_type_ids,
entity_position_ids,
sequence_labels,
labels,
token_labels,
choice_labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
@@ -247,7 +260,8 @@ class LukeModelTester:
entity_token_type_ids,
entity_position_ids,
sequence_labels,
labels,
token_labels,
choice_labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
@@ -266,7 +280,7 @@ class LukeModelTester:
entity_attention_mask=entity_attention_mask,
entity_token_type_ids=entity_token_type_ids,
entity_position_ids=entity_position_ids,
labels=labels,
labels=token_labels,
entity_labels=entity_labels,
)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
@@ -288,7 +302,8 @@ class LukeModelTester:
entity_token_type_ids,
entity_position_ids,
sequence_labels,
labels,
token_labels,
choice_labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
@@ -322,7 +337,8 @@ class LukeModelTester:
entity_token_type_ids,
entity_position_ids,
sequence_labels,
labels,
token_labels,
choice_labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
@@ -356,7 +372,8 @@ class LukeModelTester:
entity_token_type_ids,
entity_position_ids,
sequence_labels,
labels,
token_labels,
choice_labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
@@ -386,6 +403,156 @@ class LukeModelTester:
result.logits.shape, (self.batch_size, self.entity_length, self.num_entity_span_classification_labels)
)
def create_and_check_for_question_answering(
self,
config,
input_ids,
attention_mask,
token_type_ids,
entity_ids,
entity_attention_mask,
entity_token_type_ids,
entity_position_ids,
sequence_labels,
token_labels,
choice_labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
entity_span_classification_labels,
):
model = LukeForQuestionAnswering(config=config)
model.to(torch_device)
model.eval()
result = model(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
entity_ids=entity_ids,
entity_attention_mask=entity_attention_mask,
entity_token_type_ids=entity_token_type_ids,
entity_position_ids=entity_position_ids,
start_positions=sequence_labels,
end_positions=sequence_labels,
)
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
def create_and_check_for_sequence_classification(
self,
config,
input_ids,
attention_mask,
token_type_ids,
entity_ids,
entity_attention_mask,
entity_token_type_ids,
entity_position_ids,
sequence_labels,
token_labels,
choice_labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
entity_span_classification_labels,
):
config.num_labels = self.num_labels
model = LukeForSequenceClassification(config)
model.to(torch_device)
model.eval()
result = model(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
entity_ids=entity_ids,
entity_attention_mask=entity_attention_mask,
entity_token_type_ids=entity_token_type_ids,
entity_position_ids=entity_position_ids,
labels=sequence_labels,
)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def create_and_check_for_token_classification(
self,
config,
input_ids,
attention_mask,
token_type_ids,
entity_ids,
entity_attention_mask,
entity_token_type_ids,
entity_position_ids,
sequence_labels,
token_labels,
choice_labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
entity_span_classification_labels,
):
config.num_labels = self.num_labels
model = LukeForTokenClassification(config=config)
model.to(torch_device)
model.eval()
result = model(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
entity_ids=entity_ids,
entity_attention_mask=entity_attention_mask,
entity_token_type_ids=entity_token_type_ids,
entity_position_ids=entity_position_ids,
labels=token_labels,
)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
def create_and_check_for_multiple_choice(
self,
config,
input_ids,
attention_mask,
token_type_ids,
entity_ids,
entity_attention_mask,
entity_token_type_ids,
entity_position_ids,
sequence_labels,
token_labels,
choice_labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
entity_span_classification_labels,
):
config.num_choices = self.num_choices
model = LukeForMultipleChoice(config=config)
model.to(torch_device)
model.eval()
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_attention_mask = attention_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_entity_ids = entity_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_entity_token_type_ids = (
entity_token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
)
multiple_choice_entity_attention_mask = (
entity_attention_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
)
multiple_choice_entity_position_ids = (
entity_position_ids.unsqueeze(1).expand(-1, self.num_choices, -1, -1).contiguous()
)
result = model(
multiple_choice_inputs_ids,
attention_mask=multiple_choice_attention_mask,
token_type_ids=multiple_choice_token_type_ids,
entity_ids=multiple_choice_entity_ids,
entity_attention_mask=multiple_choice_entity_attention_mask,
entity_token_type_ids=multiple_choice_entity_token_type_ids,
entity_position_ids=multiple_choice_entity_position_ids,
labels=choice_labels,
)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
@@ -398,7 +565,8 @@ class LukeModelTester:
entity_token_type_ids,
entity_position_ids,
sequence_labels,
labels,
token_labels,
choice_labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
@@ -426,6 +594,10 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase):
LukeForEntityClassification,
LukeForEntityPairClassification,
LukeForEntitySpanClassification,
LukeForQuestionAnswering,
LukeForSequenceClassification,
LukeForTokenClassification,
LukeForMultipleChoice,
)
if is_torch_available()
else ()
@@ -436,7 +608,19 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase):
test_head_masking = True
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
entity_inputs_dict = {k: v for k, v in inputs_dict.items() if k.startswith("entity")}
inputs_dict = {k: v for k, v in inputs_dict.items() if not k.startswith("entity")}
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
if model_class == LukeForMultipleChoice:
entity_inputs_dict = {
k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous()
if v.ndim == 2
else v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1, -1).contiguous()
for k, v in entity_inputs_dict.items()
}
inputs_dict.update(entity_inputs_dict)
if model_class == LukeForEntitySpanClassification:
inputs_dict["entity_start_positions"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.entity_length), dtype=torch.long, device=torch_device
@@ -446,7 +630,12 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase):
)
if return_labels:
if model_class in (LukeForEntityClassification, LukeForEntityPairClassification):
if model_class in (
LukeForEntityClassification,
LukeForEntityPairClassification,
LukeForSequenceClassification,
LukeForMultipleChoice,
):
inputs_dict["labels"] = torch.zeros(
self.model_tester.batch_size, dtype=torch.long, device=torch_device
)
@@ -456,6 +645,12 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase):
dtype=torch.long,
device=torch_device,
)
elif model_class == LukeForTokenClassification:
inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length),
dtype=torch.long,
device=torch_device,
)
elif model_class == LukeForMaskedLM:
inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length),
@@ -496,6 +691,22 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = (*config_and_inputs[:4], *((None,) * len(config_and_inputs[4:])))
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
def test_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
def test_for_entity_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_entity_classification(*config_and_inputs)