Adding LM Head to Transfo-XL and first step to fixing problem with Adaptive Embeddings in TransfoXL (#3286)
* first commit * work in progress * make language generation task pass * update to working version for LM * delete print * remove dead code * make style
This commit is contained in:
committed by
GitHub
parent
efdb46b6e2
commit
292186a3e7
@@ -357,6 +357,7 @@ if is_tf_available():
|
||||
TFTransfoXLModel,
|
||||
TFTransfoXLLMHeadModel,
|
||||
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TFAdaptiveEmbedding,
|
||||
)
|
||||
|
||||
from .modeling_tf_xlnet import (
|
||||
|
||||
@@ -733,6 +733,25 @@ class TFTransfoXLModel(TFTransfoXLPreTrainedModel):
|
||||
return outputs
|
||||
|
||||
|
||||
class TFTransfoXLLMHead(tf.keras.layers.Layer):
|
||||
def __init__(self, config, input_embeddings, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
# The output weights are the same as the input embeddings, but there is
|
||||
# an output-only bias for each token.
|
||||
self.input_embeddings = input_embeddings
|
||||
|
||||
def build(self, input_shape):
|
||||
self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
|
||||
super().build(input_shape)
|
||||
|
||||
def call(self, hidden_states):
|
||||
hidden_states = self.input_embeddings(hidden_states, mode="linear")
|
||||
hidden_states = hidden_states + self.bias
|
||||
return hidden_states
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""The Transformer-XL Model with a language modeling head on top
|
||||
(adaptive softmax with weights tied to the adaptive input embeddings)""",
|
||||
@@ -743,14 +762,20 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
|
||||
super().__init__(config)
|
||||
self.transformer = TFTransfoXLMainLayer(config, name="transformer")
|
||||
self.sample_softmax = config.sample_softmax
|
||||
# use sampled softmax
|
||||
if config.sample_softmax > 0:
|
||||
raise NotImplementedError
|
||||
# use adaptive softmax (including standard softmax)
|
||||
else:
|
||||
self.crit = TFAdaptiveSoftmaxMask(
|
||||
config.vocab_size, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val, name="crit"
|
||||
)
|
||||
assert (
|
||||
self.sample_softmax <= 0
|
||||
), "Sampling from the softmax is not implemented yet. Please look at issue: #3310: https://github.com/huggingface/transformers/issues/3310"
|
||||
|
||||
self.crit = TFAdaptiveSoftmaxMask(
|
||||
config.vocab_size, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val, name="crit"
|
||||
)
|
||||
|
||||
def get_output_embeddings(self):
|
||||
""" Double-check if you are using adaptive softmax.
|
||||
"""
|
||||
if len(self.crit.out_layers) > 0:
|
||||
return self.crit.out_layers[-1]
|
||||
return None
|
||||
|
||||
def reset_length(self, tgt_len, ext_len, mem_len):
|
||||
self.transformer.reset_length(tgt_len, ext_len, mem_len)
|
||||
@@ -820,13 +845,9 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
|
||||
last_hidden = transformer_outputs[0]
|
||||
pred_hid = last_hidden[:, -tgt_len:]
|
||||
outputs = transformer_outputs[1:]
|
||||
if self.sample_softmax > 0 and training:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
# pred_hid = tf.reshape(pred_hid, (-1, shape_list(pred_hid)[-1]))
|
||||
softmax_output = self.crit([pred_hid, labels], training=training)
|
||||
# softmax_output = tf.reshape(softmax_output, (bsz, tgt_len, -1))
|
||||
outputs = [softmax_output] + outputs
|
||||
|
||||
softmax_output = self.crit([pred_hid, labels], training=training)
|
||||
outputs = [softmax_output] + outputs
|
||||
|
||||
return outputs # logits, new_mems, (all hidden states), (all attentions)
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ import torch.nn.functional as F
|
||||
|
||||
from .configuration_transfo_xl import TransfoXLConfig
|
||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||
from .modeling_transfo_xl_utilities import LogUniformSampler, ProjectedAdaptiveLogSoftmax, sample_logits
|
||||
from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax
|
||||
from .modeling_utils import PreTrainedModel
|
||||
|
||||
|
||||
@@ -809,42 +809,37 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
|
||||
super().__init__(config)
|
||||
self.transformer = TransfoXLModel(config)
|
||||
self.sample_softmax = config.sample_softmax
|
||||
# use sampled softmax
|
||||
if config.sample_softmax > 0:
|
||||
self.out_layer = nn.Linear(config.d_model, config.vocab_size)
|
||||
self.sampler = LogUniformSampler(config.vocab_size, config.sample_softmax)
|
||||
# use adaptive softmax (including standard softmax)
|
||||
else:
|
||||
self.crit = ProjectedAdaptiveLogSoftmax(
|
||||
config.vocab_size, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val
|
||||
)
|
||||
|
||||
assert (
|
||||
self.sample_softmax <= 0
|
||||
), "Sampling from the softmax is not implemented yet. Please look at issue: #3310: https://github.com/huggingface/transformers/issues/3310"
|
||||
|
||||
self.crit = ProjectedAdaptiveLogSoftmax(
|
||||
config.vocab_size, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val
|
||||
)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
"""
|
||||
Run this to be sure output and input (adaptive) softmax weights are tied
|
||||
"""
|
||||
# sampled softmax
|
||||
if self.sample_softmax > 0:
|
||||
if self.config.tie_weight:
|
||||
self.out_layer.weight = self.transformer.word_emb.weight
|
||||
# adaptive softmax (including standard softmax)
|
||||
else:
|
||||
if self.config.tie_weight:
|
||||
for i in range(len(self.crit.out_layers)):
|
||||
self._tie_or_clone_weights(self.crit.out_layers[i], self.transformer.word_emb.emb_layers[i])
|
||||
if self.config.tie_projs:
|
||||
for i, tie_proj in enumerate(self.config.tie_projs):
|
||||
if tie_proj and self.config.div_val == 1 and self.config.d_model != self.config.d_embed:
|
||||
if self.config.torchscript:
|
||||
self.crit.out_projs[i] = nn.Parameter(self.transformer.word_emb.emb_projs[0].clone())
|
||||
else:
|
||||
self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[0]
|
||||
elif tie_proj and self.config.div_val != 1:
|
||||
if self.config.torchscript:
|
||||
self.crit.out_projs[i] = nn.Parameter(self.transformer.word_emb.emb_projs[i].clone())
|
||||
else:
|
||||
self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[i]
|
||||
|
||||
if self.config.tie_weight:
|
||||
for i in range(len(self.crit.out_layers)):
|
||||
self._tie_or_clone_weights(self.crit.out_layers[i], self.transformer.word_emb.emb_layers[i])
|
||||
if self.config.tie_projs:
|
||||
for i, tie_proj in enumerate(self.config.tie_projs):
|
||||
if tie_proj and self.config.div_val == 1 and self.config.d_model != self.config.d_embed:
|
||||
if self.config.torchscript:
|
||||
self.crit.out_projs[i] = nn.Parameter(self.transformer.word_emb.emb_projs[0].clone())
|
||||
else:
|
||||
self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[0]
|
||||
elif tie_proj and self.config.div_val != 1:
|
||||
if self.config.torchscript:
|
||||
self.crit.out_projs[i] = nn.Parameter(self.transformer.word_emb.emb_projs[i].clone())
|
||||
else:
|
||||
self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[i]
|
||||
|
||||
def reset_length(self, tgt_len, ext_len, mem_len):
|
||||
self.transformer.reset_length(tgt_len, ext_len, mem_len)
|
||||
@@ -908,22 +903,14 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
|
||||
last_hidden = transformer_outputs[0]
|
||||
pred_hid = last_hidden[:, -tgt_len:]
|
||||
outputs = transformer_outputs[1:]
|
||||
if self.sample_softmax > 0 and self.training:
|
||||
assert self.config.tie_weight
|
||||
logit = sample_logits(self.transformer.word_emb, self.out_layer.bias, labels, pred_hid, self.sampler)
|
||||
softmax_output = -F.log_softmax(logit, -1)[:, :, 0]
|
||||
|
||||
softmax_output = self.crit(pred_hid.view(-1, pred_hid.size(-1)), labels)
|
||||
if labels is None:
|
||||
softmax_output = softmax_output.view(bsz, tgt_len, -1)
|
||||
outputs = [softmax_output] + outputs
|
||||
if labels is not None:
|
||||
# TODO: This is not implemented
|
||||
raise NotImplementedError
|
||||
else:
|
||||
softmax_output = self.crit(pred_hid.view(-1, pred_hid.size(-1)), labels)
|
||||
if labels is None:
|
||||
softmax_output = softmax_output.view(bsz, tgt_len, -1)
|
||||
outputs = [softmax_output] + outputs
|
||||
else:
|
||||
softmax_output = softmax_output.view(bsz, tgt_len)
|
||||
outputs = [softmax_output, None] + outputs
|
||||
softmax_output = softmax_output.view(bsz, tgt_len)
|
||||
outputs = [softmax_output, None] + outputs
|
||||
|
||||
return outputs # (loss), logits or None if labels is not None (speed up adaptive softmax), new_mems, (all hidden states), (all attentions)
|
||||
|
||||
|
||||
@@ -241,77 +241,3 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
|
||||
out[:, start_idx, stop_idx] = logprob_i
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class LogUniformSampler(object):
|
||||
def __init__(self, range_max, n_sample):
|
||||
"""
|
||||
Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py
|
||||
`P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)`
|
||||
|
||||
expected count can be approximated by 1 - (1 - p)^n
|
||||
and we use a numerically stable version -expm1(num_tries * log1p(-p))
|
||||
|
||||
Our implementation fixes num_tries at 2 * n_sample, and the actual #samples will vary from run to run
|
||||
"""
|
||||
with torch.no_grad():
|
||||
self.range_max = range_max
|
||||
log_indices = torch.arange(1.0, range_max + 2.0, 1.0).log_()
|
||||
self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1]
|
||||
|
||||
self.log_q = (-(-self.dist.double().log1p_() * 2 * n_sample).expm1_()).log_().float()
|
||||
|
||||
self.n_sample = n_sample
|
||||
|
||||
def sample(self, labels):
|
||||
"""
|
||||
labels: [b1, b2]
|
||||
Return
|
||||
true_log_probs: [b1, b2]
|
||||
samp_log_probs: [n_sample]
|
||||
neg_samples: [n_sample]
|
||||
"""
|
||||
|
||||
# neg_samples = torch.empty(0).long()
|
||||
n_sample = self.n_sample
|
||||
n_tries = 2 * n_sample
|
||||
|
||||
with torch.no_grad():
|
||||
neg_samples = torch.multinomial(self.dist, n_tries, replacement=True).unique()
|
||||
device = labels.device
|
||||
neg_samples = neg_samples.to(device)
|
||||
true_log_probs = self.log_q[labels].to(device)
|
||||
samp_log_probs = self.log_q[neg_samples].to(device)
|
||||
return true_log_probs, samp_log_probs, neg_samples
|
||||
|
||||
|
||||
def sample_logits(embedding, bias, labels, inputs, sampler):
|
||||
"""
|
||||
embedding: an nn.Embedding layer
|
||||
bias: [n_vocab]
|
||||
labels: [b1, b2]
|
||||
inputs: [b1, b2, n_emb]
|
||||
sampler: you may use a LogUniformSampler
|
||||
Return
|
||||
logits: [b1, b2, 1 + n_sample]
|
||||
"""
|
||||
true_log_probs, samp_log_probs, neg_samples = sampler.sample(labels)
|
||||
n_sample = neg_samples.size(0)
|
||||
b1, b2 = labels.size(0), labels.size(1)
|
||||
all_ids = torch.cat([labels.view(-1), neg_samples])
|
||||
all_w = embedding(all_ids)
|
||||
true_w = all_w[:-n_sample].view(b1, b2, -1)
|
||||
sample_w = all_w[-n_sample:].view(n_sample, -1)
|
||||
|
||||
all_b = bias[all_ids]
|
||||
true_b = all_b[:-n_sample].view(b1, b2)
|
||||
sample_b = all_b[-n_sample:]
|
||||
|
||||
hit = (labels[:, :, None] == neg_samples).detach()
|
||||
|
||||
true_logits = torch.einsum("ijk,ijk->ij", [true_w, inputs]) + true_b - true_log_probs
|
||||
sample_logits = torch.einsum("lk,ijk->ijl", [sample_w, inputs]) + sample_b - samp_log_probs
|
||||
sample_logits.masked_fill_(hit, -1e30)
|
||||
logits = torch.cat([true_logits[:, :, None], sample_logits], -1)
|
||||
|
||||
return logits
|
||||
|
||||
Reference in New Issue
Block a user