Unify label args (#4722)
* Deprecate masked_lm_labels argument * Apply to all models * Better error message
This commit is contained in:
@@ -71,7 +71,7 @@ class EncoderDecoderModelTest(unittest.TestCase):
|
||||
"decoder_choice_labels": decoder_choice_labels,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"lm_labels": decoder_token_labels,
|
||||
"masked_lm_labels": decoder_token_labels,
|
||||
"labels": decoder_token_labels,
|
||||
}
|
||||
|
||||
def create_and_check_bert_encoder_decoder_model(
|
||||
@@ -224,7 +224,7 @@ class EncoderDecoderModelTest(unittest.TestCase):
|
||||
def check_loss_output(self, loss):
|
||||
self.assertEqual(loss.size(), ())
|
||||
|
||||
def create_and_check_bert_encoder_decoder_model_mlm_labels(
|
||||
def create_and_check_bert_encoder_decoder_model_labels(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
@@ -233,7 +233,7 @@ class EncoderDecoderModelTest(unittest.TestCase):
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
masked_lm_labels,
|
||||
labels,
|
||||
**kwargs
|
||||
):
|
||||
encoder_model = BertModel(config)
|
||||
@@ -245,7 +245,7 @@ class EncoderDecoderModelTest(unittest.TestCase):
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
masked_lm_labels=masked_lm_labels,
|
||||
labels=labels,
|
||||
)
|
||||
|
||||
mlm_loss = outputs_encoder_decoder[0]
|
||||
@@ -316,9 +316,9 @@ class EncoderDecoderModelTest(unittest.TestCase):
|
||||
input_ids_dict = self.prepare_config_and_inputs_bert()
|
||||
self.create_and_check_save_and_load_encoder_decoder_model(**input_ids_dict)
|
||||
|
||||
def test_bert_encoder_decoder_model_mlm_labels(self):
|
||||
def test_bert_encoder_decoder_model_labels(self):
|
||||
input_ids_dict = self.prepare_config_and_inputs_bert()
|
||||
self.create_and_check_bert_encoder_decoder_model_mlm_labels(**input_ids_dict)
|
||||
self.create_and_check_bert_encoder_decoder_model_labels(**input_ids_dict)
|
||||
|
||||
def test_bert_encoder_decoder_model_lm_labels(self):
|
||||
input_ids_dict = self.prepare_config_and_inputs_bert()
|
||||
|
||||
Reference in New Issue
Block a user