TF generate refactor - XLA sample (#16713)
This commit is contained in:
@@ -346,6 +346,8 @@ class TFGenerationMixin:
|
|||||||
A class containing all of the functions supporting generation, to be used as a mixin in [`TFPreTrainedModel`].
|
A class containing all of the functions supporting generation, to be used as a mixin in [`TFPreTrainedModel`].
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
seed_generator = tf.random.Generator.from_non_deterministic_state()
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, inputs, **kwargs):
|
def prepare_inputs_for_generation(self, inputs, **kwargs):
|
||||||
"""
|
"""
|
||||||
Implement in subclasses of [`TFPreTrainedModel`] for custom behavior to prepare inputs in the generate method.
|
Implement in subclasses of [`TFPreTrainedModel`] for custom behavior to prepare inputs in the generate method.
|
||||||
@@ -585,6 +587,7 @@ class TFGenerationMixin:
|
|||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
decoder_start_token_id=decoder_start_token_id,
|
decoder_start_token_id=decoder_start_token_id,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
|
seed=model_kwargs.pop("seed", None),
|
||||||
output_scores=output_scores,
|
output_scores=output_scores,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -1288,6 +1291,7 @@ class TFGenerationMixin:
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
decoder_start_token_id=None,
|
decoder_start_token_id=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
|
seed=None,
|
||||||
output_scores=None,
|
output_scores=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
@@ -1365,6 +1369,9 @@ class TFGenerationMixin:
|
|||||||
use_cache (`bool`, *optional*, defaults to `True`):
|
use_cache (`bool`, *optional*, defaults to `True`):
|
||||||
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
|
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
|
||||||
speed up decoding.
|
speed up decoding.
|
||||||
|
seed (`List[int]`, *optional*):
|
||||||
|
Random seed to control sampling, containing two integers, used when `do_sample` is `True`. See the
|
||||||
|
`seed` argument from stateless functions in `tf.random`.
|
||||||
output_attentions (`bool`, *optional*, defaults to `False`):
|
output_attentions (`bool`, *optional*, defaults to `False`):
|
||||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||||
returned tensors for more details.
|
returned tensors for more details.
|
||||||
@@ -1590,6 +1597,7 @@ class TFGenerationMixin:
|
|||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
pad_token_id=pad_token_id,
|
pad_token_id=pad_token_id,
|
||||||
eos_token_id=eos_token_id,
|
eos_token_id=eos_token_id,
|
||||||
|
seed=seed,
|
||||||
output_scores=output_scores,
|
output_scores=output_scores,
|
||||||
return_dict_in_generate=return_dict_in_generate,
|
return_dict_in_generate=return_dict_in_generate,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
@@ -1723,7 +1731,7 @@ class TFGenerationMixin:
|
|||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
) -> Tuple[tf.Tensor, Dict[str, Any]]:
|
) -> Tuple[tf.Tensor, Dict[str, Any]]:
|
||||||
expanded_return_idx = tf.reshape(
|
expanded_return_idx = tf.reshape(
|
||||||
tf.tile(tf.reshape(tf.range(input_ids.shape[0]), (-1, 1)), (1, expand_size)), (-1)
|
tf.tile(tf.reshape(tf.range(input_ids.shape[0]), (-1, 1)), (1, expand_size)), (-1,)
|
||||||
)
|
)
|
||||||
input_ids = tf.gather(input_ids, expanded_return_idx, axis=0)
|
input_ids = tf.gather(input_ids, expanded_return_idx, axis=0)
|
||||||
|
|
||||||
@@ -2123,6 +2131,7 @@ class TFGenerationMixin:
|
|||||||
max_length: Optional[int] = None,
|
max_length: Optional[int] = None,
|
||||||
pad_token_id: Optional[int] = None,
|
pad_token_id: Optional[int] = None,
|
||||||
eos_token_id: Optional[int] = None,
|
eos_token_id: Optional[int] = None,
|
||||||
|
seed: Optional[Tuple[int, int]] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
output_scores: Optional[bool] = None,
|
output_scores: Optional[bool] = None,
|
||||||
@@ -2149,6 +2158,9 @@ class TFGenerationMixin:
|
|||||||
The id of the *padding* token.
|
The id of the *padding* token.
|
||||||
eos_token_id (`int`, *optional*):
|
eos_token_id (`int`, *optional*):
|
||||||
The id of the *end-of-sequence* token.
|
The id of the *end-of-sequence* token.
|
||||||
|
seed (`List[int]`, *optional*):
|
||||||
|
Random seed to control sampling, containing two integers, used when `do_sample` is `True`. See the
|
||||||
|
`seed` argument from stateless functions in `tf.random`.
|
||||||
output_attentions (`bool`, *optional*, defaults to `False`):
|
output_attentions (`bool`, *optional*, defaults to `False`):
|
||||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||||
returned tensors for more details.
|
returned tensors for more details.
|
||||||
@@ -2210,7 +2222,7 @@ class TFGenerationMixin:
|
|||||||
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
|
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
|
||||||
```"""
|
```"""
|
||||||
|
|
||||||
# init values
|
# 1. init greedy_search values
|
||||||
logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList()
|
logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList()
|
||||||
logits_warper = logits_warper if logits_warper is not None else TFLogitsProcessorList()
|
logits_warper = logits_warper if logits_warper is not None else TFLogitsProcessorList()
|
||||||
|
|
||||||
@@ -2224,97 +2236,155 @@ class TFGenerationMixin:
|
|||||||
return_dict_in_generate = (
|
return_dict_in_generate = (
|
||||||
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
|
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
|
||||||
)
|
)
|
||||||
|
use_xla = not tf.executing_eagerly()
|
||||||
|
|
||||||
# init attention / hidden states / scores tuples
|
# 2. init `attentions`, `hidden_states`, and `scores` tuples
|
||||||
scores = () if (return_dict_in_generate and output_scores) else None
|
scores = [] if (return_dict_in_generate and output_scores) else None
|
||||||
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
decoder_attentions = [] if (return_dict_in_generate and output_attentions) else None
|
||||||
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
cross_attentions = [] if (return_dict_in_generate and output_attentions) else None
|
||||||
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None
|
||||||
|
|
||||||
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
# 3. init tensors to use for "xla-compileable" generate function
|
||||||
if return_dict_in_generate and self.config.is_encoder_decoder:
|
# define bsz, seq_length
|
||||||
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
|
batch_size, cur_len = input_ids.shape
|
||||||
encoder_hidden_states = (
|
|
||||||
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
# initialize `generated`, `finished_sequences`
|
||||||
|
generated = tf.TensorArray(
|
||||||
|
element_shape=(batch_size,),
|
||||||
|
dtype=tf.int32,
|
||||||
|
dynamic_size=False,
|
||||||
|
size=max_length,
|
||||||
|
clear_after_read=False,
|
||||||
)
|
)
|
||||||
|
finished_sequences = tf.zeros((batch_size,), dtype=tf.bool)
|
||||||
|
|
||||||
# keep track of which sequences are already finished
|
# write prompt to generated
|
||||||
unfinished_sequences = tf.ones_like(input_ids[:, 0])
|
for i in range(cur_len):
|
||||||
cur_len = input_ids.shape[-1]
|
generated = generated.write(i, input_ids[:, i])
|
||||||
|
|
||||||
while cur_len < max_length:
|
# 4. define "xla-compile-able" stop-condition and auto-regressive function
|
||||||
# prepare model inputs
|
def sample_cond_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs):
|
||||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
return ~tf.reduce_all(finished_sequences)
|
||||||
|
|
||||||
# forward pass to get next token
|
def sample_body_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs):
|
||||||
|
# TODO(pvp, Joao) - `use_xla` can be removed here as soon as `position_ids` are corrected for the non-xla case in gpt2's `prepare_inputs_for_generation`.
|
||||||
|
model_inputs = self.prepare_inputs_for_generation(next_tokens, use_xla=use_xla, **model_kwargs)
|
||||||
|
# forward pass to get next token logits
|
||||||
outputs = self(
|
outputs = self(
|
||||||
**model_inputs,
|
**model_inputs,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
)
|
)
|
||||||
|
next_token_logits = outputs.logits[:, -1]
|
||||||
next_token_logits = outputs.logits[:, -1, :]
|
|
||||||
|
|
||||||
# pre-process distribution
|
|
||||||
next_token_scores = logits_processor(input_ids, next_token_logits, cur_len=cur_len)
|
|
||||||
next_token_scores = logits_warper(input_ids, next_token_scores)
|
|
||||||
|
|
||||||
# Store scores, attentions and hidden_states when required
|
# Store scores, attentions and hidden_states when required
|
||||||
if return_dict_in_generate:
|
if not use_xla and return_dict_in_generate:
|
||||||
if output_scores:
|
if output_scores:
|
||||||
scores += (next_token_scores,)
|
scores.append(next_token_logits)
|
||||||
if output_attentions:
|
if output_attentions and self.config.is_encoder_decoder:
|
||||||
decoder_attentions += (
|
decoder_attentions.append(outputs.decoder_attentions)
|
||||||
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
elif output_attentions and not self.config.is_encoder_decoder:
|
||||||
)
|
decoder_attentions.append(outputs.attentions)
|
||||||
if self.config.is_encoder_decoder:
|
if self.config.is_encoder_decoder:
|
||||||
cross_attentions += (outputs.cross_attentions,)
|
cross_attentions.append(outputs.cross_attentions)
|
||||||
|
|
||||||
if output_hidden_states:
|
if output_hidden_states and self.config.is_encoder_decoder:
|
||||||
decoder_hidden_states += (
|
decoder_hidden_states.append(outputs.decoder_hidden_states)
|
||||||
(outputs.decoder_hidden_states,)
|
elif output_hidden_states and self.config.is_encoder_decoder:
|
||||||
if self.config.is_encoder_decoder
|
decoder_hidden_states.append(outputs.hidden_states)
|
||||||
else (outputs.hidden_states,)
|
|
||||||
)
|
# pre-process distribution
|
||||||
|
# TODO(pvp, joao, matt) - all the logits processors/wrappers need to be adapted
|
||||||
|
# to be XLA compatible
|
||||||
|
input_ids = None
|
||||||
|
if not use_xla:
|
||||||
|
input_ids = tf.reshape(generated.concat(), (-1, batch_size))
|
||||||
|
input_ids = tf.transpose(input_ids[:cur_len])
|
||||||
|
next_tokens_scores = logits_processor(input_ids, next_token_logits, cur_len=cur_len)
|
||||||
|
next_tokens_scores = logits_warper(input_ids, next_tokens_scores)
|
||||||
|
|
||||||
# sample
|
# sample
|
||||||
|
if seed is not None:
|
||||||
|
sample_seed = seed
|
||||||
|
else:
|
||||||
|
sample_seed = tf.cast(self.seed_generator.make_seeds(count=1)[:, 0], dtype=tf.int32)
|
||||||
next_tokens = tf.squeeze(
|
next_tokens = tf.squeeze(
|
||||||
tf.random.categorical(logits=next_token_scores, num_samples=1, dtype=tf.int32), axis=1
|
tf.random.stateless_categorical(
|
||||||
|
logits=next_tokens_scores, num_samples=1, seed=sample_seed, dtype=tf.int32
|
||||||
|
),
|
||||||
|
axis=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# finished sentences should have their next token be a padding token
|
|
||||||
if eos_token_id is not None:
|
if eos_token_id is not None:
|
||||||
if pad_token_id is None:
|
if pad_token_id is None:
|
||||||
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
||||||
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
unfinished_seq = 1 - tf.cast(finished_sequences, tf.int32)
|
||||||
|
next_tokens = next_tokens * unfinished_seq + pad_token_id * (1 - unfinished_seq)
|
||||||
|
finished_sequences = finished_sequences | (next_tokens == eos_token_id)
|
||||||
|
|
||||||
# update generated ids, model inputs, and length for next step
|
# update `generated` and `cur_len`
|
||||||
input_ids = tf.concat([input_ids, next_tokens[:, None]], axis=-1)
|
generated = generated.write(cur_len, next_tokens)
|
||||||
|
next_tokens = next_tokens[:, None]
|
||||||
|
cur_len += 1
|
||||||
|
|
||||||
|
# update model_kwargs
|
||||||
|
if use_xla:
|
||||||
|
model_kwargs = self._update_model_kwargs_for_xla_generation(outputs, model_kwargs, cur_len, max_length)
|
||||||
|
else:
|
||||||
model_kwargs = self._update_model_kwargs_for_generation(
|
model_kwargs = self._update_model_kwargs_for_generation(
|
||||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||||
)
|
)
|
||||||
cur_len = cur_len + 1
|
# if we don't cache past key values we need the whole input
|
||||||
|
if model_kwargs.get("past", None) is None:
|
||||||
|
# let's throw out `past` since we don't want `None` tensors
|
||||||
|
model_kwargs.pop("past", None)
|
||||||
|
|
||||||
# if eos_token was found in one sentence, set sentence to finished
|
next_tokens = tf.reshape(generated.concat(), (-1, batch_size))
|
||||||
if eos_token_id is not None:
|
next_tokens = tf.transpose(next_tokens[:cur_len])
|
||||||
eos_in_sents = next_tokens == eos_token_id
|
|
||||||
# if sentence is unfinished and the token to add is eos
|
return generated, finished_sequences, next_tokens, cur_len, model_kwargs
|
||||||
is_sents_unfinished_and_token_to_add_is_eos = tf.math.multiply(
|
|
||||||
unfinished_sequences, tf.cast(eos_in_sents, tf.int32)
|
# 5. run generation
|
||||||
|
# 1st generation step has to be run before to initialize `past`
|
||||||
|
generated, finished_sequences, next_tokens, cur_len, model_kwargs = sample_body_fn(
|
||||||
|
generated, finished_sequences, input_ids, cur_len, model_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
# unfinished_sequences is set to zero if eos in sentence
|
# 2-to-n generation steps can then be run in autoregressive fashion
|
||||||
unfinished_sequences -= is_sents_unfinished_and_token_to_add_is_eos
|
# only in case 1st generation step does NOT yield EOS token though
|
||||||
|
if sample_cond_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs):
|
||||||
|
maximum_iterations = max_length - cur_len - 1
|
||||||
|
generated, _, _, cur_len, _ = tf.while_loop(
|
||||||
|
sample_cond_fn,
|
||||||
|
sample_body_fn,
|
||||||
|
(generated, finished_sequences, next_tokens, cur_len, model_kwargs),
|
||||||
|
maximum_iterations=maximum_iterations,
|
||||||
|
)
|
||||||
|
|
||||||
# stop when each sentence is finished, or if we exceed the maximum length
|
# 6. prepare outputs
|
||||||
if tf.math.reduce_max(unfinished_sequences) == 0:
|
output_ids = tf.transpose(tf.reshape(generated.concat(), (-1, batch_size)))
|
||||||
break
|
|
||||||
|
if not use_xla:
|
||||||
|
# cut for backward compatibility
|
||||||
|
output_ids = output_ids[:, :cur_len]
|
||||||
|
|
||||||
if return_dict_in_generate:
|
if return_dict_in_generate:
|
||||||
if self.config.is_encoder_decoder:
|
if self.config.is_encoder_decoder:
|
||||||
|
# if model is an encoder-decoder, retrieve encoder attention weights
|
||||||
|
# and hidden states
|
||||||
|
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
|
||||||
|
encoder_hidden_states = (
|
||||||
|
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
||||||
|
)
|
||||||
|
|
||||||
|
scores = tuple(scores) if scores is not None else None
|
||||||
|
decoder_attentions = tuple(decoder_attentions) if decoder_attentions is not None else None
|
||||||
|
cross_attentions = tuple(cross_attentions) if cross_attentions is not None else None
|
||||||
|
decoder_hidden_states = tuple(decoder_hidden_states) if decoder_hidden_states is not None else None
|
||||||
|
|
||||||
return TFSampleEncoderDecoderOutput(
|
return TFSampleEncoderDecoderOutput(
|
||||||
sequences=input_ids,
|
sequences=output_ids,
|
||||||
scores=scores,
|
scores=scores,
|
||||||
encoder_attentions=encoder_attentions,
|
encoder_attentions=encoder_attentions,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
@@ -2324,13 +2394,13 @@ class TFGenerationMixin:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return TFSampleDecoderOnlyOutput(
|
return TFSampleDecoderOnlyOutput(
|
||||||
sequences=input_ids,
|
sequences=output_ids,
|
||||||
scores=scores,
|
scores=scores,
|
||||||
attentions=decoder_attentions,
|
attentions=decoder_attentions,
|
||||||
hidden_states=decoder_hidden_states,
|
hidden_states=decoder_hidden_states,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return input_ids
|
return output_ids
|
||||||
|
|
||||||
def beam_search(
|
def beam_search(
|
||||||
self,
|
self,
|
||||||
@@ -2575,8 +2645,8 @@ class TFGenerationMixin:
|
|||||||
sequences,
|
sequences,
|
||||||
scores,
|
scores,
|
||||||
is_sent_finished,
|
is_sent_finished,
|
||||||
model_kwargs,
|
|
||||||
input_ids_length,
|
input_ids_length,
|
||||||
|
model_kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Beam Search termination condition function -- halts the generation loop if any of these conditions becomes
|
Beam Search termination condition function -- halts the generation loop if any of these conditions becomes
|
||||||
@@ -2604,8 +2674,8 @@ class TFGenerationMixin:
|
|||||||
sequences,
|
sequences,
|
||||||
scores,
|
scores,
|
||||||
is_sent_finished,
|
is_sent_finished,
|
||||||
model_kwargs,
|
|
||||||
input_ids_length,
|
input_ids_length,
|
||||||
|
model_kwargs,
|
||||||
intermediary_running_sequences=None,
|
intermediary_running_sequences=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -2781,8 +2851,8 @@ class TFGenerationMixin:
|
|||||||
next_sequences,
|
next_sequences,
|
||||||
next_scores,
|
next_scores,
|
||||||
next_is_sent_finished,
|
next_is_sent_finished,
|
||||||
next_model_kwargs,
|
|
||||||
next_input_ids_length,
|
next_input_ids_length,
|
||||||
|
next_model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 5. run generation
|
# 5. run generation
|
||||||
@@ -2799,8 +2869,8 @@ class TFGenerationMixin:
|
|||||||
sequences,
|
sequences,
|
||||||
scores,
|
scores,
|
||||||
is_sent_finished,
|
is_sent_finished,
|
||||||
model_kwargs,
|
|
||||||
input_ids_length,
|
input_ids_length,
|
||||||
|
model_kwargs,
|
||||||
) = beam_search_body_fn(
|
) = beam_search_body_fn(
|
||||||
cur_len,
|
cur_len,
|
||||||
running_sequences,
|
running_sequences,
|
||||||
@@ -2808,8 +2878,8 @@ class TFGenerationMixin:
|
|||||||
sequences,
|
sequences,
|
||||||
scores,
|
scores,
|
||||||
is_sent_finished,
|
is_sent_finished,
|
||||||
model_kwargs,
|
|
||||||
input_ids_length,
|
input_ids_length,
|
||||||
|
model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2-to-n generation steps can then be run in autoregressive fashion (only in case 1st generation step does
|
# 2-to-n generation steps can then be run in autoregressive fashion (only in case 1st generation step does
|
||||||
@@ -2821,8 +2891,8 @@ class TFGenerationMixin:
|
|||||||
sequences,
|
sequences,
|
||||||
scores,
|
scores,
|
||||||
is_sent_finished,
|
is_sent_finished,
|
||||||
model_kwargs,
|
|
||||||
input_ids_length,
|
input_ids_length,
|
||||||
|
model_kwargs,
|
||||||
):
|
):
|
||||||
maximum_iterations = max_length - cur_len
|
maximum_iterations = max_length - cur_len
|
||||||
cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, _, _ = tf.while_loop(
|
cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, _, _ = tf.while_loop(
|
||||||
@@ -2835,8 +2905,8 @@ class TFGenerationMixin:
|
|||||||
sequences,
|
sequences,
|
||||||
scores,
|
scores,
|
||||||
is_sent_finished,
|
is_sent_finished,
|
||||||
model_kwargs,
|
|
||||||
input_ids_length,
|
input_ids_length,
|
||||||
|
model_kwargs,
|
||||||
),
|
),
|
||||||
maximum_iterations=maximum_iterations,
|
maximum_iterations=maximum_iterations,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -447,19 +447,6 @@ class TFGPT2ModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC
|
|||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||||
@slow
|
|
||||||
def test_lm_generate_distilgpt2(self):
|
|
||||||
model = TFGPT2LMHeadModel.from_pretrained("distilgpt2")
|
|
||||||
input_ids = tf.convert_to_tensor([[464, 1893]], dtype=tf.int32) # The president
|
|
||||||
|
|
||||||
# The president of the United States, and the president of the United Kingdom, have been in the White
|
|
||||||
# fmt: off
|
|
||||||
expected_output_ids = [464, 1893, 286, 262, 1578, 1829, 11, 290, 262, 1893, 286, 262, 1578, 7526, 11, 423, 587, 287, 262, 2635]
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
output_ids = model.generate(input_ids, do_sample=False)
|
|
||||||
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_lm_generate_greedy_distilgpt2_batch_special(self):
|
def test_lm_generate_greedy_distilgpt2_batch_special(self):
|
||||||
model = TFGPT2LMHeadModel.from_pretrained("distilgpt2")
|
model = TFGPT2LMHeadModel.from_pretrained("distilgpt2")
|
||||||
@@ -506,18 +493,18 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
|||||||
"temperature": 1.5,
|
"temperature": 1.5,
|
||||||
"top_k": 500,
|
"top_k": 500,
|
||||||
"top_p": 0.9,
|
"top_p": 0.9,
|
||||||
|
"seed": [42, 0], # seed set -> deterministic sampling sequence -> deterministic generation
|
||||||
}
|
}
|
||||||
|
|
||||||
# forces the generation to happen on CPU, to avoid GPU-related quirks
|
# forces the generation to happen on CPU, to avoid GPU-related quirks
|
||||||
with tf.device(":/CPU:0"):
|
with tf.device(":/CPU:0"):
|
||||||
tf.random.set_seed(42) # deterministic sampling sequence -> deterministic generation
|
|
||||||
output_ids = model.generate(input_ids, **generation_kwargs)
|
output_ids = model.generate(input_ids, **generation_kwargs)
|
||||||
|
|
||||||
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||||
|
|
||||||
expected_output_string = [
|
expected_output_string = [
|
||||||
"Today is a beautiful day and this makes finding holiday travel easier for you to do other project\nOh",
|
"Today is a beautiful day and we will make you feel very hot/terrific in all",
|
||||||
"Yesterday was an enjoyable but especially great note though it certainly upset many Democrats who say",
|
"Yesterday was another solid success as news coverage became standard American domestic television hit.",
|
||||||
]
|
]
|
||||||
self.assertListEqual(output_strings, expected_output_string)
|
self.assertListEqual(output_strings, expected_output_string)
|
||||||
|
|
||||||
@@ -561,7 +548,7 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
|||||||
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
|
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_lm_generate_gpt2_xla(self):
|
def test_lm_generate_gpt2_xla_greedy(self):
|
||||||
"""This test gives the exact same results as the non-xla test above"""
|
"""This test gives the exact same results as the non-xla test above"""
|
||||||
model = TFGPT2LMHeadModel.from_pretrained("gpt2")
|
model = TFGPT2LMHeadModel.from_pretrained("gpt2")
|
||||||
input_ids = tf.convert_to_tensor([[464, 3290]], dtype=tf.int32) # The dog
|
input_ids = tf.convert_to_tensor([[464, 3290]], dtype=tf.int32) # The dog
|
||||||
@@ -574,3 +561,16 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
|||||||
|
|
||||||
output_ids = xla_generate(input_ids, do_sample=False)
|
output_ids = xla_generate(input_ids, do_sample=False)
|
||||||
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
|
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_lm_generate_gpt2_xla_sample(self):
|
||||||
|
model = TFGPT2LMHeadModel.from_pretrained("gpt2")
|
||||||
|
input_ids = tf.convert_to_tensor([[464, 3290]], dtype=tf.int32) # The dog
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
expected_output_ids = [464, 3290, 550, 284, 307, 4376, 287, 281, 4044, 1363, 329, 734, 812, 878, 852, 4376, 757, 329, 2267, 0]
|
||||||
|
# fmt: on
|
||||||
|
xla_generate = tf.function(model.generate, jit_compile=True)
|
||||||
|
|
||||||
|
output_ids = xla_generate(input_ids, do_sample=True, seed=[42, 0])
|
||||||
|
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
|
||||||
|
|||||||
@@ -524,6 +524,35 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertListEqual(expected_output_string, output_strings)
|
self.assertListEqual(expected_output_string, output_strings)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_sample_xla_generate_simple(self):
|
||||||
|
model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
|
||||||
|
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||||
|
|
||||||
|
sentence = "Translate English to German: Today is a beautiful day."
|
||||||
|
input_ids = tokenizer(sentence, return_tensors="tf", padding=True).input_ids
|
||||||
|
# XLA reorder ops, which causes operations like FP matmul to have slightly different results, causing
|
||||||
|
# divergences in generate -- especially with sampling.
|
||||||
|
expected_output_string = ["Heute ist ein schöner Tag."]
|
||||||
|
expected_output_string_xla = ["Heute ist ein schöne Tage."]
|
||||||
|
# However, notice that the first tokens are the same, for the same seed
|
||||||
|
assert expected_output_string[0][:15] == expected_output_string_xla[0][:15]
|
||||||
|
|
||||||
|
# forces the generation to happen on CPU, to avoid GPU-related quirks
|
||||||
|
with tf.device(":/CPU:0"):
|
||||||
|
# seed set -> deterministic sampling sequence -> deterministic generation
|
||||||
|
output_ids = model.generate(input_ids, do_sample=True, seed=[42, 0])
|
||||||
|
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||||
|
self.assertListEqual(expected_output_string, output_strings)
|
||||||
|
|
||||||
|
# forces the generation to happen on CPU, to avoid GPU-related quirks
|
||||||
|
with tf.device(":/CPU:0"):
|
||||||
|
xla_generate = tf.function(model.generate, jit_compile=True)
|
||||||
|
# seed set -> deterministic sampling sequence -> deterministic generation
|
||||||
|
output_ids_xla = xla_generate(input_ids, do_sample=True, seed=[42, 0])
|
||||||
|
output_strings_xla = tokenizer.batch_decode(output_ids_xla, skip_special_tokens=True)
|
||||||
|
self.assertListEqual(expected_output_string_xla, output_strings_xla)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_sample_generate(self):
|
def test_sample_generate(self):
|
||||||
model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
|
model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
|
||||||
@@ -540,16 +569,16 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
|
|||||||
"temperature": 0.8,
|
"temperature": 0.8,
|
||||||
"top_k": 500,
|
"top_k": 500,
|
||||||
"top_p": 0.9,
|
"top_p": 0.9,
|
||||||
|
"seed": [20, 0], # seed set -> deterministic sampling sequence -> deterministic generation
|
||||||
}
|
}
|
||||||
|
|
||||||
# forces the generation to happen on CPU, to avoid GPU-related quirks
|
# forces the generation to happen on CPU, to avoid GPU-related quirks
|
||||||
with tf.device(":/CPU:0"):
|
with tf.device(":/CPU:0"):
|
||||||
tf.random.set_seed(42) # deterministic sampling sequence -> deterministic generation
|
|
||||||
output_ids = model.generate(input_ids, **generation_kwargs)
|
output_ids = model.generate(input_ids, **generation_kwargs)
|
||||||
|
|
||||||
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||||
|
|
||||||
expected_output_string = ["i love her I really love my heart", "die Transformatoren sind wirklich erstaunlich"]
|
expected_output_string = ["- I really love my way of this.", "die Transformatoren sind wirklich erstaunlich"]
|
||||||
|
|
||||||
self.assertListEqual(expected_output_string, output_strings)
|
self.assertListEqual(expected_output_string, output_strings)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user