Add generate() functionality to TF 2.0 (#3063)
* add first copy past test to tf 2 generate * add tf top_k_top_p_filter fn * add generate function for TF * add generate function for TF * implemented generate for all models expect transfoXL * implemented generate for all models expect transfoXL * implemented generate for all models expect transfoXL * make style * change permission of test file to correct ones * delete ipdb * delete ipdb * fix bug and finish simple gpt2 integration test * clean test file * clean test file * make style * make style * make style * make style * change import style * change import style * make style * make style * add decorators * add decorators * fix tf ctrl bug dim => axis in TF * make style * make style * refactored test file * refactored test file * take out test_torch_tf_conversion if nothing is defined * take out test_torch_tf_conversion if nothing is defined * remove useless files * remove useless files * fix conflicts * fix conflicts * fix conflicts * fix conflicts * fix conflicts * solve conflicts * solve conflicts * fix conflicts * fix conflicts * merge conflicts * delete ipdb * exposed top_k_top_p_filtering fns * delete weirdly created w! file * add comment to test tf common modeling * fix conflicts * fix conflicts * make style * merge conflicts * make style * change tf.tensor.shape to shape_list(tensor)
This commit is contained in:
committed by
GitHub
parent
b31f715019
commit
4134100363
@@ -136,7 +136,7 @@ if is_sklearn_available():
|
|||||||
|
|
||||||
# Modeling
|
# Modeling
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from .modeling_utils import PreTrainedModel, prune_layer, Conv1D
|
from .modeling_utils import PreTrainedModel, prune_layer, Conv1D, top_k_top_p_filtering
|
||||||
from .modeling_auto import (
|
from .modeling_auto import (
|
||||||
AutoModel,
|
AutoModel,
|
||||||
AutoModelForPreTraining,
|
AutoModelForPreTraining,
|
||||||
@@ -291,7 +291,13 @@ if is_torch_available():
|
|||||||
|
|
||||||
# TensorFlow
|
# TensorFlow
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary, shape_list
|
from .modeling_tf_utils import (
|
||||||
|
TFPreTrainedModel,
|
||||||
|
TFSharedEmbeddings,
|
||||||
|
TFSequenceSummary,
|
||||||
|
shape_list,
|
||||||
|
tf_top_k_top_p_filtering,
|
||||||
|
)
|
||||||
from .modeling_tf_auto import (
|
from .modeling_tf_auto import (
|
||||||
TFAutoModel,
|
TFAutoModel,
|
||||||
TFAutoModelForPreTraining,
|
TFAutoModelForPreTraining,
|
||||||
|
|||||||
@@ -454,14 +454,12 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.lm_head
|
return self.lm_head
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
def prepare_inputs_for_generation(self, input_ids, past, **kwargs):
|
||||||
# only last token for inputs_ids if past is defined in kwargs
|
# only last token for inputs_ids if past is defined in kwargs
|
||||||
if "past" in kwargs and kwargs["past"]:
|
if past:
|
||||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||||
|
|
||||||
inputs = {"input_ids": input_ids}
|
return {"input_ids": input_ids, "past": past}
|
||||||
inputs.update(kwargs)
|
|
||||||
return inputs
|
|
||||||
|
|
||||||
@add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING)
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
@@ -525,14 +525,12 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.lm_head
|
return self.lm_head
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
def prepare_inputs_for_generation(self, input_ids, past, **kwargs):
|
||||||
# only last token for inputs_ids if past is defined in kwargs
|
# only last token for inputs_ids if past is defined in kwargs
|
||||||
if "past" in kwargs and kwargs["past"]:
|
if past:
|
||||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||||
|
|
||||||
inputs = {"input_ids": input_ids}
|
return {"input_ids": input_ids, "past": past}
|
||||||
inputs.update(kwargs)
|
|
||||||
return inputs
|
|
||||||
|
|
||||||
@add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
@@ -105,8 +105,8 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
|
|||||||
v = self.split_into_heads(v, batch_size)
|
v = self.split_into_heads(v, batch_size)
|
||||||
if layer_past is not None:
|
if layer_past is not None:
|
||||||
past_key, past_value = tf.unstack(layer_past, axis=1)
|
past_key, past_value = tf.unstack(layer_past, axis=1)
|
||||||
k = tf.concat((past_key, k), dim=-2)
|
k = tf.concat((past_key, k), axis=-2)
|
||||||
v = tf.concat((past_value, v), dim=-2)
|
v = tf.concat((past_value, v), axis=-2)
|
||||||
present = tf.stack((k, v), axis=1)
|
present = tf.stack((k, v), axis=1)
|
||||||
|
|
||||||
output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask)
|
output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask)
|
||||||
@@ -505,6 +505,13 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.lm_head.input_embeddings
|
return self.lm_head.input_embeddings
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation(self, inputs, past, **kwargs):
|
||||||
|
# only last token for inputs_ids if past is defined in kwargs
|
||||||
|
if past:
|
||||||
|
inputs = tf.expand_dims(inputs[:, -1], -1)
|
||||||
|
|
||||||
|
return {"inputs": inputs, "past": past}
|
||||||
|
|
||||||
@add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING)
|
||||||
def call(self, inputs, **kwargs):
|
def call(self, inputs, **kwargs):
|
||||||
r"""
|
r"""
|
||||||
|
|||||||
@@ -500,6 +500,13 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.transformer.wte
|
return self.transformer.wte
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation(self, inputs, past, **kwargs):
|
||||||
|
# only last token for inputs_ids if past is defined in kwargs
|
||||||
|
if past:
|
||||||
|
inputs = tf.expand_dims(inputs[:, -1], -1)
|
||||||
|
|
||||||
|
return {"inputs": inputs, "past": past}
|
||||||
|
|
||||||
@add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
|
||||||
def call(self, inputs, **kwargs):
|
def call(self, inputs, **kwargs):
|
||||||
r"""
|
r"""
|
||||||
|
|||||||
@@ -826,3 +826,12 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
|
|||||||
outputs = [softmax_output] + outputs
|
outputs = [softmax_output] + outputs
|
||||||
|
|
||||||
return outputs # logits, new_mems, (all hidden states), (all attentions)
|
return outputs # logits, new_mems, (all hidden states), (all attentions)
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation(self, inputs, past, **model_kwargs):
|
||||||
|
inputs = {"inputs": inputs}
|
||||||
|
|
||||||
|
# if past is defined in model kwargs then use it for faster decoding
|
||||||
|
if past:
|
||||||
|
inputs["mems"] = past
|
||||||
|
|
||||||
|
return inputs
|
||||||
|
|||||||
@@ -384,6 +384,432 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation(self, inputs, **kwargs):
|
||||||
|
return {"inputs": inputs}
|
||||||
|
|
||||||
|
def _do_output_past(self, outputs):
|
||||||
|
has_output_past = hasattr(self.config, "output_past") and self.config.output_past
|
||||||
|
has_mem_len = hasattr(self.config, "mem_len") and self.config.mem_len
|
||||||
|
|
||||||
|
if has_output_past and not has_mem_len and len(outputs) > 1:
|
||||||
|
return True
|
||||||
|
elif has_mem_len and self.config.mem_len > 0 and len(outputs) > 1:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
max_length=None,
|
||||||
|
do_sample=True,
|
||||||
|
num_beams=None,
|
||||||
|
temperature=None,
|
||||||
|
top_k=None,
|
||||||
|
top_p=None,
|
||||||
|
repetition_penalty=None,
|
||||||
|
bos_token_id=None,
|
||||||
|
pad_token_id=None,
|
||||||
|
eos_token_ids=None,
|
||||||
|
length_penalty=None,
|
||||||
|
num_return_sequences=None,
|
||||||
|
):
|
||||||
|
r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
|
||||||
|
and beam-search.
|
||||||
|
|
||||||
|
Adapted in part from `Facebook's XLM beam search code`_.
|
||||||
|
|
||||||
|
.. _`Facebook's XLM beam search code`:
|
||||||
|
https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529
|
||||||
|
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
|
||||||
|
input_ids: (`optional`) `torch.LongTensor` of shape `(batch_size, sequence_length)`
|
||||||
|
The sequence used as a prompt for the generation. If `None` the method initializes
|
||||||
|
it as an empty `torch.LongTensor` of shape `(1,)`.
|
||||||
|
|
||||||
|
max_length: (`optional`) int
|
||||||
|
The max length of the sequence to be generated. Between 1 and infinity. Default to 20.
|
||||||
|
|
||||||
|
do_sample: (`optional`) bool
|
||||||
|
If set to `False` greedy decoding is used. Otherwise sampling is used. Defaults to `True`.
|
||||||
|
|
||||||
|
num_beams: (`optional`) int
|
||||||
|
Number of beams for beam search. Must be between 1 and infinity. 1 means no beam search. Default to 1.
|
||||||
|
|
||||||
|
temperature: (`optional`) float
|
||||||
|
The value used to module the next token probabilities. Must be strictely positive. Default to 1.0.
|
||||||
|
|
||||||
|
top_k: (`optional`) int
|
||||||
|
The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
|
||||||
|
|
||||||
|
top_p: (`optional`) float
|
||||||
|
The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
|
||||||
|
|
||||||
|
repetition_penalty: (`optional`) float
|
||||||
|
The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.
|
||||||
|
|
||||||
|
bos_token_id: (`optional`) int
|
||||||
|
Beginning of sentence token if no prompt is provided. Default to 0.
|
||||||
|
|
||||||
|
eos_token_ids: (`optional`) int or list of int
|
||||||
|
End of sequence token or list of tokens to stop the generation. Default to 0.
|
||||||
|
length_penalty: (`optional`) float
|
||||||
|
Exponential penalty to the length. Default to 1.
|
||||||
|
|
||||||
|
num_return_sequences: (`optional`) int
|
||||||
|
The number of independently computed returned sequences for each element in the batch. Default to 1.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
|
||||||
|
output: `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`
|
||||||
|
sequence_length is either equal to max_length or shorter if all batches finished early due to the `eos_token_id`
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
|
||||||
|
model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
|
||||||
|
outputs = model.generate(max_length=40, bos_token_id=tokenizer.bos_token_id, eos_token_ids=tokenizer.eos_token_id, do_sample=False) # do greedy decoding
|
||||||
|
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained('openai-gpt') # Initialize tokenizer
|
||||||
|
model = AutoModelWithLMHead.from_pretrained('openai-gpt') # Download model and configuration from S3 and cache.
|
||||||
|
input_context = 'The dog'
|
||||||
|
input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0) # encode input context
|
||||||
|
outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5) # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog'
|
||||||
|
for i in range(3): # 3 output sequences were generated
|
||||||
|
print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
|
||||||
|
model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
|
||||||
|
input_context = 'The dog'
|
||||||
|
input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0) # encode input context
|
||||||
|
outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, bos_token_id=tokenizer.bos_token_id, pad_token_id=tokenizer.pad_token_id, eos_token_ids=tokenizer.eos_token_id, num_return_sequences=3) # 3 generate sequences using by sampling
|
||||||
|
for i in range(3): # 3 output sequences were generated
|
||||||
|
print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained('ctrl') # Initialize tokenizer
|
||||||
|
model = AutoModelWithLMHead.from_pretrained('ctrl') # Download model and configuration from S3 and cache.
|
||||||
|
input_context = 'Legal My neighbor is' # "Legal" is one of the control codes for ctrl
|
||||||
|
input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0) # encode input context
|
||||||
|
outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) # generate sequences
|
||||||
|
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# We cannot generate if the model does not have a LM head
|
||||||
|
if self.get_output_embeddings() is None:
|
||||||
|
raise AttributeError(
|
||||||
|
"You tried to generate sequences with a model that does not have a LM Head."
|
||||||
|
"Please use another model class (e.g. `OpenAIGPTLMHeadModel`, `XLNetLMHeadModel`, `GPT2LMHeadModel`, `CTRLLMHeadModel`, `T5WithLMHeadModel`, `TransfoXLLMHeadModel`)"
|
||||||
|
)
|
||||||
|
|
||||||
|
max_length = max_length if max_length is not None else self.config.max_length
|
||||||
|
do_sample = do_sample if do_sample is not None else self.config.do_sample
|
||||||
|
num_beams = num_beams if num_beams is not None else self.config.num_beams
|
||||||
|
temperature = temperature if temperature is not None else self.config.temperature
|
||||||
|
top_k = top_k if top_k is not None else self.config.top_k
|
||||||
|
top_p = top_p if top_p is not None else self.config.top_p
|
||||||
|
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
|
||||||
|
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
|
||||||
|
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||||
|
eos_token_ids = eos_token_ids if eos_token_ids is not None else self.config.eos_token_ids
|
||||||
|
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
|
||||||
|
num_return_sequences = (
|
||||||
|
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
|
||||||
|
)
|
||||||
|
|
||||||
|
if input_ids is not None:
|
||||||
|
batch_size = shape_list(input_ids)[0] # overriden by the input batch_size
|
||||||
|
else:
|
||||||
|
batch_size = 1
|
||||||
|
if isinstance(eos_token_ids, int):
|
||||||
|
eos_token_ids = [eos_token_ids]
|
||||||
|
|
||||||
|
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer."
|
||||||
|
assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
|
||||||
|
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictely positive integer."
|
||||||
|
assert temperature > 0, "`temperature` should be strictely positive."
|
||||||
|
assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
|
||||||
|
assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
|
||||||
|
assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
|
||||||
|
assert input_ids is not None or (
|
||||||
|
isinstance(bos_token_id, int) and bos_token_id >= 0
|
||||||
|
), "If input_ids is not defined, `bos_token_id` should be a positive integer."
|
||||||
|
assert pad_token_id is None or (
|
||||||
|
isinstance(pad_token_id, int) and (pad_token_id >= 0)
|
||||||
|
), "`pad_token_id` should be a positive integer."
|
||||||
|
assert (eos_token_ids is None) or (
|
||||||
|
isinstance(eos_token_ids, (list, tuple)) and ((isinstance(e, int) and e >= 0) for e in eos_token_ids)
|
||||||
|
), "`eos_token_ids` should be a positive integer or a list/tuple of positive integers."
|
||||||
|
assert length_penalty > 0, "`length_penalty` should be strictely positive."
|
||||||
|
assert (
|
||||||
|
isinstance(num_return_sequences, int) and num_return_sequences > 0
|
||||||
|
), "`num_return_sequences` should be a strictely positive integer."
|
||||||
|
|
||||||
|
if input_ids is None:
|
||||||
|
assert isinstance(bos_token_id, int) and bos_token_id >= 0, (
|
||||||
|
"you should either supply a context to complete as `input_ids` input "
|
||||||
|
"or a `bos_token_id` (integer >= 0) as a first token to start the generation."
|
||||||
|
)
|
||||||
|
input_ids = tf.fill((batch_size, 1), bos_token_id)
|
||||||
|
else:
|
||||||
|
assert len(shape_list(input_ids)) == 2, "Input prompt should be of shape (batch_size, sequence length)."
|
||||||
|
|
||||||
|
if pad_token_id is None and eos_token_ids is not None:
|
||||||
|
logger.warning(
|
||||||
|
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_ids[0])
|
||||||
|
)
|
||||||
|
pad_token_id = eos_token_ids[0]
|
||||||
|
|
||||||
|
# current position and vocab size
|
||||||
|
cur_len = shape_list(input_ids)[1]
|
||||||
|
vocab_size = self.config.vocab_size
|
||||||
|
|
||||||
|
if num_return_sequences != 1:
|
||||||
|
# Expand input to num return sequences
|
||||||
|
input_ids = tf.broadcast_to(tf.expand_dims(input_ids, 1), (batch_size, num_return_sequences, cur_len))
|
||||||
|
effective_batch_size = batch_size * num_return_sequences
|
||||||
|
input_ids = tf.reshape(input_ids, (effective_batch_size, cur_len))
|
||||||
|
else:
|
||||||
|
effective_batch_size = batch_size
|
||||||
|
|
||||||
|
if num_beams > 1:
|
||||||
|
output = self._generate_beam_search(
|
||||||
|
input_ids,
|
||||||
|
cur_len,
|
||||||
|
max_length,
|
||||||
|
do_sample,
|
||||||
|
temperature,
|
||||||
|
top_k,
|
||||||
|
top_p,
|
||||||
|
repetition_penalty,
|
||||||
|
pad_token_id,
|
||||||
|
eos_token_ids,
|
||||||
|
effective_batch_size,
|
||||||
|
length_penalty,
|
||||||
|
num_beams,
|
||||||
|
vocab_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output = self._generate_no_beam_search(
|
||||||
|
input_ids,
|
||||||
|
cur_len,
|
||||||
|
max_length,
|
||||||
|
do_sample,
|
||||||
|
temperature,
|
||||||
|
top_k,
|
||||||
|
top_p,
|
||||||
|
repetition_penalty,
|
||||||
|
pad_token_id,
|
||||||
|
eos_token_ids,
|
||||||
|
effective_batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def _generate_no_beam_search(
|
||||||
|
self,
|
||||||
|
input_ids,
|
||||||
|
cur_len,
|
||||||
|
max_length,
|
||||||
|
do_sample,
|
||||||
|
temperature,
|
||||||
|
top_k,
|
||||||
|
top_p,
|
||||||
|
repetition_penalty,
|
||||||
|
pad_token_id,
|
||||||
|
eos_token_ids,
|
||||||
|
batch_size,
|
||||||
|
):
|
||||||
|
""" Generate sequences for each example without beam search (num_beams == 1).
|
||||||
|
All returned sequence are generated independantly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _create_next_token_logits_penalties(input_ids, logits):
|
||||||
|
# create logit penalties for already seen input_ids
|
||||||
|
token_penalties = np.ones(shape_list(logits))
|
||||||
|
prev_input_ids = [np.unique(input_id) for input_id in input_ids.numpy()]
|
||||||
|
for i, prev_input_id in enumerate(prev_input_ids):
|
||||||
|
logit_penalized = logits[i].numpy()[prev_input_id]
|
||||||
|
# if previous logit score is < 0 then multiply repetition penalty else divide
|
||||||
|
logit_penalized[logit_penalized < 0] = repetition_penalty
|
||||||
|
logit_penalized[logit_penalized > 0] = 1 / repetition_penalty
|
||||||
|
np.put(token_penalties[i], prev_input_id, logit_penalized)
|
||||||
|
return tf.convert_to_tensor(token_penalties, dtype=tf.float32)
|
||||||
|
|
||||||
|
# current position / max lengths / length of generated sentences / unfinished sentences
|
||||||
|
unfinished_sents = tf.ones_like(input_ids[:, 0])
|
||||||
|
sent_lengths = tf.ones_like(input_ids[:, 0]) * max_length
|
||||||
|
|
||||||
|
past = None
|
||||||
|
|
||||||
|
while cur_len < max_length:
|
||||||
|
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past)
|
||||||
|
outputs = self(**model_inputs)
|
||||||
|
next_token_logits = outputs[0][:, -1, :]
|
||||||
|
|
||||||
|
# if model has past, then set the past variable to speed up decoding
|
||||||
|
if self._do_output_past(outputs):
|
||||||
|
past = outputs[1]
|
||||||
|
|
||||||
|
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
|
||||||
|
if repetition_penalty != 1.0:
|
||||||
|
next_token_logits_penalties = _create_next_token_logits_penalties(input_ids, next_token_logits)
|
||||||
|
next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties)
|
||||||
|
|
||||||
|
if do_sample:
|
||||||
|
# Temperature (higher temperature => more likely to sample low probability tokens)
|
||||||
|
if temperature != 1.0:
|
||||||
|
next_token_logits = next_token_logits / temperature
|
||||||
|
# Top-p/top-k filtering
|
||||||
|
next_token_logits = tf_top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
|
||||||
|
# Sample
|
||||||
|
next_token = tf.squeeze(
|
||||||
|
tf.random.categorical(next_token_logits, dtype=tf.int32, num_samples=1), axis=1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Greedy decoding
|
||||||
|
next_token = tf.math.argmax(next_token_logits, axis=-1, output_type=tf.int32)
|
||||||
|
|
||||||
|
# update generations and finished sentences
|
||||||
|
if eos_token_ids is not None:
|
||||||
|
# pad finished sentences if eos_token_ids exist
|
||||||
|
tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
|
||||||
|
else:
|
||||||
|
tokens_to_add = next_token
|
||||||
|
|
||||||
|
input_ids = tf.concat([input_ids, tf.expand_dims(tokens_to_add, -1)], 1)
|
||||||
|
|
||||||
|
if eos_token_ids is not None:
|
||||||
|
for eos_token_id in eos_token_ids:
|
||||||
|
eos_in_sents = tokens_to_add == eos_token_id
|
||||||
|
# if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
|
||||||
|
is_sents_unfinished_and_token_to_add_is_eos = tf.math.multiply(
|
||||||
|
unfinished_sents, tf.cast(eos_in_sents, tf.int32)
|
||||||
|
)
|
||||||
|
sent_lengths = (
|
||||||
|
sent_lengths * (1 - is_sents_unfinished_and_token_to_add_is_eos)
|
||||||
|
+ cur_len * is_sents_unfinished_and_token_to_add_is_eos
|
||||||
|
)
|
||||||
|
|
||||||
|
# unfinished_sents is set to zero if eos in sentence
|
||||||
|
unfinished_sents -= is_sents_unfinished_and_token_to_add_is_eos
|
||||||
|
|
||||||
|
cur_len = cur_len + 1
|
||||||
|
|
||||||
|
# stop when there is a </s> in each sentence, or if we exceed the maximul length
|
||||||
|
if tf.math.reduce_max(unfinished_sents) == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
# if there are different sentences lengths in the batch, some batches have to be padded
|
||||||
|
min_sent_length = tf.math.reduce_min(sent_lengths)
|
||||||
|
max_sent_length = tf.math.reduce_max(sent_lengths)
|
||||||
|
if min_sent_length != max_sent_length:
|
||||||
|
assert pad_token_id is not None, "`Pad_token_id` has to be defined if batches have different lengths"
|
||||||
|
# finished sents are filled with pad_token
|
||||||
|
padding = tf.ones([batch_size, max_sent_length.numpy()], dtype=tf.int32) * pad_token_id
|
||||||
|
|
||||||
|
# create length masks for tf.where operation
|
||||||
|
broad_casted_sent_lengths = tf.broadcast_to(
|
||||||
|
tf.expand_dims(sent_lengths, -1), [batch_size, max_sent_length]
|
||||||
|
)
|
||||||
|
broad_casted_range = tf.transpose(
|
||||||
|
tf.broadcast_to(tf.expand_dims(tf.range(max_length), -1), [max_length, batch_size])
|
||||||
|
)
|
||||||
|
|
||||||
|
decoded = tf.where(broad_casted_range < broad_casted_sent_lengths, input_ids, padding)
|
||||||
|
else:
|
||||||
|
decoded = input_ids
|
||||||
|
|
||||||
|
return decoded
|
||||||
|
|
||||||
|
def _generate_beam_search(
|
||||||
|
self,
|
||||||
|
input_ids,
|
||||||
|
cur_len,
|
||||||
|
max_length,
|
||||||
|
do_sample,
|
||||||
|
temperature,
|
||||||
|
top_k,
|
||||||
|
top_p,
|
||||||
|
repetition_penalty,
|
||||||
|
pad_token_id,
|
||||||
|
eos_token_ids,
|
||||||
|
batch_size,
|
||||||
|
length_penalty,
|
||||||
|
num_beams,
|
||||||
|
vocab_size,
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
|
||||||
|
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
||||||
|
Args:
|
||||||
|
logits: logits distribution shape (batch size, vocabulary size)
|
||||||
|
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
||||||
|
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
||||||
|
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
||||||
|
Make sure we keep at least min_tokens_to_keep per batch example in the output
|
||||||
|
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
||||||
|
"""
|
||||||
|
logits_shape = shape_list(logits)
|
||||||
|
|
||||||
|
if top_k > 0:
|
||||||
|
top_k = min(max(top_k, min_tokens_to_keep), logits_shape[-1]) # Safety check
|
||||||
|
# Remove all tokens with a probability less than the last token of the top-k
|
||||||
|
indices_to_remove = logits < tf.math.top_k(logits, k=top_k)[0][..., -1, None]
|
||||||
|
logits = set_tensor_by_indices_to_value(logits, indices_to_remove, filter_value)
|
||||||
|
|
||||||
|
if top_p < 1.0:
|
||||||
|
sorted_indices = tf.argsort(logits, direction="DESCENDING")
|
||||||
|
sorted_logits = tf.gather(
|
||||||
|
logits, sorted_indices, axis=-1, batch_dims=1
|
||||||
|
) # expects logits to be of dim (batch_size, vocab_size)
|
||||||
|
|
||||||
|
cumulative_probs = tf.math.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1)
|
||||||
|
|
||||||
|
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
|
||||||
|
sorted_indices_to_remove = cumulative_probs > top_p
|
||||||
|
|
||||||
|
if min_tokens_to_keep > 1:
|
||||||
|
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
||||||
|
sorted_indices_to_remove = tf.concat(
|
||||||
|
[
|
||||||
|
tf.zeros_like(sorted_indices_to_remove[:, :min_tokens_to_keep]),
|
||||||
|
sorted_indices_to_remove[:, min_tokens_to_keep:],
|
||||||
|
],
|
||||||
|
-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Shift the indices to the right to keep also the first token above the threshold
|
||||||
|
sorted_indices_to_remove = tf.roll(sorted_indices_to_remove, 1, axis=-1)
|
||||||
|
sorted_indices_to_remove = tf.concat(
|
||||||
|
[tf.zeros_like(sorted_indices_to_remove[:, :1]), sorted_indices_to_remove[:, 1:]], -1,
|
||||||
|
)
|
||||||
|
# scatter sorted tensors to original indexing
|
||||||
|
indices_to_remove = scatter_values_on_batch_indices(sorted_indices_to_remove, sorted_indices)
|
||||||
|
logits = set_tensor_by_indices_to_value(logits, indices_to_remove, filter_value)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
def scatter_values_on_batch_indices(values, batch_indices):
|
||||||
|
shape = shape_list(batch_indices)
|
||||||
|
# broadcast batch dim to shape
|
||||||
|
broad_casted_batch_dims = tf.reshape(tf.broadcast_to(tf.expand_dims(tf.range(shape[0]), axis=-1), shape), [1, -1])
|
||||||
|
# transform batch_indices to pair_indices
|
||||||
|
pair_indices = tf.transpose(tf.concat([broad_casted_batch_dims, tf.reshape(batch_indices, [1, -1])], 0))
|
||||||
|
# scatter values to pair indices
|
||||||
|
return tf.scatter_nd(pair_indices, tf.reshape(values, [-1]), shape)
|
||||||
|
|
||||||
|
|
||||||
|
def set_tensor_by_indices_to_value(tensor, indices, value):
|
||||||
|
# create value_tensor since tensor value assignment is not possible in TF
|
||||||
|
value_tensor = tf.zeros_like(tensor) + value
|
||||||
|
return tf.where(indices, value_tensor, tensor)
|
||||||
|
|
||||||
|
|
||||||
class TFConv1D(tf.keras.layers.Layer):
|
class TFConv1D(tf.keras.layers.Layer):
|
||||||
def __init__(self, nf, nx, initializer_range=0.02, **kwargs):
|
def __init__(self, nf, nx, initializer_range=0.02, **kwargs):
|
||||||
|
|||||||
@@ -657,6 +657,20 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.pred_layer.input_embeddings
|
return self.pred_layer.input_embeddings
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation(self, inputs, **kwargs):
|
||||||
|
mask_token_id = self.config.mask_token_id
|
||||||
|
lang_id = self.config.lang_id
|
||||||
|
|
||||||
|
effective_batch_size = inputs.shape[0]
|
||||||
|
mask_token = tf.ones((effective_batch_size, 1), dtype=tf.int32) * mask_token_id
|
||||||
|
inputs = tf.concat([inputs, mask_token], axis=1)
|
||||||
|
|
||||||
|
if lang_id is not None:
|
||||||
|
langs = tf.ones_like(inputs) * lang_id
|
||||||
|
else:
|
||||||
|
langs = None
|
||||||
|
return {"inputs": inputs, "langs": langs}
|
||||||
|
|
||||||
@add_start_docstrings_to_callable(XLM_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(XLM_INPUTS_DOCSTRING)
|
||||||
def call(self, inputs, **kwargs):
|
def call(self, inputs, **kwargs):
|
||||||
r"""
|
r"""
|
||||||
|
|||||||
@@ -837,6 +837,32 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.lm_loss.input_embeddings
|
return self.lm_loss.input_embeddings
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation(self, inputs, past, **model_kwargs):
|
||||||
|
# Add dummy token at the end (no attention on this one)
|
||||||
|
|
||||||
|
effective_batch_size = inputs.shape[0]
|
||||||
|
dummy_token = tf.zeros((effective_batch_size, 1), dtype=tf.int32)
|
||||||
|
inputs = tf.concat([inputs, dummy_token], axis=1)
|
||||||
|
|
||||||
|
# Build permutation mask so that previous tokens don't see last token
|
||||||
|
sequence_length = inputs.shape[1]
|
||||||
|
perm_mask = tf.zeros((effective_batch_size, sequence_length, sequence_length - 1), dtype=tf.float32)
|
||||||
|
perm_mask_seq_end = tf.ones((effective_batch_size, sequence_length, 1), dtype=tf.float32)
|
||||||
|
perm_mask = tf.concat([perm_mask, perm_mask_seq_end], axis=-1)
|
||||||
|
|
||||||
|
# We'll only predict the last token
|
||||||
|
target_mapping = tf.zeros((effective_batch_size, 1, sequence_length - 1), dtype=tf.float32)
|
||||||
|
target_mapping_seq_end = tf.ones((effective_batch_size, 1, 1), dtype=tf.float32)
|
||||||
|
target_mapping = tf.concat([target_mapping, target_mapping_seq_end], axis=-1)
|
||||||
|
|
||||||
|
inputs = {"inputs": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping}
|
||||||
|
|
||||||
|
# if past is defined in model kwargs then use it for faster decoding
|
||||||
|
if past:
|
||||||
|
inputs["mems"] = past
|
||||||
|
|
||||||
|
return inputs
|
||||||
|
|
||||||
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
|
||||||
def call(self, inputs, **kwargs):
|
def call(self, inputs, **kwargs):
|
||||||
r"""
|
r"""
|
||||||
|
|||||||
@@ -935,11 +935,11 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
return self.crit.out_layers[-1]
|
return self.crit.out_layers[-1]
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, **model_kwargs):
|
def prepare_inputs_for_generation(self, input_ids, past, **model_kwargs):
|
||||||
inputs = {"input_ids": input_ids}
|
inputs = {"input_ids": input_ids}
|
||||||
|
|
||||||
# if past is defined in model kwargs then use it for faster decoding
|
# if past is defined in model kwargs then use it for faster decoding
|
||||||
if "past" in model_kwargs and model_kwargs["past"]:
|
if past:
|
||||||
inputs["mems"] = model_kwargs["past"]
|
inputs["mems"] = past
|
||||||
|
|
||||||
return inputs
|
return inputs
|
||||||
|
|||||||
@@ -935,7 +935,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.lm_loss
|
return self.lm_loss
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, **model_kwargs):
|
def prepare_inputs_for_generation(self, input_ids, past, **model_kwargs):
|
||||||
# Add dummy token at the end (no attention on this one)
|
# Add dummy token at the end (no attention on this one)
|
||||||
|
|
||||||
effective_batch_size = input_ids.shape[0]
|
effective_batch_size = input_ids.shape[0]
|
||||||
@@ -958,8 +958,8 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
|||||||
inputs = {"input_ids": input_ids, "perm_mask": perm_mask, "target_mapping": target_mapping}
|
inputs = {"input_ids": input_ids, "perm_mask": perm_mask, "target_mapping": target_mapping}
|
||||||
|
|
||||||
# if past is defined in model kwargs then use it for faster decoding
|
# if past is defined in model kwargs then use it for faster decoding
|
||||||
if "past" in model_kwargs and model_kwargs["past"]:
|
if past:
|
||||||
inputs["mems"] = model_kwargs["past"]
|
inputs["mems"] = past
|
||||||
|
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ if is_torch_available():
|
|||||||
BertModel,
|
BertModel,
|
||||||
BertConfig,
|
BertConfig,
|
||||||
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
top_k_top_p_filtering,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -263,7 +264,7 @@ class ModelTesterMixin:
|
|||||||
# Prepare head_mask
|
# Prepare head_mask
|
||||||
# Set require_grad after having prepared the tensor to avoid error (leaf variable has been moved into the graph interior)
|
# Set require_grad after having prepared the tensor to avoid error (leaf variable has been moved into the graph interior)
|
||||||
head_mask = torch.ones(
|
head_mask = torch.ones(
|
||||||
self.model_tester.num_hidden_layers, self.model_tester.num_attention_heads, device=torch_device
|
self.model_tester.num_hidden_layers, self.model_tester.num_attention_heads, device=torch_device,
|
||||||
)
|
)
|
||||||
head_mask[0, 0] = 0
|
head_mask[0, 0] = 0
|
||||||
head_mask[-1, :-1] = 0
|
head_mask[-1, :-1] = 0
|
||||||
@@ -303,7 +304,7 @@ class ModelTesterMixin:
|
|||||||
return
|
return
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
(config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
if "head_mask" in inputs_dict:
|
if "head_mask" in inputs_dict:
|
||||||
del inputs_dict["head_mask"]
|
del inputs_dict["head_mask"]
|
||||||
@@ -313,7 +314,10 @@ class ModelTesterMixin:
|
|||||||
model = model_class(config=config)
|
model = model_class(config=config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
heads_to_prune = {0: list(range(1, self.model_tester.num_attention_heads)), -1: [0]}
|
heads_to_prune = {
|
||||||
|
0: list(range(1, self.model_tester.num_attention_heads)),
|
||||||
|
-1: [0],
|
||||||
|
}
|
||||||
model.prune_heads(heads_to_prune)
|
model.prune_heads(heads_to_prune)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model(**inputs_dict)
|
outputs = model(**inputs_dict)
|
||||||
@@ -329,7 +333,7 @@ class ModelTesterMixin:
|
|||||||
return
|
return
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
(config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
if "head_mask" in inputs_dict:
|
if "head_mask" in inputs_dict:
|
||||||
del inputs_dict["head_mask"]
|
del inputs_dict["head_mask"]
|
||||||
@@ -339,7 +343,10 @@ class ModelTesterMixin:
|
|||||||
model = model_class(config=config)
|
model = model_class(config=config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
heads_to_prune = {0: list(range(1, self.model_tester.num_attention_heads)), -1: [0]}
|
heads_to_prune = {
|
||||||
|
0: list(range(1, self.model_tester.num_attention_heads)),
|
||||||
|
-1: [0],
|
||||||
|
}
|
||||||
model.prune_heads(heads_to_prune)
|
model.prune_heads(heads_to_prune)
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir_name:
|
with tempfile.TemporaryDirectory() as temp_dir_name:
|
||||||
@@ -359,7 +366,7 @@ class ModelTesterMixin:
|
|||||||
return
|
return
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
(config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
if "head_mask" in inputs_dict:
|
if "head_mask" in inputs_dict:
|
||||||
del inputs_dict["head_mask"]
|
del inputs_dict["head_mask"]
|
||||||
@@ -367,7 +374,10 @@ class ModelTesterMixin:
|
|||||||
config.output_attentions = True
|
config.output_attentions = True
|
||||||
config.output_hidden_states = False
|
config.output_hidden_states = False
|
||||||
|
|
||||||
heads_to_prune = {0: list(range(1, self.model_tester.num_attention_heads)), -1: [0]}
|
heads_to_prune = {
|
||||||
|
0: list(range(1, self.model_tester.num_attention_heads)),
|
||||||
|
-1: [0],
|
||||||
|
}
|
||||||
config.pruned_heads = heads_to_prune
|
config.pruned_heads = heads_to_prune
|
||||||
|
|
||||||
model = model_class(config=config)
|
model = model_class(config=config)
|
||||||
@@ -387,7 +397,7 @@ class ModelTesterMixin:
|
|||||||
return
|
return
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
(config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
if "head_mask" in inputs_dict:
|
if "head_mask" in inputs_dict:
|
||||||
del inputs_dict["head_mask"]
|
del inputs_dict["head_mask"]
|
||||||
@@ -465,7 +475,7 @@ class ModelTesterMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def test_resize_tokens_embeddings(self):
|
def test_resize_tokens_embeddings(self):
|
||||||
original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
(original_config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
if not self.test_resize_embeddings:
|
if not self.test_resize_embeddings:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -634,6 +644,7 @@ class ModelTesterMixin:
|
|||||||
self._check_generated_tokens(model.generate(input_ids, num_return_sequences=3))
|
self._check_generated_tokens(model.generate(input_ids, num_return_sequences=3))
|
||||||
# batch_size > 1, greedy
|
# batch_size > 1, greedy
|
||||||
self._check_generated_tokens(model.generate(input_ids, do_sample=False))
|
self._check_generated_tokens(model.generate(input_ids, do_sample=False))
|
||||||
|
|
||||||
# batch_size > 1, num_beams > 1, sample
|
# batch_size > 1, num_beams > 1, sample
|
||||||
self._check_generated_tokens(model.generate(input_ids, num_beams=3, num_return_sequences=3,))
|
self._check_generated_tokens(model.generate(input_ids, num_beams=3, num_return_sequences=3,))
|
||||||
# batch_size > 1, num_beams > 1, greedy
|
# batch_size > 1, num_beams > 1, greedy
|
||||||
@@ -704,3 +715,110 @@ class ModelUtilsTest(unittest.TestCase):
|
|||||||
self.assertEqual(model.config.output_attentions, True)
|
self.assertEqual(model.config.output_attentions, True)
|
||||||
self.assertEqual(model.config.output_hidden_states, True)
|
self.assertEqual(model.config.output_hidden_states, True)
|
||||||
self.assertEqual(model.config, config)
|
self.assertEqual(model.config, config)
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class UtilsFunctionsTest(unittest.TestCase):
|
||||||
|
|
||||||
|
# tests whether the top_k_top_p function behaves as expected
|
||||||
|
def test_top_k_top_p_filtering(self):
|
||||||
|
logits = torch.tensor(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
8.2220991, # 3rd highest value; idx. 0
|
||||||
|
-0.5620044,
|
||||||
|
5.23229752,
|
||||||
|
4.0386393,
|
||||||
|
-6.8798378,
|
||||||
|
-0.54785802,
|
||||||
|
-3.2012153,
|
||||||
|
2.92777176,
|
||||||
|
1.88171953,
|
||||||
|
7.35341276, # 5th highest value; idx. 9
|
||||||
|
8.43207833, # 2nd highest value; idx. 10
|
||||||
|
-9.85711836,
|
||||||
|
-5.96209236,
|
||||||
|
-1.13039161,
|
||||||
|
-7.1115294,
|
||||||
|
-0.8369633,
|
||||||
|
-5.3186408,
|
||||||
|
7.06427407,
|
||||||
|
0.81369344,
|
||||||
|
-0.82023817,
|
||||||
|
-5.9179796,
|
||||||
|
0.58813443,
|
||||||
|
-6.99778438,
|
||||||
|
4.71551189,
|
||||||
|
-0.18771637,
|
||||||
|
7.44020759, # 4th highest value; idx. 25
|
||||||
|
9.38450987, # 1st highest value; idx. 26
|
||||||
|
2.12662941,
|
||||||
|
-9.32562038,
|
||||||
|
2.35652522,
|
||||||
|
], # cummulative prob of 5 highest values <= 0.6
|
||||||
|
[
|
||||||
|
0.58425518,
|
||||||
|
4.53139238,
|
||||||
|
-5.57510464,
|
||||||
|
-6.28030699,
|
||||||
|
-7.19529503,
|
||||||
|
-4.02122551,
|
||||||
|
1.39337037,
|
||||||
|
-6.06707057,
|
||||||
|
1.59480517,
|
||||||
|
-9.643119,
|
||||||
|
0.03907799,
|
||||||
|
0.67231762,
|
||||||
|
-8.88206726,
|
||||||
|
6.27115922, # 4th highest value; idx. 13
|
||||||
|
2.28520723,
|
||||||
|
4.82767506,
|
||||||
|
4.30421368,
|
||||||
|
8.8275313, # 2nd highest value; idx. 17
|
||||||
|
5.44029958, # 5th highest value; idx. 18
|
||||||
|
-4.4735794,
|
||||||
|
7.38579536, # 3rd highest value; idx. 20
|
||||||
|
-2.91051663,
|
||||||
|
2.61946077,
|
||||||
|
-2.5674762,
|
||||||
|
-9.48959302,
|
||||||
|
-4.02922645,
|
||||||
|
-1.35416918,
|
||||||
|
9.67702323, # 1st highest value; idx. 27
|
||||||
|
-5.89478553,
|
||||||
|
1.85370467,
|
||||||
|
], # cummulative prob of 5 highest values <= 0.6
|
||||||
|
],
|
||||||
|
dtype=torch.float,
|
||||||
|
device=torch_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
non_inf_expected_idx = torch.tensor(
|
||||||
|
[[0, 0], [0, 9], [0, 10], [0, 25], [0, 26], [1, 13], [1, 17], [1, 18], [1, 20], [1, 27]],
|
||||||
|
dtype=torch.long,
|
||||||
|
device=torch_device,
|
||||||
|
) # expected non filtered idx as noted above
|
||||||
|
|
||||||
|
non_inf_expected_output = torch.tensor(
|
||||||
|
[
|
||||||
|
8.2221,
|
||||||
|
7.3534,
|
||||||
|
8.4321,
|
||||||
|
7.4402,
|
||||||
|
9.3845,
|
||||||
|
6.2712,
|
||||||
|
8.8275,
|
||||||
|
5.4403,
|
||||||
|
7.3858,
|
||||||
|
9.6770,
|
||||||
|
], # expected non filtered values as noted above
|
||||||
|
dtype=torch.float,
|
||||||
|
device=torch_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
output = top_k_top_p_filtering(logits, top_k=10, top_p=0.6, min_tokens_to_keep=4)
|
||||||
|
non_inf_output = output[output != -float("inf")].to(device=torch_device)
|
||||||
|
non_inf_idx = (output != -float("inf")).nonzero().to(device=torch_device)
|
||||||
|
|
||||||
|
self.assertTrue(torch.allclose(non_inf_expected_output, non_inf_output, atol=1e-12))
|
||||||
|
self.assertTrue(torch.all(torch.eq(non_inf_expected_idx, non_inf_idx)))
|
||||||
|
|||||||
@@ -386,33 +386,33 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
|
|||||||
@slow
|
@slow
|
||||||
def test_lm_generate_distilgpt2(self):
|
def test_lm_generate_distilgpt2(self):
|
||||||
model = GPT2LMHeadModel.from_pretrained("distilgpt2")
|
model = GPT2LMHeadModel.from_pretrained("distilgpt2")
|
||||||
input_ids = torch.Tensor([[464, 3290, 318, 13779]]).long() # The dog is cute
|
input_ids = torch.Tensor([[464, 1893]]).long() # The president
|
||||||
expected_output_ids = [
|
expected_output_ids = [
|
||||||
464,
|
464,
|
||||||
3290,
|
1893,
|
||||||
318,
|
286,
|
||||||
13779,
|
262,
|
||||||
996,
|
1578,
|
||||||
339,
|
1829,
|
||||||
460,
|
11,
|
||||||
3360,
|
290,
|
||||||
655,
|
262,
|
||||||
2513,
|
1893,
|
||||||
|
286,
|
||||||
|
262,
|
||||||
|
1578,
|
||||||
|
7526,
|
||||||
|
11,
|
||||||
|
423,
|
||||||
|
587,
|
||||||
287,
|
287,
|
||||||
262,
|
262,
|
||||||
3952,
|
2635,
|
||||||
13,
|
] # The president of the United States, and the president of the United Kingdom, have been in the White
|
||||||
632,
|
|
||||||
318,
|
|
||||||
407,
|
|
||||||
845,
|
|
||||||
3621,
|
|
||||||
284,
|
|
||||||
] # The dog is cute though he can sometimes just walk in the park. It is not very nice to
|
|
||||||
torch.manual_seed(0)
|
|
||||||
|
|
||||||
output_ids = model.generate(
|
output_ids = model.generate(
|
||||||
input_ids,
|
input_ids,
|
||||||
|
do_sample=False,
|
||||||
bos_token_id=self.special_tokens["bos_token_id"],
|
bos_token_id=self.special_tokens["bos_token_id"],
|
||||||
eos_token_ids=self.special_tokens["eos_token_id"],
|
eos_token_ids=self.special_tokens["eos_token_id"],
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import copy
|
|||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
|
||||||
from transformers import is_tf_available, is_torch_available
|
from transformers import is_tf_available, is_torch_available
|
||||||
|
|
||||||
@@ -28,6 +29,8 @@ if is_tf_available():
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from transformers import tf_top_k_top_p_filtering
|
||||||
|
|
||||||
if _tf_gpu_memory_limit is not None:
|
if _tf_gpu_memory_limit is not None:
|
||||||
gpus = tf.config.list_physical_devices("GPU")
|
gpus = tf.config.list_physical_devices("GPU")
|
||||||
for gpu in gpus:
|
for gpu in gpus:
|
||||||
@@ -56,6 +59,7 @@ class TFModelTesterMixin:
|
|||||||
|
|
||||||
model_tester = None
|
model_tester = None
|
||||||
all_model_classes = ()
|
all_model_classes = ()
|
||||||
|
all_generative_model_classes = ()
|
||||||
test_torchscript = True
|
test_torchscript = True
|
||||||
test_pruning = True
|
test_pruning = True
|
||||||
test_resize_embeddings = True
|
test_resize_embeddings = True
|
||||||
@@ -216,7 +220,7 @@ class TFModelTesterMixin:
|
|||||||
outputs_dict = model(inputs_dict)
|
outputs_dict = model(inputs_dict)
|
||||||
|
|
||||||
inputs_keywords = copy.deepcopy(inputs_dict)
|
inputs_keywords = copy.deepcopy(inputs_dict)
|
||||||
input_ids = inputs_keywords.pop("input_ids" if not self.is_encoder_decoder else "decoder_input_ids", None)
|
input_ids = inputs_keywords.pop("input_ids" if not self.is_encoder_decoder else "decoder_input_ids", None,)
|
||||||
outputs_keywords = model(input_ids, **inputs_keywords)
|
outputs_keywords = model(input_ids, **inputs_keywords)
|
||||||
|
|
||||||
output_dict = outputs_dict[0].numpy()
|
output_dict = outputs_dict[0].numpy()
|
||||||
@@ -299,7 +303,7 @@ class TFModelTesterMixin:
|
|||||||
self.assertEqual(model.config.output_hidden_states, True)
|
self.assertEqual(model.config.output_hidden_states, True)
|
||||||
self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
|
self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
list(hidden_states[0].shape[-2:]), [self.model_tester.seq_length, self.model_tester.hidden_size]
|
list(hidden_states[0].shape[-2:]), [self.model_tester.seq_length, self.model_tester.hidden_size],
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_model_common_attributes(self):
|
def test_model_common_attributes(self):
|
||||||
@@ -316,7 +320,10 @@ class TFModelTesterMixin:
|
|||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
first, second = model(inputs_dict, training=False)[0], model(inputs_dict, training=False)[0]
|
first, second = (
|
||||||
|
model(inputs_dict, training=False)[0],
|
||||||
|
model(inputs_dict, training=False)[0],
|
||||||
|
)
|
||||||
out_1 = first.numpy()
|
out_1 = first.numpy()
|
||||||
out_2 = second.numpy()
|
out_2 = second.numpy()
|
||||||
out_1 = out_1[~np.isnan(out_1)]
|
out_1 = out_1[~np.isnan(out_1)]
|
||||||
@@ -338,9 +345,9 @@ class TFModelTesterMixin:
|
|||||||
x = wte([input_ids, None, None, None], mode="embedding")
|
x = wte([input_ids, None, None, None], mode="embedding")
|
||||||
except Exception:
|
except Exception:
|
||||||
if hasattr(self.model_tester, "embedding_size"):
|
if hasattr(self.model_tester, "embedding_size"):
|
||||||
x = tf.ones(input_ids.shape + [self.model_tester.embedding_size], dtype=tf.dtypes.float32)
|
x = tf.ones(input_ids.shape + [self.model_tester.embedding_size], dtype=tf.dtypes.float32,)
|
||||||
else:
|
else:
|
||||||
x = tf.ones(input_ids.shape + [self.model_tester.hidden_size], dtype=tf.dtypes.float32)
|
x = tf.ones(input_ids.shape + [self.model_tester.hidden_size], dtype=tf.dtypes.float32,)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def test_inputs_embeds(self):
|
def test_inputs_embeds(self):
|
||||||
@@ -366,6 +373,37 @@ class TFModelTesterMixin:
|
|||||||
|
|
||||||
model(inputs_dict)
|
model(inputs_dict)
|
||||||
|
|
||||||
|
def test_lm_head_model_random_generate(self):
|
||||||
|
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
input_ids = inputs_dict.get(
|
||||||
|
"input_ids", None
|
||||||
|
) # TODO (PVP): ugly workaround to make code work for t5 for the moment - has to changed when t5 is fixed.
|
||||||
|
|
||||||
|
for model_class in self.all_generative_model_classes:
|
||||||
|
# TODO (PVP): add beam search tests when beam search is implemented
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
if config.bos_token_id is None:
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
model.generate(max_length=5)
|
||||||
|
# batch_size = 1
|
||||||
|
self._check_generated_tokens(model.generate(input_ids))
|
||||||
|
else:
|
||||||
|
# batch_size = 1
|
||||||
|
self._check_generated_tokens(model.generate(max_length=5))
|
||||||
|
# batch_size = 1, num_beams > 1
|
||||||
|
|
||||||
|
# batch_size > 1, sample
|
||||||
|
self._check_generated_tokens(model.generate(input_ids, num_return_sequences=3))
|
||||||
|
# batch_size > 1, greedy
|
||||||
|
self._check_generated_tokens(model.generate(input_ids, do_sample=False, num_return_sequences=3))
|
||||||
|
|
||||||
|
def _check_generated_tokens(self, output_ids):
|
||||||
|
for token_id in output_ids[0].numpy().tolist():
|
||||||
|
self.assertGreaterEqual(token_id, 0)
|
||||||
|
self.assertLess(token_id, self.model_tester.vocab_size)
|
||||||
|
|
||||||
|
|
||||||
def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=None):
|
def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=None):
|
||||||
"""Creates a random int32 tensor of the shape within the vocab size."""
|
"""Creates a random int32 tensor of the shape within the vocab size."""
|
||||||
@@ -383,3 +421,98 @@ def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=None):
|
|||||||
output = tf.constant(values, shape=shape, dtype=dtype if dtype is not None else tf.int32)
|
output = tf.constant(values, shape=shape, dtype=dtype if dtype is not None else tf.int32)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
class UtilsFunctionsTest(unittest.TestCase):
|
||||||
|
|
||||||
|
# tests whether the top_k_top_p_filtering function behaves as expected
|
||||||
|
def test_top_k_top_p_filtering(self):
|
||||||
|
logits = tf.convert_to_tensor(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
8.2220991, # 3rd highest value; idx. 0
|
||||||
|
-0.5620044,
|
||||||
|
5.23229752,
|
||||||
|
4.0386393,
|
||||||
|
-6.8798378,
|
||||||
|
-0.54785802,
|
||||||
|
-3.2012153,
|
||||||
|
2.92777176,
|
||||||
|
1.88171953,
|
||||||
|
7.35341276, # 5th highest value; idx. 9
|
||||||
|
8.43207833, # 2nd highest value; idx. 10
|
||||||
|
-9.85711836,
|
||||||
|
-5.96209236,
|
||||||
|
-1.13039161,
|
||||||
|
-7.1115294,
|
||||||
|
-0.8369633,
|
||||||
|
-5.3186408,
|
||||||
|
7.06427407,
|
||||||
|
0.81369344,
|
||||||
|
-0.82023817,
|
||||||
|
-5.9179796,
|
||||||
|
0.58813443,
|
||||||
|
-6.99778438,
|
||||||
|
4.71551189,
|
||||||
|
-0.18771637,
|
||||||
|
7.44020759, # 4th highest value; idx. 25
|
||||||
|
9.38450987, # 1st highest value; idx. 26
|
||||||
|
2.12662941,
|
||||||
|
-9.32562038,
|
||||||
|
2.35652522,
|
||||||
|
], # cummulative prob of 5 highest values <= 0.6
|
||||||
|
[
|
||||||
|
0.58425518,
|
||||||
|
4.53139238,
|
||||||
|
-5.57510464,
|
||||||
|
-6.28030699,
|
||||||
|
-7.19529503,
|
||||||
|
-4.02122551,
|
||||||
|
1.39337037,
|
||||||
|
-6.06707057,
|
||||||
|
1.59480517,
|
||||||
|
-9.643119,
|
||||||
|
0.03907799,
|
||||||
|
0.67231762,
|
||||||
|
-8.88206726,
|
||||||
|
6.27115922, # 4th highest value; idx. 13
|
||||||
|
2.28520723,
|
||||||
|
4.82767506,
|
||||||
|
4.30421368,
|
||||||
|
8.8275313, # 2nd highest value; idx. 17
|
||||||
|
5.44029958, # 5th highest value; idx. 18
|
||||||
|
-4.4735794,
|
||||||
|
7.38579536, # 3rd highest value; idx. 20
|
||||||
|
-2.91051663,
|
||||||
|
2.61946077,
|
||||||
|
-2.5674762,
|
||||||
|
-9.48959302,
|
||||||
|
-4.02922645,
|
||||||
|
-1.35416918,
|
||||||
|
9.67702323, # 1st highest value; idx. 27
|
||||||
|
-5.89478553,
|
||||||
|
1.85370467,
|
||||||
|
], # cummulative prob of 5 highest values <= 0.6
|
||||||
|
],
|
||||||
|
dtype=tf.float32,
|
||||||
|
)
|
||||||
|
|
||||||
|
non_inf_expected_idx = tf.convert_to_tensor(
|
||||||
|
[[0, 0], [0, 9], [0, 10], [0, 25], [0, 26], [1, 13], [1, 17], [1, 18], [1, 20], [1, 27]], dtype=tf.int32,
|
||||||
|
) # expected non filtered idx as noted above
|
||||||
|
|
||||||
|
non_inf_expected_output = tf.convert_to_tensor(
|
||||||
|
[8.222099, 7.3534126, 8.432078, 7.4402075, 9.38451, 6.271159, 8.827531, 5.4402995, 7.3857956, 9.677023],
|
||||||
|
dtype=tf.float32,
|
||||||
|
) # expected non filtered values as noted above
|
||||||
|
|
||||||
|
output = tf_top_k_top_p_filtering(logits, top_k=10, top_p=0.6, min_tokens_to_keep=4)
|
||||||
|
|
||||||
|
non_inf_output = output[output != -float("inf")]
|
||||||
|
non_inf_idx = tf.cast(
|
||||||
|
tf.where(tf.not_equal(output, tf.constant(-float("inf"), dtype=tf.float32))), dtype=tf.int32,
|
||||||
|
)
|
||||||
|
|
||||||
|
tf.debugging.assert_near(non_inf_output, non_inf_expected_output, rtol=1e-12)
|
||||||
|
tf.debugging.assert_equal(non_inf_idx, non_inf_expected_idx)
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ if is_tf_available():
|
|||||||
class TFCTRLModelTest(TFModelTesterMixin, unittest.TestCase):
|
class TFCTRLModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
all_model_classes = (TFCTRLModel, TFCTRLLMHeadModel) if is_tf_available() else ()
|
all_model_classes = (TFCTRLModel, TFCTRLLMHeadModel) if is_tf_available() else ()
|
||||||
|
all_generative_model_classes = (TFCTRLLMHeadModel,) if is_tf_available() else ()
|
||||||
|
|
||||||
class TFCTRLModelTester(object):
|
class TFCTRLModelTester(object):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ if is_tf_available():
|
|||||||
class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
|
class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
all_model_classes = (TFGPT2Model, TFGPT2LMHeadModel, TFGPT2DoubleHeadsModel) if is_tf_available() else ()
|
all_model_classes = (TFGPT2Model, TFGPT2LMHeadModel, TFGPT2DoubleHeadsModel) if is_tf_available() else ()
|
||||||
# all_model_classes = (TFGPT2Model, TFGPT2LMHeadModel) if is_tf_available() else ()
|
all_generative_model_classes = (TFGPT2LMHeadModel,) if is_tf_available() else ()
|
||||||
|
|
||||||
class TFGPT2ModelTester(object):
|
class TFGPT2ModelTester(object):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -89,6 +89,8 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
self.num_labels = num_labels
|
self.num_labels = num_labels
|
||||||
self.num_choices = num_choices
|
self.num_choices = num_choices
|
||||||
self.scope = scope
|
self.scope = scope
|
||||||
|
self.bos_token_id = vocab_size - 1
|
||||||
|
self.eos_token_id = vocab_size - 1
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
@@ -123,9 +125,11 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
# hidden_dropout_prob=self.hidden_dropout_prob,
|
# hidden_dropout_prob=self.hidden_dropout_prob,
|
||||||
# attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
# attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||||
n_positions=self.max_position_embeddings,
|
n_positions=self.max_position_embeddings,
|
||||||
n_ctx=self.max_position_embeddings
|
n_ctx=self.max_position_embeddings,
|
||||||
# type_vocab_size=self.type_vocab_size,
|
# type_vocab_size=self.type_vocab_size,
|
||||||
# initializer_range=self.initializer_range
|
# initializer_range=self.initializer_range
|
||||||
|
bos_token_id=self.bos_token_id,
|
||||||
|
eos_token_ids=self.eos_token_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
|
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
|
||||||
@@ -144,7 +148,11 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
def create_and_check_gpt2_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
def create_and_check_gpt2_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
||||||
model = TFGPT2Model(config=config)
|
model = TFGPT2Model(config=config)
|
||||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
inputs = {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": input_mask,
|
||||||
|
"token_type_ids": token_type_ids,
|
||||||
|
}
|
||||||
sequence_output = model(inputs)[0]
|
sequence_output = model(inputs)[0]
|
||||||
|
|
||||||
inputs = [input_ids, None, input_mask] # None is the input for 'past'
|
inputs = [input_ids, None, input_mask] # None is the input for 'past'
|
||||||
@@ -156,18 +164,22 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
"sequence_output": sequence_output.numpy(),
|
"sequence_output": sequence_output.numpy(),
|
||||||
}
|
}
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
list(result["sequence_output"].shape), [self.batch_size, self.seq_length, self.hidden_size]
|
list(result["sequence_output"].shape), [self.batch_size, self.seq_length, self.hidden_size],
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_and_check_gpt2_lm_head(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
def create_and_check_gpt2_lm_head(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
||||||
model = TFGPT2LMHeadModel(config=config)
|
model = TFGPT2LMHeadModel(config=config)
|
||||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
inputs = {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": input_mask,
|
||||||
|
"token_type_ids": token_type_ids,
|
||||||
|
}
|
||||||
prediction_scores = model(inputs)[0]
|
prediction_scores = model(inputs)[0]
|
||||||
result = {
|
result = {
|
||||||
"prediction_scores": prediction_scores.numpy(),
|
"prediction_scores": prediction_scores.numpy(),
|
||||||
}
|
}
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
list(result["prediction_scores"].shape), [self.batch_size, self.seq_length, self.vocab_size]
|
list(result["prediction_scores"].shape), [self.batch_size, self.seq_length, self.vocab_size],
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_and_check_gpt2_double_head(
|
def create_and_check_gpt2_double_head(
|
||||||
@@ -188,7 +200,7 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
lm_logits, mc_logits = model(inputs)[:2]
|
lm_logits, mc_logits = model(inputs)[:2]
|
||||||
result = {"lm_logits": lm_logits.numpy(), "mc_logits": mc_logits.numpy()}
|
result = {"lm_logits": lm_logits.numpy(), "mc_logits": mc_logits.numpy()}
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
list(result["lm_logits"].shape), [self.batch_size, self.num_choices, self.seq_length, self.vocab_size]
|
list(result["lm_logits"].shape), [self.batch_size, self.num_choices, self.seq_length, self.vocab_size],
|
||||||
)
|
)
|
||||||
self.parent.assertListEqual(list(result["mc_logits"].shape), [self.batch_size, self.num_choices])
|
self.parent.assertListEqual(list(result["mc_logits"].shape), [self.batch_size, self.num_choices])
|
||||||
|
|
||||||
@@ -207,7 +219,11 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
choice_labels,
|
choice_labels,
|
||||||
) = config_and_inputs
|
) = config_and_inputs
|
||||||
|
|
||||||
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
|
inputs_dict = {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"token_type_ids": token_type_ids,
|
||||||
|
"attention_mask": input_mask,
|
||||||
|
}
|
||||||
return config, inputs_dict
|
return config, inputs_dict
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@@ -234,3 +250,48 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
for model_name in list(TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
for model_name in list(TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||||
model = TFGPT2Model.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
model = TFGPT2Model.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_generation_special_tokens():
|
||||||
|
return {"bos_token_id": 50256, "eos_token_id": 50256}
|
||||||
|
|
||||||
|
|
||||||
|
class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||||
|
|
||||||
|
special_tokens = prepare_generation_special_tokens()
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_lm_generate_distilgpt2(self):
|
||||||
|
model = TFGPT2LMHeadModel.from_pretrained("distilgpt2")
|
||||||
|
input_ids = tf.convert_to_tensor([[464, 1893]], dtype=tf.int32) # The president
|
||||||
|
expected_output_ids = [
|
||||||
|
464,
|
||||||
|
1893,
|
||||||
|
286,
|
||||||
|
262,
|
||||||
|
1578,
|
||||||
|
1829,
|
||||||
|
11,
|
||||||
|
290,
|
||||||
|
262,
|
||||||
|
1893,
|
||||||
|
286,
|
||||||
|
262,
|
||||||
|
1578,
|
||||||
|
7526,
|
||||||
|
11,
|
||||||
|
423,
|
||||||
|
587,
|
||||||
|
287,
|
||||||
|
262,
|
||||||
|
2635,
|
||||||
|
] # The president of the United States, and the president of the United Kingdom, have been in the White
|
||||||
|
|
||||||
|
output_ids = model.generate(
|
||||||
|
input_ids,
|
||||||
|
do_sample=False,
|
||||||
|
bos_token_id=self.special_tokens["bos_token_id"],
|
||||||
|
eos_token_ids=self.special_tokens["eos_token_id"],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
|
||||||
|
|||||||
@@ -39,6 +39,9 @@ class TFOpenAIGPTModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(TFOpenAIGPTModel, TFOpenAIGPTLMHeadModel, TFOpenAIGPTDoubleHeadsModel) if is_tf_available() else ()
|
(TFOpenAIGPTModel, TFOpenAIGPTLMHeadModel, TFOpenAIGPTDoubleHeadsModel) if is_tf_available() else ()
|
||||||
)
|
)
|
||||||
|
all_generative_model_classes = (
|
||||||
|
(TFOpenAIGPTLMHeadModel,) if is_tf_available() else ()
|
||||||
|
) # TODO (PVP): Add Double HeadsModel when generate() function is changed accordingly
|
||||||
|
|
||||||
class TFOpenAIGPTModelTester(object):
|
class TFOpenAIGPTModelTester(object):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@@ -37,6 +37,8 @@ if is_tf_available():
|
|||||||
class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
|
class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
all_model_classes = (TFTransfoXLModel, TFTransfoXLLMHeadModel) if is_tf_available() else ()
|
all_model_classes = (TFTransfoXLModel, TFTransfoXLLMHeadModel) if is_tf_available() else ()
|
||||||
|
all_generative_model_classes = () if is_tf_available() else ()
|
||||||
|
# TODO: add this test when TFTransfoXLLMHead has a linear output layer implemented
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_torchscript = False
|
test_torchscript = False
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = False
|
||||||
@@ -62,6 +64,7 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
num_hidden_layers=5,
|
num_hidden_layers=5,
|
||||||
scope=None,
|
scope=None,
|
||||||
seed=1,
|
seed=1,
|
||||||
|
eos_token_id=0,
|
||||||
):
|
):
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
@@ -82,6 +85,7 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
self.num_hidden_layers = num_hidden_layers
|
self.num_hidden_layers = num_hidden_layers
|
||||||
self.scope = scope
|
self.scope = scope
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
|
self.eos_token_id = eos_token_id
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
@@ -103,6 +107,7 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
d_inner=self.d_inner,
|
d_inner=self.d_inner,
|
||||||
div_val=self.div_val,
|
div_val=self.div_val,
|
||||||
n_layer=self.num_hidden_layers,
|
n_layer=self.num_hidden_layers,
|
||||||
|
eos_token_ids=self.eos_token_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return (config, input_ids_1, input_ids_2, lm_labels)
|
return (config, input_ids_1, input_ids_2, lm_labels)
|
||||||
|
|||||||
@@ -43,6 +43,9 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
if is_tf_available()
|
if is_tf_available()
|
||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
|
all_generative_model_classes = (
|
||||||
|
(TFXLMWithLMHeadModel,) if is_tf_available() else ()
|
||||||
|
) # TODO (PVP): Check other models whether language generation is also applicable
|
||||||
|
|
||||||
class TFXLMModelTester(object):
|
class TFXLMModelTester(object):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -75,6 +78,7 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
summary_type="last",
|
summary_type="last",
|
||||||
use_proj=True,
|
use_proj=True,
|
||||||
scope=None,
|
scope=None,
|
||||||
|
bos_token_id=0,
|
||||||
):
|
):
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
@@ -105,6 +109,7 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
self.num_labels = num_labels
|
self.num_labels = num_labels
|
||||||
self.num_choices = num_choices
|
self.num_choices = num_choices
|
||||||
self.scope = scope
|
self.scope = scope
|
||||||
|
self.bos_token_id = bos_token_id
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
@@ -145,6 +150,7 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
initializer_range=self.initializer_range,
|
initializer_range=self.initializer_range,
|
||||||
summary_type=self.summary_type,
|
summary_type=self.summary_type,
|
||||||
use_proj=self.use_proj,
|
use_proj=self.use_proj,
|
||||||
|
bos_token_id=self.bos_token_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
|||||||
@@ -51,6 +51,9 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
if is_tf_available()
|
if is_tf_available()
|
||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
|
all_generative_model_classes = (
|
||||||
|
(TFXLNetLMHeadModel,) if is_tf_available() else ()
|
||||||
|
) # TODO (PVP): Check other models whether language generation is also applicable
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
|
|
||||||
class TFXLNetModelTester(object):
|
class TFXLNetModelTester(object):
|
||||||
@@ -77,6 +80,9 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
initializer_range=0.05,
|
initializer_range=0.05,
|
||||||
seed=1,
|
seed=1,
|
||||||
type_vocab_size=2,
|
type_vocab_size=2,
|
||||||
|
bos_token_id=1,
|
||||||
|
eos_token_id=2,
|
||||||
|
pad_token_id=5,
|
||||||
):
|
):
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
@@ -100,6 +106,9 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.type_vocab_size = type_vocab_size
|
self.type_vocab_size = type_vocab_size
|
||||||
self.type_sequence_label_size = type_sequence_label_size
|
self.type_sequence_label_size = type_sequence_label_size
|
||||||
|
self.bos_token_id = bos_token_id
|
||||||
|
self.pad_token_id = pad_token_id
|
||||||
|
self.eos_token_id = eos_token_id
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
@@ -139,6 +148,9 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
bi_data=self.bi_data,
|
bi_data=self.bi_data,
|
||||||
initializer_range=self.initializer_range,
|
initializer_range=self.initializer_range,
|
||||||
num_labels=self.type_sequence_label_size,
|
num_labels=self.type_sequence_label_size,
|
||||||
|
bos_token_id=self.bos_token_id,
|
||||||
|
pad_token_id=self.pad_token_id,
|
||||||
|
eos_token_id=self.eos_token_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
|||||||
Reference in New Issue
Block a user