Add use_lang_emb to config
This commit is contained in:
@@ -114,6 +114,7 @@ class XLMConfig(PretrainedConfig):
|
|||||||
causal=False,
|
causal=False,
|
||||||
asm=False,
|
asm=False,
|
||||||
n_langs=1,
|
n_langs=1,
|
||||||
|
use_lang_emb=True,
|
||||||
max_position_embeddings=512,
|
max_position_embeddings=512,
|
||||||
embed_init_std=2048 ** -0.5,
|
embed_init_std=2048 ** -0.5,
|
||||||
layer_norm_eps=1e-12,
|
layer_norm_eps=1e-12,
|
||||||
@@ -157,6 +158,7 @@ class XLMConfig(PretrainedConfig):
|
|||||||
self.causal = causal
|
self.causal = causal
|
||||||
self.asm = asm
|
self.asm = asm
|
||||||
self.n_langs = n_langs
|
self.n_langs = n_langs
|
||||||
|
self.use_lang_emb = use_lang_emb
|
||||||
self.layer_norm_eps = layer_norm_eps
|
self.layer_norm_eps = layer_norm_eps
|
||||||
self.bos_index = bos_index
|
self.bos_index = bos_index
|
||||||
self.eos_index = eos_index
|
self.eos_index = eos_index
|
||||||
@@ -488,7 +490,7 @@ class XLMModel(XLMPreTrainedModel):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
ATTRIBUTES = ['encoder', 'eos_index', 'pad_index', # 'with_output',
|
ATTRIBUTES = ['encoder', 'eos_index', 'pad_index', # 'with_output',
|
||||||
'n_langs', 'n_words', 'dim', 'n_layers', 'n_heads',
|
'n_langs', 'use_lang_emb', 'n_words', 'dim', 'n_layers', 'n_heads',
|
||||||
'hidden_dim', 'dropout', 'attention_dropout', 'asm',
|
'hidden_dim', 'dropout', 'attention_dropout', 'asm',
|
||||||
'asm_cutoffs', 'asm_div_value']
|
'asm_cutoffs', 'asm_div_value']
|
||||||
|
|
||||||
@@ -507,6 +509,7 @@ class XLMModel(XLMPreTrainedModel):
|
|||||||
|
|
||||||
# dictionary / languages
|
# dictionary / languages
|
||||||
self.n_langs = config.n_langs
|
self.n_langs = config.n_langs
|
||||||
|
self.use_lang_emb = config.use_lang_emb
|
||||||
self.n_words = config.n_words
|
self.n_words = config.n_words
|
||||||
self.eos_index = config.eos_index
|
self.eos_index = config.eos_index
|
||||||
self.pad_index = config.pad_index
|
self.pad_index = config.pad_index
|
||||||
@@ -529,7 +532,7 @@ class XLMModel(XLMPreTrainedModel):
|
|||||||
self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.dim)
|
self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.dim)
|
||||||
if config.sinusoidal_embeddings:
|
if config.sinusoidal_embeddings:
|
||||||
create_sinusoidal_embeddings(config.max_position_embeddings, self.dim, out=self.position_embeddings.weight)
|
create_sinusoidal_embeddings(config.max_position_embeddings, self.dim, out=self.position_embeddings.weight)
|
||||||
if config.n_langs > 1:
|
if config.n_langs > 1 and config.use_lang_emb:
|
||||||
self.lang_embeddings = nn.Embedding(self.n_langs, self.dim)
|
self.lang_embeddings = nn.Embedding(self.n_langs, self.dim)
|
||||||
self.embeddings = nn.Embedding(self.n_words, self.dim, padding_idx=self.pad_index)
|
self.embeddings = nn.Embedding(self.n_words, self.dim, padding_idx=self.pad_index)
|
||||||
self.layer_norm_emb = nn.LayerNorm(self.dim, eps=config.layer_norm_eps)
|
self.layer_norm_emb = nn.LayerNorm(self.dim, eps=config.layer_norm_eps)
|
||||||
@@ -628,7 +631,7 @@ class XLMModel(XLMPreTrainedModel):
|
|||||||
# embeddings
|
# embeddings
|
||||||
tensor = self.embeddings(input_ids)
|
tensor = self.embeddings(input_ids)
|
||||||
tensor = tensor + self.position_embeddings(position_ids).expand_as(tensor)
|
tensor = tensor + self.position_embeddings(position_ids).expand_as(tensor)
|
||||||
if langs is not None:
|
if langs is not None and self.use_lang_emb:
|
||||||
tensor = tensor + self.lang_embeddings(langs)
|
tensor = tensor + self.lang_embeddings(langs)
|
||||||
if token_type_ids is not None:
|
if token_type_ids is not None:
|
||||||
tensor = tensor + self.embeddings(token_type_ids)
|
tensor = tensor + self.embeddings(token_type_ids)
|
||||||
|
|||||||
Reference in New Issue
Block a user