From 4d1ce3968346737c437245f86213beff16bd6899 Mon Sep 17 00:00:00 2001 From: Ryokan RI Date: Wed, 1 Jun 2022 23:03:06 +0900 Subject: [PATCH] Debug LukeForMaskedLM (#17499) * add a test for a word only input * make LukeForMaskedLM work without entity inputs * update test * add LukeForMaskedLM to MODEL_FOR_MASKED_LM_MAPPING_NAMES * restore pyproject.toml * empty line at the end of pyproject.toml --- src/transformers/models/auto/modeling_auto.py | 1 + src/transformers/models/luke/modeling_luke.py | 16 +++++++++------- tests/models/luke/test_modeling_luke.py | 14 +++++++++++--- 3 files changed, 21 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index be7dc5bc9e..61787c3d60 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -377,6 +377,7 @@ MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( ("ibert", "IBertForMaskedLM"), ("layoutlm", "LayoutLMForMaskedLM"), ("longformer", "LongformerForMaskedLM"), + ("luke", "LukeForMaskedLM"), ("mbart", "MBartForConditionalGeneration"), ("megatron-bert", "MegatronBertForMaskedLM"), ("mobilebert", "MobileBertForMaskedLM"), diff --git a/src/transformers/models/luke/modeling_luke.py b/src/transformers/models/luke/modeling_luke.py index f7c36ff93d..4c2491aee7 100644 --- a/src/transformers/models/luke/modeling_luke.py +++ b/src/transformers/models/luke/modeling_luke.py @@ -1229,13 +1229,15 @@ class LukeForMaskedLM(LukePreTrainedModel): loss = mlm_loss mep_loss = None - entity_logits = self.entity_predictions(outputs.entity_last_hidden_state) - if entity_labels is not None: - mep_loss = self.loss_fn(entity_logits.view(-1, self.config.entity_vocab_size), entity_labels.view(-1)) - if loss is None: - loss = mep_loss - else: - loss = loss + mep_loss + entity_logits = None + if outputs.entity_last_hidden_state is not None: + entity_logits = self.entity_predictions(outputs.entity_last_hidden_state) + if entity_labels is not None: + mep_loss = self.loss_fn(entity_logits.view(-1, self.config.entity_vocab_size), entity_labels.view(-1)) + if loss is None: + loss = mep_loss + else: + loss = loss + mep_loss if not return_dict: output = (logits, entity_logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions) diff --git a/tests/models/luke/test_modeling_luke.py b/tests/models/luke/test_modeling_luke.py index b6c9ef89ff..264b7f8955 100644 --- a/tests/models/luke/test_modeling_luke.py +++ b/tests/models/luke/test_modeling_luke.py @@ -270,9 +270,12 @@ class LukeModelTester: entity_labels=entity_labels, ) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) - self.parent.assertEqual( - result.entity_logits.shape, (self.batch_size, self.entity_length, self.entity_vocab_size) - ) + if entity_ids is not None: + self.parent.assertEqual( + result.entity_logits.shape, (self.batch_size, self.entity_length, self.entity_vocab_size) + ) + else: + self.parent.assertIsNone(result.entity_logits) def create_and_check_for_entity_classification( self, @@ -488,6 +491,11 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_masked_lm(*config_and_inputs) + def test_for_masked_lm_with_word_only(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + config_and_inputs = (*config_and_inputs[:4], *((None,) * len(config_and_inputs[4:]))) + self.model_tester.create_and_check_for_masked_lm(*config_and_inputs) + def test_for_entity_classification(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_entity_classification(*config_and_inputs)