Adding LM Head to Transfo-XL and first step to fixing problem with Adaptive Embeddings in TransfoXL (#3286)
* first commit * work in progress * make language generation task pass * update to working version for LM * delete print * remove dead code * make style
This commit is contained in:
committed by
GitHub
parent
efdb46b6e2
commit
292186a3e7
@@ -357,6 +357,7 @@ if is_tf_available():
|
|||||||
TFTransfoXLModel,
|
TFTransfoXLModel,
|
||||||
TFTransfoXLLMHeadModel,
|
TFTransfoXLLMHeadModel,
|
||||||
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
TFAdaptiveEmbedding,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .modeling_tf_xlnet import (
|
from .modeling_tf_xlnet import (
|
||||||
|
|||||||
@@ -733,6 +733,25 @@ class TFTransfoXLModel(TFTransfoXLPreTrainedModel):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class TFTransfoXLLMHead(tf.keras.layers.Layer):
|
||||||
|
def __init__(self, config, input_embeddings, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
|
||||||
|
# The output weights are the same as the input embeddings, but there is
|
||||||
|
# an output-only bias for each token.
|
||||||
|
self.input_embeddings = input_embeddings
|
||||||
|
|
||||||
|
def build(self, input_shape):
|
||||||
|
self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
|
||||||
|
super().build(input_shape)
|
||||||
|
|
||||||
|
def call(self, hidden_states):
|
||||||
|
hidden_states = self.input_embeddings(hidden_states, mode="linear")
|
||||||
|
hidden_states = hidden_states + self.bias
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"""The Transformer-XL Model with a language modeling head on top
|
"""The Transformer-XL Model with a language modeling head on top
|
||||||
(adaptive softmax with weights tied to the adaptive input embeddings)""",
|
(adaptive softmax with weights tied to the adaptive input embeddings)""",
|
||||||
@@ -743,14 +762,20 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
|
|||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.transformer = TFTransfoXLMainLayer(config, name="transformer")
|
self.transformer = TFTransfoXLMainLayer(config, name="transformer")
|
||||||
self.sample_softmax = config.sample_softmax
|
self.sample_softmax = config.sample_softmax
|
||||||
# use sampled softmax
|
assert (
|
||||||
if config.sample_softmax > 0:
|
self.sample_softmax <= 0
|
||||||
raise NotImplementedError
|
), "Sampling from the softmax is not implemented yet. Please look at issue: #3310: https://github.com/huggingface/transformers/issues/3310"
|
||||||
# use adaptive softmax (including standard softmax)
|
|
||||||
else:
|
self.crit = TFAdaptiveSoftmaxMask(
|
||||||
self.crit = TFAdaptiveSoftmaxMask(
|
config.vocab_size, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val, name="crit"
|
||||||
config.vocab_size, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val, name="crit"
|
)
|
||||||
)
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
""" Double-check if you are using adaptive softmax.
|
||||||
|
"""
|
||||||
|
if len(self.crit.out_layers) > 0:
|
||||||
|
return self.crit.out_layers[-1]
|
||||||
|
return None
|
||||||
|
|
||||||
def reset_length(self, tgt_len, ext_len, mem_len):
|
def reset_length(self, tgt_len, ext_len, mem_len):
|
||||||
self.transformer.reset_length(tgt_len, ext_len, mem_len)
|
self.transformer.reset_length(tgt_len, ext_len, mem_len)
|
||||||
@@ -820,13 +845,9 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
|
|||||||
last_hidden = transformer_outputs[0]
|
last_hidden = transformer_outputs[0]
|
||||||
pred_hid = last_hidden[:, -tgt_len:]
|
pred_hid = last_hidden[:, -tgt_len:]
|
||||||
outputs = transformer_outputs[1:]
|
outputs = transformer_outputs[1:]
|
||||||
if self.sample_softmax > 0 and training:
|
|
||||||
raise NotImplementedError
|
softmax_output = self.crit([pred_hid, labels], training=training)
|
||||||
else:
|
outputs = [softmax_output] + outputs
|
||||||
# pred_hid = tf.reshape(pred_hid, (-1, shape_list(pred_hid)[-1]))
|
|
||||||
softmax_output = self.crit([pred_hid, labels], training=training)
|
|
||||||
# softmax_output = tf.reshape(softmax_output, (bsz, tgt_len, -1))
|
|
||||||
outputs = [softmax_output] + outputs
|
|
||||||
|
|
||||||
return outputs # logits, new_mems, (all hidden states), (all attentions)
|
return outputs # logits, new_mems, (all hidden states), (all attentions)
|
||||||
|
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
from .configuration_transfo_xl import TransfoXLConfig
|
from .configuration_transfo_xl import TransfoXLConfig
|
||||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_transfo_xl_utilities import LogUniformSampler, ProjectedAdaptiveLogSoftmax, sample_logits
|
from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax
|
||||||
from .modeling_utils import PreTrainedModel
|
from .modeling_utils import PreTrainedModel
|
||||||
|
|
||||||
|
|
||||||
@@ -809,42 +809,37 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
|
|||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.transformer = TransfoXLModel(config)
|
self.transformer = TransfoXLModel(config)
|
||||||
self.sample_softmax = config.sample_softmax
|
self.sample_softmax = config.sample_softmax
|
||||||
# use sampled softmax
|
|
||||||
if config.sample_softmax > 0:
|
assert (
|
||||||
self.out_layer = nn.Linear(config.d_model, config.vocab_size)
|
self.sample_softmax <= 0
|
||||||
self.sampler = LogUniformSampler(config.vocab_size, config.sample_softmax)
|
), "Sampling from the softmax is not implemented yet. Please look at issue: #3310: https://github.com/huggingface/transformers/issues/3310"
|
||||||
# use adaptive softmax (including standard softmax)
|
|
||||||
else:
|
self.crit = ProjectedAdaptiveLogSoftmax(
|
||||||
self.crit = ProjectedAdaptiveLogSoftmax(
|
config.vocab_size, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val
|
||||||
config.vocab_size, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val
|
)
|
||||||
)
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
def tie_weights(self):
|
def tie_weights(self):
|
||||||
"""
|
"""
|
||||||
Run this to be sure output and input (adaptive) softmax weights are tied
|
Run this to be sure output and input (adaptive) softmax weights are tied
|
||||||
"""
|
"""
|
||||||
# sampled softmax
|
|
||||||
if self.sample_softmax > 0:
|
if self.config.tie_weight:
|
||||||
if self.config.tie_weight:
|
for i in range(len(self.crit.out_layers)):
|
||||||
self.out_layer.weight = self.transformer.word_emb.weight
|
self._tie_or_clone_weights(self.crit.out_layers[i], self.transformer.word_emb.emb_layers[i])
|
||||||
# adaptive softmax (including standard softmax)
|
if self.config.tie_projs:
|
||||||
else:
|
for i, tie_proj in enumerate(self.config.tie_projs):
|
||||||
if self.config.tie_weight:
|
if tie_proj and self.config.div_val == 1 and self.config.d_model != self.config.d_embed:
|
||||||
for i in range(len(self.crit.out_layers)):
|
if self.config.torchscript:
|
||||||
self._tie_or_clone_weights(self.crit.out_layers[i], self.transformer.word_emb.emb_layers[i])
|
self.crit.out_projs[i] = nn.Parameter(self.transformer.word_emb.emb_projs[0].clone())
|
||||||
if self.config.tie_projs:
|
else:
|
||||||
for i, tie_proj in enumerate(self.config.tie_projs):
|
self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[0]
|
||||||
if tie_proj and self.config.div_val == 1 and self.config.d_model != self.config.d_embed:
|
elif tie_proj and self.config.div_val != 1:
|
||||||
if self.config.torchscript:
|
if self.config.torchscript:
|
||||||
self.crit.out_projs[i] = nn.Parameter(self.transformer.word_emb.emb_projs[0].clone())
|
self.crit.out_projs[i] = nn.Parameter(self.transformer.word_emb.emb_projs[i].clone())
|
||||||
else:
|
else:
|
||||||
self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[0]
|
self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[i]
|
||||||
elif tie_proj and self.config.div_val != 1:
|
|
||||||
if self.config.torchscript:
|
|
||||||
self.crit.out_projs[i] = nn.Parameter(self.transformer.word_emb.emb_projs[i].clone())
|
|
||||||
else:
|
|
||||||
self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[i]
|
|
||||||
|
|
||||||
def reset_length(self, tgt_len, ext_len, mem_len):
|
def reset_length(self, tgt_len, ext_len, mem_len):
|
||||||
self.transformer.reset_length(tgt_len, ext_len, mem_len)
|
self.transformer.reset_length(tgt_len, ext_len, mem_len)
|
||||||
@@ -908,22 +903,14 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
|
|||||||
last_hidden = transformer_outputs[0]
|
last_hidden = transformer_outputs[0]
|
||||||
pred_hid = last_hidden[:, -tgt_len:]
|
pred_hid = last_hidden[:, -tgt_len:]
|
||||||
outputs = transformer_outputs[1:]
|
outputs = transformer_outputs[1:]
|
||||||
if self.sample_softmax > 0 and self.training:
|
|
||||||
assert self.config.tie_weight
|
softmax_output = self.crit(pred_hid.view(-1, pred_hid.size(-1)), labels)
|
||||||
logit = sample_logits(self.transformer.word_emb, self.out_layer.bias, labels, pred_hid, self.sampler)
|
if labels is None:
|
||||||
softmax_output = -F.log_softmax(logit, -1)[:, :, 0]
|
softmax_output = softmax_output.view(bsz, tgt_len, -1)
|
||||||
outputs = [softmax_output] + outputs
|
outputs = [softmax_output] + outputs
|
||||||
if labels is not None:
|
|
||||||
# TODO: This is not implemented
|
|
||||||
raise NotImplementedError
|
|
||||||
else:
|
else:
|
||||||
softmax_output = self.crit(pred_hid.view(-1, pred_hid.size(-1)), labels)
|
softmax_output = softmax_output.view(bsz, tgt_len)
|
||||||
if labels is None:
|
outputs = [softmax_output, None] + outputs
|
||||||
softmax_output = softmax_output.view(bsz, tgt_len, -1)
|
|
||||||
outputs = [softmax_output] + outputs
|
|
||||||
else:
|
|
||||||
softmax_output = softmax_output.view(bsz, tgt_len)
|
|
||||||
outputs = [softmax_output, None] + outputs
|
|
||||||
|
|
||||||
return outputs # (loss), logits or None if labels is not None (speed up adaptive softmax), new_mems, (all hidden states), (all attentions)
|
return outputs # (loss), logits or None if labels is not None (speed up adaptive softmax), new_mems, (all hidden states), (all attentions)
|
||||||
|
|
||||||
|
|||||||
@@ -241,77 +241,3 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
|
|||||||
out[:, start_idx, stop_idx] = logprob_i
|
out[:, start_idx, stop_idx] = logprob_i
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class LogUniformSampler(object):
|
|
||||||
def __init__(self, range_max, n_sample):
|
|
||||||
"""
|
|
||||||
Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py
|
|
||||||
`P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)`
|
|
||||||
|
|
||||||
expected count can be approximated by 1 - (1 - p)^n
|
|
||||||
and we use a numerically stable version -expm1(num_tries * log1p(-p))
|
|
||||||
|
|
||||||
Our implementation fixes num_tries at 2 * n_sample, and the actual #samples will vary from run to run
|
|
||||||
"""
|
|
||||||
with torch.no_grad():
|
|
||||||
self.range_max = range_max
|
|
||||||
log_indices = torch.arange(1.0, range_max + 2.0, 1.0).log_()
|
|
||||||
self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1]
|
|
||||||
|
|
||||||
self.log_q = (-(-self.dist.double().log1p_() * 2 * n_sample).expm1_()).log_().float()
|
|
||||||
|
|
||||||
self.n_sample = n_sample
|
|
||||||
|
|
||||||
def sample(self, labels):
|
|
||||||
"""
|
|
||||||
labels: [b1, b2]
|
|
||||||
Return
|
|
||||||
true_log_probs: [b1, b2]
|
|
||||||
samp_log_probs: [n_sample]
|
|
||||||
neg_samples: [n_sample]
|
|
||||||
"""
|
|
||||||
|
|
||||||
# neg_samples = torch.empty(0).long()
|
|
||||||
n_sample = self.n_sample
|
|
||||||
n_tries = 2 * n_sample
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
neg_samples = torch.multinomial(self.dist, n_tries, replacement=True).unique()
|
|
||||||
device = labels.device
|
|
||||||
neg_samples = neg_samples.to(device)
|
|
||||||
true_log_probs = self.log_q[labels].to(device)
|
|
||||||
samp_log_probs = self.log_q[neg_samples].to(device)
|
|
||||||
return true_log_probs, samp_log_probs, neg_samples
|
|
||||||
|
|
||||||
|
|
||||||
def sample_logits(embedding, bias, labels, inputs, sampler):
|
|
||||||
"""
|
|
||||||
embedding: an nn.Embedding layer
|
|
||||||
bias: [n_vocab]
|
|
||||||
labels: [b1, b2]
|
|
||||||
inputs: [b1, b2, n_emb]
|
|
||||||
sampler: you may use a LogUniformSampler
|
|
||||||
Return
|
|
||||||
logits: [b1, b2, 1 + n_sample]
|
|
||||||
"""
|
|
||||||
true_log_probs, samp_log_probs, neg_samples = sampler.sample(labels)
|
|
||||||
n_sample = neg_samples.size(0)
|
|
||||||
b1, b2 = labels.size(0), labels.size(1)
|
|
||||||
all_ids = torch.cat([labels.view(-1), neg_samples])
|
|
||||||
all_w = embedding(all_ids)
|
|
||||||
true_w = all_w[:-n_sample].view(b1, b2, -1)
|
|
||||||
sample_w = all_w[-n_sample:].view(n_sample, -1)
|
|
||||||
|
|
||||||
all_b = bias[all_ids]
|
|
||||||
true_b = all_b[:-n_sample].view(b1, b2)
|
|
||||||
sample_b = all_b[-n_sample:]
|
|
||||||
|
|
||||||
hit = (labels[:, :, None] == neg_samples).detach()
|
|
||||||
|
|
||||||
true_logits = torch.einsum("ijk,ijk->ij", [true_w, inputs]) + true_b - true_log_probs
|
|
||||||
sample_logits = torch.einsum("lk,ijk->ijl", [sample_w, inputs]) + sample_b - samp_log_probs
|
|
||||||
sample_logits.masked_fill_(hit, -1e30)
|
|
||||||
logits = torch.cat([true_logits[:, :, None], sample_logits], -1)
|
|
||||||
|
|
||||||
return logits
|
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ if is_tf_available():
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from transformers import tf_top_k_top_p_filtering
|
from transformers import tf_top_k_top_p_filtering, TFAdaptiveEmbedding
|
||||||
|
|
||||||
if _tf_gpu_memory_limit is not None:
|
if _tf_gpu_memory_limit is not None:
|
||||||
gpus = tf.config.list_physical_devices("GPU")
|
gpus = tf.config.list_physical_devices("GPU")
|
||||||
@@ -348,7 +348,7 @@ class TFModelTesterMixin:
|
|||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer)
|
assert isinstance(model.get_input_embeddings(), (tf.keras.layers.Layer, TFAdaptiveEmbedding))
|
||||||
x = model.get_output_embeddings()
|
x = model.get_output_embeddings()
|
||||||
assert x is None or isinstance(x, tf.keras.layers.Layer)
|
assert x is None or isinstance(x, tf.keras.layers.Layer)
|
||||||
|
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ from .utils import CACHE_DIR, require_tf, slow
|
|||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from transformers.modeling_tf_transfo_xl import (
|
from transformers import (
|
||||||
TFTransfoXLModel,
|
TFTransfoXLModel,
|
||||||
TFTransfoXLLMHeadModel,
|
TFTransfoXLLMHeadModel,
|
||||||
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
@@ -364,7 +364,7 @@ class TFTransfoXLModelLanguageGenerationTest(unittest.TestCase):
|
|||||||
0,
|
0,
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
dtype=tf.int31,
|
dtype=tf.int32,
|
||||||
)
|
)
|
||||||
# In 1991 , the remains of Russian Tsar Nicholas II and his family
|
# In 1991 , the remains of Russian Tsar Nicholas II and his family
|
||||||
# ( except for Alexei and Maria ) are discovered .
|
# ( except for Alexei and Maria ) are discovered .
|
||||||
@@ -570,8 +570,5 @@ class TFTransfoXLModelLanguageGenerationTest(unittest.TestCase):
|
|||||||
# Nicholas II and his family were discovered. The voice of <unk> young son,
|
# Nicholas II and his family were discovered. The voice of <unk> young son,
|
||||||
# Tsarevich Alexei Nikolaevich, narrates the remainder of the story.<eos>
|
# Tsarevich Alexei Nikolaevich, narrates the remainder of the story.<eos>
|
||||||
|
|
||||||
# TODO: add this test when trasnfo-xl-lmhead is implemented
|
output_ids = model.generate(input_ids, max_length=200, do_sample=False)
|
||||||
with self.assertRaises(NotImplementedError):
|
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
|
||||||
model.generate(input_ids, max_length=200, do_sample=False)
|
|
||||||
print(expected_output_ids)
|
|
||||||
# self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids) TODO: (PVP) to add when transfo-xl is implemented
|
|
||||||
|
|||||||
@@ -129,10 +129,10 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
def check_transfo_xl_model_output(self, result):
|
def check_transfo_xl_model_output(self, result):
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
list(result["hidden_states_1"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
list(result["hidden_states_1"].size()), [self.batch_size, self.seq_length, self.hidden_size],
|
||||||
)
|
)
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
list(result["hidden_states_2"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
list(result["hidden_states_2"].size()), [self.batch_size, self.seq_length, self.hidden_size],
|
||||||
)
|
)
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
list(list(mem.size()) for mem in result["mems_1"]),
|
list(list(mem.size()) for mem in result["mems_1"]),
|
||||||
@@ -166,7 +166,7 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
def check_transfo_xl_lm_head_output(self, result):
|
def check_transfo_xl_lm_head_output(self, result):
|
||||||
self.parent.assertListEqual(list(result["loss_1"].size()), [self.batch_size, self.seq_length])
|
self.parent.assertListEqual(list(result["loss_1"].size()), [self.batch_size, self.seq_length])
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
list(result["lm_logits_1"].size()), [self.batch_size, self.seq_length, self.vocab_size]
|
list(result["lm_logits_1"].size()), [self.batch_size, self.seq_length, self.vocab_size],
|
||||||
)
|
)
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
list(list(mem.size()) for mem in result["mems_1"]),
|
list(list(mem.size()) for mem in result["mems_1"]),
|
||||||
@@ -175,7 +175,7 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
self.parent.assertListEqual(list(result["loss_2"].size()), [self.batch_size, self.seq_length])
|
self.parent.assertListEqual(list(result["loss_2"].size()), [self.batch_size, self.seq_length])
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
list(result["lm_logits_2"].size()), [self.batch_size, self.seq_length, self.vocab_size]
|
list(result["lm_logits_2"].size()), [self.batch_size, self.seq_length, self.vocab_size],
|
||||||
)
|
)
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
list(list(mem.size()) for mem in result["mems_2"]),
|
list(list(mem.size()) for mem in result["mems_2"]),
|
||||||
|
|||||||
Reference in New Issue
Block a user