Fix TF Flaubert and XLM (#9661)
* Fix Flaubert and XLM * Fix Flaubert and XLM * Apply style
This commit is contained in:
@@ -214,10 +214,13 @@ class TFFlaubertPreTrainedModel(TFPreTrainedModel):
|
|||||||
inputs_list = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
|
inputs_list = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
|
||||||
attns_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
|
attns_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
|
||||||
if self.config.use_lang_emb and self.config.n_langs > 1:
|
if self.config.use_lang_emb and self.config.n_langs > 1:
|
||||||
langs_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
|
return {
|
||||||
|
"input_ids": inputs_list,
|
||||||
|
"attention_mask": attns_list,
|
||||||
|
"langs": tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]),
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
langs_list = None
|
return {"input_ids": inputs_list, "attention_mask": attns_list}
|
||||||
return {"input_ids": inputs_list, "attention_mask": attns_list, "langs": langs_list}
|
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
|
|||||||
@@ -536,10 +536,13 @@ class TFXLMPreTrainedModel(TFPreTrainedModel):
|
|||||||
inputs_list = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
|
inputs_list = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
|
||||||
attns_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
|
attns_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
|
||||||
if self.config.use_lang_emb and self.config.n_langs > 1:
|
if self.config.use_lang_emb and self.config.n_langs > 1:
|
||||||
langs_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
|
return {
|
||||||
|
"input_ids": inputs_list,
|
||||||
|
"attention_mask": attns_list,
|
||||||
|
"langs": tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]),
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
langs_list = None
|
return {"input_ids": inputs_list, "attention_mask": attns_list}
|
||||||
return {"input_ids": inputs_list, "attention_mask": attns_list, "langs": langs_list}
|
|
||||||
|
|
||||||
|
|
||||||
# Remove when XLMWithLMHead computes loss like other LM models
|
# Remove when XLMWithLMHead computes loss like other LM models
|
||||||
@@ -1045,10 +1048,16 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
|
|||||||
Returns:
|
Returns:
|
||||||
tf.Tensor with dummy inputs
|
tf.Tensor with dummy inputs
|
||||||
"""
|
"""
|
||||||
|
# Sometimes XLM has language embeddings so don't forget to build them as well if needed
|
||||||
|
if self.config.use_lang_emb and self.config.n_langs > 1:
|
||||||
return {
|
return {
|
||||||
"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS),
|
"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS),
|
||||||
"langs": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS),
|
"langs": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS),
|
||||||
}
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS),
|
||||||
|
}
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
|
@add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
|
|||||||
Reference in New Issue
Block a user