Merge pull request #3191 from patrickvonplaten/add_integration_tests_lm_generate_torch_tf

Add integration tests lm generate torch tf
This commit is contained in:
Patrick von Platen
2020-03-10 11:29:17 +01:00
committed by GitHub
13 changed files with 1075 additions and 189 deletions

View File

@@ -408,7 +408,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
inputs_embeds = self.embeddings(input_ids)
tensor = inputs_embeds + self.position_embeddings(position_ids)
if langs is not None and self.use_lang_emb:
if langs is not None and self.use_lang_emb and self.n_langs > 1:
tensor = tensor + self.lang_embeddings(langs)
if token_type_ids is not None:
tensor = tensor + self.embeddings(token_type_ids)