Unify label args (#4722)

* Deprecate masked_lm_labels argument

* Apply to all models

* Better error message
This commit is contained in:
Sylvain Gugger
2020-06-03 09:36:26 -04:00
committed by GitHub
parent 3e5928c57d
commit 1b5820a565
14 changed files with 223 additions and 93 deletions

View File

@@ -71,7 +71,7 @@ class EncoderDecoderModelTest(unittest.TestCase):
"decoder_choice_labels": decoder_choice_labels,
"encoder_hidden_states": encoder_hidden_states,
"lm_labels": decoder_token_labels,
"masked_lm_labels": decoder_token_labels,
"labels": decoder_token_labels,
}
def create_and_check_bert_encoder_decoder_model(
@@ -224,7 +224,7 @@ class EncoderDecoderModelTest(unittest.TestCase):
def check_loss_output(self, loss):
self.assertEqual(loss.size(), ())
def create_and_check_bert_encoder_decoder_model_mlm_labels(
def create_and_check_bert_encoder_decoder_model_labels(
self,
config,
input_ids,
@@ -233,7 +233,7 @@ class EncoderDecoderModelTest(unittest.TestCase):
decoder_config,
decoder_input_ids,
decoder_attention_mask,
masked_lm_labels,
labels,
**kwargs
):
encoder_model = BertModel(config)
@@ -245,7 +245,7 @@ class EncoderDecoderModelTest(unittest.TestCase):
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
masked_lm_labels=masked_lm_labels,
labels=labels,
)
mlm_loss = outputs_encoder_decoder[0]
@@ -316,9 +316,9 @@ class EncoderDecoderModelTest(unittest.TestCase):
input_ids_dict = self.prepare_config_and_inputs_bert()
self.create_and_check_save_and_load_encoder_decoder_model(**input_ids_dict)
def test_bert_encoder_decoder_model_mlm_labels(self):
def test_bert_encoder_decoder_model_labels(self):
input_ids_dict = self.prepare_config_and_inputs_bert()
self.create_and_check_bert_encoder_decoder_model_mlm_labels(**input_ids_dict)
self.create_and_check_bert_encoder_decoder_model_labels(**input_ids_dict)
def test_bert_encoder_decoder_model_lm_labels(self):
input_ids_dict = self.prepare_config_and_inputs_bert()