From 4620caa864997a361a969266ed57943ad1b3f11e Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 9 Mar 2020 11:18:54 +0100 Subject: [PATCH] fix if use lang embeddings in tf xlm --- src/transformers/modeling_tf_xlm.py | 2 +- tests/test_modeling_tf_xlm.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_tf_xlm.py b/src/transformers/modeling_tf_xlm.py index 6e94a7206e..407f83d05d 100644 --- a/src/transformers/modeling_tf_xlm.py +++ b/src/transformers/modeling_tf_xlm.py @@ -408,7 +408,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer): inputs_embeds = self.embeddings(input_ids) tensor = inputs_embeds + self.position_embeddings(position_ids) - if langs is not None and self.use_lang_emb: + if langs is not None and self.use_lang_emb and self.n_langs > 1: tensor = tensor + self.lang_embeddings(langs) if token_type_ids is not None: tensor = tensor + self.embeddings(token_type_ids) diff --git a/tests/test_modeling_tf_xlm.py b/tests/test_modeling_tf_xlm.py index c2557cff38..d98fc44131 100644 --- a/tests/test_modeling_tf_xlm.py +++ b/tests/test_modeling_tf_xlm.py @@ -342,4 +342,4 @@ class TFXLMModelLanguageGenerationTest(unittest.TestCase): ] # the president the president the president the president the president the president the president the president the president the president # TODO(PVP): this and other input_ids I tried for generation give pretty bad results. Not sure why. Model might just not be made for auto-regressive inference output_ids = model.generate(input_ids) - self.assertListEqual(output_ids[0].tolist(), expected_output_ids, do_sample=False) + self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids, do_sample=False)