Debug LukeForMaskedLM (#17499)
* add a test for a word only input * make LukeForMaskedLM work without entity inputs * update test * add LukeForMaskedLM to MODEL_FOR_MASKED_LM_MAPPING_NAMES * restore pyproject.toml * empty line at the end of pyproject.toml
This commit is contained in:
@@ -270,9 +270,12 @@ class LukeModelTester:
|
||||
entity_labels=entity_labels,
|
||||
)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
self.parent.assertEqual(
|
||||
result.entity_logits.shape, (self.batch_size, self.entity_length, self.entity_vocab_size)
|
||||
)
|
||||
if entity_ids is not None:
|
||||
self.parent.assertEqual(
|
||||
result.entity_logits.shape, (self.batch_size, self.entity_length, self.entity_vocab_size)
|
||||
)
|
||||
else:
|
||||
self.parent.assertIsNone(result.entity_logits)
|
||||
|
||||
def create_and_check_for_entity_classification(
|
||||
self,
|
||||
@@ -488,6 +491,11 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
|
||||
|
||||
def test_for_masked_lm_with_word_only(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
config_and_inputs = (*config_and_inputs[:4], *((None,) * len(config_and_inputs[4:])))
|
||||
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
|
||||
|
||||
def test_for_entity_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_entity_classification(*config_and_inputs)
|
||||
|
||||
Reference in New Issue
Block a user