Adding fine-tuning models to LUKE (#18353)
* add LUKE models for downstream tasks * add new LUKE models to docs * fix typos * remove commented lines * exclude None items from tuple return values
This commit is contained in:
@@ -30,6 +30,10 @@ if is_torch_available():
|
||||
LukeForEntityPairClassification,
|
||||
LukeForEntitySpanClassification,
|
||||
LukeForMaskedLM,
|
||||
LukeForMultipleChoice,
|
||||
LukeForQuestionAnswering,
|
||||
LukeForSequenceClassification,
|
||||
LukeForTokenClassification,
|
||||
LukeModel,
|
||||
LukeTokenizer,
|
||||
)
|
||||
@@ -66,6 +70,8 @@ class LukeModelTester:
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
num_entity_classification_labels=9,
|
||||
num_entity_pair_classification_labels=6,
|
||||
num_entity_span_classification_labels=4,
|
||||
@@ -99,6 +105,8 @@ class LukeModelTester:
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.num_labels = num_labels
|
||||
self.num_choices = num_choices
|
||||
self.num_entity_classification_labels = num_entity_classification_labels
|
||||
self.num_entity_pair_classification_labels = num_entity_pair_classification_labels
|
||||
self.num_entity_span_classification_labels = num_entity_span_classification_labels
|
||||
@@ -139,7 +147,8 @@ class LukeModelTester:
|
||||
)
|
||||
|
||||
sequence_labels = None
|
||||
labels = None
|
||||
token_labels = None
|
||||
choice_labels = None
|
||||
entity_labels = None
|
||||
entity_classification_labels = None
|
||||
entity_pair_classification_labels = None
|
||||
@@ -147,7 +156,9 @@ class LukeModelTester:
|
||||
|
||||
if self.use_labels:
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
||||
|
||||
entity_labels = ids_tensor([self.batch_size, self.entity_length], self.entity_vocab_size)
|
||||
|
||||
entity_classification_labels = ids_tensor([self.batch_size], self.num_entity_classification_labels)
|
||||
@@ -170,7 +181,8 @@ class LukeModelTester:
|
||||
entity_token_type_ids,
|
||||
entity_position_ids,
|
||||
sequence_labels,
|
||||
labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
entity_labels,
|
||||
entity_classification_labels,
|
||||
entity_pair_classification_labels,
|
||||
@@ -207,7 +219,8 @@ class LukeModelTester:
|
||||
entity_token_type_ids,
|
||||
entity_position_ids,
|
||||
sequence_labels,
|
||||
labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
entity_labels,
|
||||
entity_classification_labels,
|
||||
entity_pair_classification_labels,
|
||||
@@ -247,7 +260,8 @@ class LukeModelTester:
|
||||
entity_token_type_ids,
|
||||
entity_position_ids,
|
||||
sequence_labels,
|
||||
labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
entity_labels,
|
||||
entity_classification_labels,
|
||||
entity_pair_classification_labels,
|
||||
@@ -266,7 +280,7 @@ class LukeModelTester:
|
||||
entity_attention_mask=entity_attention_mask,
|
||||
entity_token_type_ids=entity_token_type_ids,
|
||||
entity_position_ids=entity_position_ids,
|
||||
labels=labels,
|
||||
labels=token_labels,
|
||||
entity_labels=entity_labels,
|
||||
)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
@@ -288,7 +302,8 @@ class LukeModelTester:
|
||||
entity_token_type_ids,
|
||||
entity_position_ids,
|
||||
sequence_labels,
|
||||
labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
entity_labels,
|
||||
entity_classification_labels,
|
||||
entity_pair_classification_labels,
|
||||
@@ -322,7 +337,8 @@ class LukeModelTester:
|
||||
entity_token_type_ids,
|
||||
entity_position_ids,
|
||||
sequence_labels,
|
||||
labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
entity_labels,
|
||||
entity_classification_labels,
|
||||
entity_pair_classification_labels,
|
||||
@@ -356,7 +372,8 @@ class LukeModelTester:
|
||||
entity_token_type_ids,
|
||||
entity_position_ids,
|
||||
sequence_labels,
|
||||
labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
entity_labels,
|
||||
entity_classification_labels,
|
||||
entity_pair_classification_labels,
|
||||
@@ -386,6 +403,156 @@ class LukeModelTester:
|
||||
result.logits.shape, (self.batch_size, self.entity_length, self.num_entity_span_classification_labels)
|
||||
)
|
||||
|
||||
def create_and_check_for_question_answering(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
entity_ids,
|
||||
entity_attention_mask,
|
||||
entity_token_type_ids,
|
||||
entity_position_ids,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
entity_labels,
|
||||
entity_classification_labels,
|
||||
entity_pair_classification_labels,
|
||||
entity_span_classification_labels,
|
||||
):
|
||||
model = LukeForQuestionAnswering(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
entity_ids=entity_ids,
|
||||
entity_attention_mask=entity_attention_mask,
|
||||
entity_token_type_ids=entity_token_type_ids,
|
||||
entity_position_ids=entity_position_ids,
|
||||
start_positions=sequence_labels,
|
||||
end_positions=sequence_labels,
|
||||
)
|
||||
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
|
||||
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
|
||||
|
||||
def create_and_check_for_sequence_classification(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
entity_ids,
|
||||
entity_attention_mask,
|
||||
entity_token_type_ids,
|
||||
entity_position_ids,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
entity_labels,
|
||||
entity_classification_labels,
|
||||
entity_pair_classification_labels,
|
||||
entity_span_classification_labels,
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = LukeForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
entity_ids=entity_ids,
|
||||
entity_attention_mask=entity_attention_mask,
|
||||
entity_token_type_ids=entity_token_type_ids,
|
||||
entity_position_ids=entity_position_ids,
|
||||
labels=sequence_labels,
|
||||
)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
||||
|
||||
def create_and_check_for_token_classification(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
entity_ids,
|
||||
entity_attention_mask,
|
||||
entity_token_type_ids,
|
||||
entity_position_ids,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
entity_labels,
|
||||
entity_classification_labels,
|
||||
entity_pair_classification_labels,
|
||||
entity_span_classification_labels,
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = LukeForTokenClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
entity_ids=entity_ids,
|
||||
entity_attention_mask=entity_attention_mask,
|
||||
entity_token_type_ids=entity_token_type_ids,
|
||||
entity_position_ids=entity_position_ids,
|
||||
labels=token_labels,
|
||||
)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
|
||||
|
||||
def create_and_check_for_multiple_choice(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
entity_ids,
|
||||
entity_attention_mask,
|
||||
entity_token_type_ids,
|
||||
entity_position_ids,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
entity_labels,
|
||||
entity_classification_labels,
|
||||
entity_pair_classification_labels,
|
||||
entity_span_classification_labels,
|
||||
):
|
||||
config.num_choices = self.num_choices
|
||||
model = LukeForMultipleChoice(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_attention_mask = attention_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_entity_ids = entity_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_entity_token_type_ids = (
|
||||
entity_token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
)
|
||||
multiple_choice_entity_attention_mask = (
|
||||
entity_attention_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
)
|
||||
multiple_choice_entity_position_ids = (
|
||||
entity_position_ids.unsqueeze(1).expand(-1, self.num_choices, -1, -1).contiguous()
|
||||
)
|
||||
result = model(
|
||||
multiple_choice_inputs_ids,
|
||||
attention_mask=multiple_choice_attention_mask,
|
||||
token_type_ids=multiple_choice_token_type_ids,
|
||||
entity_ids=multiple_choice_entity_ids,
|
||||
entity_attention_mask=multiple_choice_entity_attention_mask,
|
||||
entity_token_type_ids=multiple_choice_entity_token_type_ids,
|
||||
entity_position_ids=multiple_choice_entity_position_ids,
|
||||
labels=choice_labels,
|
||||
)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
@@ -398,7 +565,8 @@ class LukeModelTester:
|
||||
entity_token_type_ids,
|
||||
entity_position_ids,
|
||||
sequence_labels,
|
||||
labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
entity_labels,
|
||||
entity_classification_labels,
|
||||
entity_pair_classification_labels,
|
||||
@@ -426,6 +594,10 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
LukeForEntityClassification,
|
||||
LukeForEntityPairClassification,
|
||||
LukeForEntitySpanClassification,
|
||||
LukeForQuestionAnswering,
|
||||
LukeForSequenceClassification,
|
||||
LukeForTokenClassification,
|
||||
LukeForMultipleChoice,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
@@ -436,7 +608,19 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
test_head_masking = True
|
||||
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
entity_inputs_dict = {k: v for k, v in inputs_dict.items() if k.startswith("entity")}
|
||||
inputs_dict = {k: v for k, v in inputs_dict.items() if not k.startswith("entity")}
|
||||
|
||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||
if model_class == LukeForMultipleChoice:
|
||||
entity_inputs_dict = {
|
||||
k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous()
|
||||
if v.ndim == 2
|
||||
else v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1, -1).contiguous()
|
||||
for k, v in entity_inputs_dict.items()
|
||||
}
|
||||
inputs_dict.update(entity_inputs_dict)
|
||||
|
||||
if model_class == LukeForEntitySpanClassification:
|
||||
inputs_dict["entity_start_positions"] = torch.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.entity_length), dtype=torch.long, device=torch_device
|
||||
@@ -446,7 +630,12 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
)
|
||||
|
||||
if return_labels:
|
||||
if model_class in (LukeForEntityClassification, LukeForEntityPairClassification):
|
||||
if model_class in (
|
||||
LukeForEntityClassification,
|
||||
LukeForEntityPairClassification,
|
||||
LukeForSequenceClassification,
|
||||
LukeForMultipleChoice,
|
||||
):
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
||||
)
|
||||
@@ -456,6 +645,12 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
)
|
||||
elif model_class == LukeForTokenClassification:
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length),
|
||||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
)
|
||||
elif model_class == LukeForMaskedLM:
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length),
|
||||
@@ -496,6 +691,22 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = (*config_and_inputs[:4], *((None,) * len(config_and_inputs[4:])))
|
||||
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
|
||||
|
||||
def test_for_question_answering(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
|
||||
|
||||
def test_for_sequence_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
|
||||
|
||||
def test_for_token_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
|
||||
|
||||
def test_for_multiple_choice(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
|
||||
|
||||
def test_for_entity_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_entity_classification(*config_and_inputs)
|
||||
|
||||
Reference in New Issue
Block a user