Debug LukeForMaskedLM (#17499)
* add a test for a word only input * make LukeForMaskedLM work without entity inputs * update test * add LukeForMaskedLM to MODEL_FOR_MASKED_LM_MAPPING_NAMES * restore pyproject.toml * empty line at the end of pyproject.toml
This commit is contained in:
@@ -377,6 +377,7 @@ MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
|
|||||||
("ibert", "IBertForMaskedLM"),
|
("ibert", "IBertForMaskedLM"),
|
||||||
("layoutlm", "LayoutLMForMaskedLM"),
|
("layoutlm", "LayoutLMForMaskedLM"),
|
||||||
("longformer", "LongformerForMaskedLM"),
|
("longformer", "LongformerForMaskedLM"),
|
||||||
|
("luke", "LukeForMaskedLM"),
|
||||||
("mbart", "MBartForConditionalGeneration"),
|
("mbart", "MBartForConditionalGeneration"),
|
||||||
("megatron-bert", "MegatronBertForMaskedLM"),
|
("megatron-bert", "MegatronBertForMaskedLM"),
|
||||||
("mobilebert", "MobileBertForMaskedLM"),
|
("mobilebert", "MobileBertForMaskedLM"),
|
||||||
|
|||||||
@@ -1229,13 +1229,15 @@ class LukeForMaskedLM(LukePreTrainedModel):
|
|||||||
loss = mlm_loss
|
loss = mlm_loss
|
||||||
|
|
||||||
mep_loss = None
|
mep_loss = None
|
||||||
entity_logits = self.entity_predictions(outputs.entity_last_hidden_state)
|
entity_logits = None
|
||||||
if entity_labels is not None:
|
if outputs.entity_last_hidden_state is not None:
|
||||||
mep_loss = self.loss_fn(entity_logits.view(-1, self.config.entity_vocab_size), entity_labels.view(-1))
|
entity_logits = self.entity_predictions(outputs.entity_last_hidden_state)
|
||||||
if loss is None:
|
if entity_labels is not None:
|
||||||
loss = mep_loss
|
mep_loss = self.loss_fn(entity_logits.view(-1, self.config.entity_vocab_size), entity_labels.view(-1))
|
||||||
else:
|
if loss is None:
|
||||||
loss = loss + mep_loss
|
loss = mep_loss
|
||||||
|
else:
|
||||||
|
loss = loss + mep_loss
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits, entity_logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions)
|
output = (logits, entity_logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions)
|
||||||
|
|||||||
@@ -270,9 +270,12 @@ class LukeModelTester:
|
|||||||
entity_labels=entity_labels,
|
entity_labels=entity_labels,
|
||||||
)
|
)
|
||||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||||
self.parent.assertEqual(
|
if entity_ids is not None:
|
||||||
result.entity_logits.shape, (self.batch_size, self.entity_length, self.entity_vocab_size)
|
self.parent.assertEqual(
|
||||||
)
|
result.entity_logits.shape, (self.batch_size, self.entity_length, self.entity_vocab_size)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.parent.assertIsNone(result.entity_logits)
|
||||||
|
|
||||||
def create_and_check_for_entity_classification(
|
def create_and_check_for_entity_classification(
|
||||||
self,
|
self,
|
||||||
@@ -488,6 +491,11 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
|
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_for_masked_lm_with_word_only(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
config_and_inputs = (*config_and_inputs[:4], *((None,) * len(config_and_inputs[4:])))
|
||||||
|
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
|
||||||
|
|
||||||
def test_for_entity_classification(self):
|
def test_for_entity_classification(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_for_entity_classification(*config_and_inputs)
|
self.model_tester.create_and_check_for_entity_classification(*config_and_inputs)
|
||||||
|
|||||||
Reference in New Issue
Block a user