From fa876aee2adf525b597495c10ad9c96896953dbd Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Tue, 19 Jan 2021 18:02:57 +0100 Subject: [PATCH] Fix TF Flaubert and XLM (#9661) * Fix Flaubert and XLM * Fix Flaubert and XLM * Apply style --- .../models/flaubert/modeling_tf_flaubert.py | 9 +++++--- .../models/xlm/modeling_tf_xlm.py | 23 +++++++++++++------ 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/flaubert/modeling_tf_flaubert.py b/src/transformers/models/flaubert/modeling_tf_flaubert.py index 28ebba7daa..f24dfa7473 100644 --- a/src/transformers/models/flaubert/modeling_tf_flaubert.py +++ b/src/transformers/models/flaubert/modeling_tf_flaubert.py @@ -214,10 +214,13 @@ class TFFlaubertPreTrainedModel(TFPreTrainedModel): inputs_list = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]) attns_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]) if self.config.use_lang_emb and self.config.n_langs > 1: - langs_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]) + return { + "input_ids": inputs_list, + "attention_mask": attns_list, + "langs": tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), + } else: - langs_list = None - return {"input_ids": inputs_list, "attention_mask": attns_list, "langs": langs_list} + return {"input_ids": inputs_list, "attention_mask": attns_list} @add_start_docstrings( diff --git a/src/transformers/models/xlm/modeling_tf_xlm.py b/src/transformers/models/xlm/modeling_tf_xlm.py index 31d8b4fd49..8cd3c7ef48 100644 --- a/src/transformers/models/xlm/modeling_tf_xlm.py +++ b/src/transformers/models/xlm/modeling_tf_xlm.py @@ -536,10 +536,13 @@ class TFXLMPreTrainedModel(TFPreTrainedModel): inputs_list = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]) attns_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]) if self.config.use_lang_emb and self.config.n_langs > 1: - langs_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]) + return { + "input_ids": inputs_list, + "attention_mask": attns_list, + "langs": tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), + } else: - langs_list = None - return {"input_ids": inputs_list, "attention_mask": attns_list, "langs": langs_list} + return {"input_ids": inputs_list, "attention_mask": attns_list} # Remove when XLMWithLMHead computes loss like other LM models @@ -1045,10 +1048,16 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss): Returns: tf.Tensor with dummy inputs """ - return { - "input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS), - "langs": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS), - } + # Sometimes XLM has language embeddings so don't forget to build them as well if needed + if self.config.use_lang_emb and self.config.n_langs > 1: + return { + "input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS), + "langs": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS), + } + else: + return { + "input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS), + } @add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) @add_code_sample_docstrings(