From f1b018740c9355f0bcf0093fc993724eaa737445 Mon Sep 17 00:00:00 2001 From: Shijie Wu Date: Fri, 23 Aug 2019 20:33:01 -0400 Subject: [PATCH] Add use_lang_emb to config --- pytorch_transformers/modeling_xlm.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/pytorch_transformers/modeling_xlm.py b/pytorch_transformers/modeling_xlm.py index 19800da2ed..10be972ea5 100644 --- a/pytorch_transformers/modeling_xlm.py +++ b/pytorch_transformers/modeling_xlm.py @@ -114,6 +114,7 @@ class XLMConfig(PretrainedConfig): causal=False, asm=False, n_langs=1, + use_lang_emb=True, max_position_embeddings=512, embed_init_std=2048 ** -0.5, layer_norm_eps=1e-12, @@ -157,6 +158,7 @@ class XLMConfig(PretrainedConfig): self.causal = causal self.asm = asm self.n_langs = n_langs + self.use_lang_emb = use_lang_emb self.layer_norm_eps = layer_norm_eps self.bos_index = bos_index self.eos_index = eos_index @@ -488,7 +490,7 @@ class XLMModel(XLMPreTrainedModel): """ ATTRIBUTES = ['encoder', 'eos_index', 'pad_index', # 'with_output', - 'n_langs', 'n_words', 'dim', 'n_layers', 'n_heads', + 'n_langs', 'use_lang_emb', 'n_words', 'dim', 'n_layers', 'n_heads', 'hidden_dim', 'dropout', 'attention_dropout', 'asm', 'asm_cutoffs', 'asm_div_value'] @@ -507,6 +509,7 @@ class XLMModel(XLMPreTrainedModel): # dictionary / languages self.n_langs = config.n_langs + self.use_lang_emb = config.use_lang_emb self.n_words = config.n_words self.eos_index = config.eos_index self.pad_index = config.pad_index @@ -529,7 +532,7 @@ class XLMModel(XLMPreTrainedModel): self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.dim) if config.sinusoidal_embeddings: create_sinusoidal_embeddings(config.max_position_embeddings, self.dim, out=self.position_embeddings.weight) - if config.n_langs > 1: + if config.n_langs > 1 and config.use_lang_emb: self.lang_embeddings = nn.Embedding(self.n_langs, self.dim) self.embeddings = nn.Embedding(self.n_words, self.dim, padding_idx=self.pad_index) self.layer_norm_emb = nn.LayerNorm(self.dim, eps=config.layer_norm_eps) @@ -628,7 +631,7 @@ class XLMModel(XLMPreTrainedModel): # embeddings tensor = self.embeddings(input_ids) tensor = tensor + self.position_embeddings(position_ids).expand_as(tensor) - if langs is not None: + if langs is not None and self.use_lang_emb: tensor = tensor + self.lang_embeddings(langs) if token_type_ids is not None: tensor = tensor + self.embeddings(token_type_ids)