TF generate refactor - XLA sample (#16713)
This commit is contained in:
@@ -346,6 +346,8 @@ class TFGenerationMixin:
|
||||
A class containing all of the functions supporting generation, to be used as a mixin in [`TFPreTrainedModel`].
|
||||
"""
|
||||
|
||||
seed_generator = tf.random.Generator.from_non_deterministic_state()
|
||||
|
||||
def prepare_inputs_for_generation(self, inputs, **kwargs):
|
||||
"""
|
||||
Implement in subclasses of [`TFPreTrainedModel`] for custom behavior to prepare inputs in the generate method.
|
||||
@@ -585,6 +587,7 @@ class TFGenerationMixin:
|
||||
attention_mask=attention_mask,
|
||||
decoder_start_token_id=decoder_start_token_id,
|
||||
use_cache=use_cache,
|
||||
seed=model_kwargs.pop("seed", None),
|
||||
output_scores=output_scores,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
@@ -1288,6 +1291,7 @@ class TFGenerationMixin:
|
||||
attention_mask=None,
|
||||
decoder_start_token_id=None,
|
||||
use_cache=None,
|
||||
seed=None,
|
||||
output_scores=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
@@ -1365,6 +1369,9 @@ class TFGenerationMixin:
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
|
||||
speed up decoding.
|
||||
seed (`List[int]`, *optional*):
|
||||
Random seed to control sampling, containing two integers, used when `do_sample` is `True`. See the
|
||||
`seed` argument from stateless functions in `tf.random`.
|
||||
output_attentions (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more details.
|
||||
@@ -1590,6 +1597,7 @@ class TFGenerationMixin:
|
||||
max_length=max_length,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
seed=seed,
|
||||
output_scores=output_scores,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
**model_kwargs,
|
||||
@@ -1723,7 +1731,7 @@ class TFGenerationMixin:
|
||||
**model_kwargs,
|
||||
) -> Tuple[tf.Tensor, Dict[str, Any]]:
|
||||
expanded_return_idx = tf.reshape(
|
||||
tf.tile(tf.reshape(tf.range(input_ids.shape[0]), (-1, 1)), (1, expand_size)), (-1)
|
||||
tf.tile(tf.reshape(tf.range(input_ids.shape[0]), (-1, 1)), (1, expand_size)), (-1,)
|
||||
)
|
||||
input_ids = tf.gather(input_ids, expanded_return_idx, axis=0)
|
||||
|
||||
@@ -2123,6 +2131,7 @@ class TFGenerationMixin:
|
||||
max_length: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
seed: Optional[Tuple[int, int]] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_scores: Optional[bool] = None,
|
||||
@@ -2149,6 +2158,9 @@ class TFGenerationMixin:
|
||||
The id of the *padding* token.
|
||||
eos_token_id (`int`, *optional*):
|
||||
The id of the *end-of-sequence* token.
|
||||
seed (`List[int]`, *optional*):
|
||||
Random seed to control sampling, containing two integers, used when `do_sample` is `True`. See the
|
||||
`seed` argument from stateless functions in `tf.random`.
|
||||
output_attentions (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more details.
|
||||
@@ -2210,7 +2222,7 @@ class TFGenerationMixin:
|
||||
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
|
||||
```"""
|
||||
|
||||
# init values
|
||||
# 1. init greedy_search values
|
||||
logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList()
|
||||
logits_warper = logits_warper if logits_warper is not None else TFLogitsProcessorList()
|
||||
|
||||
@@ -2224,97 +2236,155 @@ class TFGenerationMixin:
|
||||
return_dict_in_generate = (
|
||||
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
|
||||
)
|
||||
use_xla = not tf.executing_eagerly()
|
||||
|
||||
# init attention / hidden states / scores tuples
|
||||
scores = () if (return_dict_in_generate and output_scores) else None
|
||||
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
||||
# 2. init `attentions`, `hidden_states`, and `scores` tuples
|
||||
scores = [] if (return_dict_in_generate and output_scores) else None
|
||||
decoder_attentions = [] if (return_dict_in_generate and output_attentions) else None
|
||||
cross_attentions = [] if (return_dict_in_generate and output_attentions) else None
|
||||
decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None
|
||||
|
||||
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
||||
if return_dict_in_generate and self.config.is_encoder_decoder:
|
||||
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
|
||||
encoder_hidden_states = (
|
||||
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
||||
# 3. init tensors to use for "xla-compileable" generate function
|
||||
# define bsz, seq_length
|
||||
batch_size, cur_len = input_ids.shape
|
||||
|
||||
# initialize `generated`, `finished_sequences`
|
||||
generated = tf.TensorArray(
|
||||
element_shape=(batch_size,),
|
||||
dtype=tf.int32,
|
||||
dynamic_size=False,
|
||||
size=max_length,
|
||||
clear_after_read=False,
|
||||
)
|
||||
finished_sequences = tf.zeros((batch_size,), dtype=tf.bool)
|
||||
|
||||
# keep track of which sequences are already finished
|
||||
unfinished_sequences = tf.ones_like(input_ids[:, 0])
|
||||
cur_len = input_ids.shape[-1]
|
||||
# write prompt to generated
|
||||
for i in range(cur_len):
|
||||
generated = generated.write(i, input_ids[:, i])
|
||||
|
||||
while cur_len < max_length:
|
||||
# prepare model inputs
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
# 4. define "xla-compile-able" stop-condition and auto-regressive function
|
||||
def sample_cond_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs):
|
||||
return ~tf.reduce_all(finished_sequences)
|
||||
|
||||
# forward pass to get next token
|
||||
def sample_body_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs):
|
||||
# TODO(pvp, Joao) - `use_xla` can be removed here as soon as `position_ids` are corrected for the non-xla case in gpt2's `prepare_inputs_for_generation`.
|
||||
model_inputs = self.prepare_inputs_for_generation(next_tokens, use_xla=use_xla, **model_kwargs)
|
||||
# forward pass to get next token logits
|
||||
outputs = self(
|
||||
**model_inputs,
|
||||
return_dict=True,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
next_token_logits = outputs.logits[:, -1, :]
|
||||
|
||||
# pre-process distribution
|
||||
next_token_scores = logits_processor(input_ids, next_token_logits, cur_len=cur_len)
|
||||
next_token_scores = logits_warper(input_ids, next_token_scores)
|
||||
next_token_logits = outputs.logits[:, -1]
|
||||
|
||||
# Store scores, attentions and hidden_states when required
|
||||
if return_dict_in_generate:
|
||||
if not use_xla and return_dict_in_generate:
|
||||
if output_scores:
|
||||
scores += (next_token_scores,)
|
||||
if output_attentions:
|
||||
decoder_attentions += (
|
||||
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
||||
)
|
||||
scores.append(next_token_logits)
|
||||
if output_attentions and self.config.is_encoder_decoder:
|
||||
decoder_attentions.append(outputs.decoder_attentions)
|
||||
elif output_attentions and not self.config.is_encoder_decoder:
|
||||
decoder_attentions.append(outputs.attentions)
|
||||
if self.config.is_encoder_decoder:
|
||||
cross_attentions += (outputs.cross_attentions,)
|
||||
cross_attentions.append(outputs.cross_attentions)
|
||||
|
||||
if output_hidden_states:
|
||||
decoder_hidden_states += (
|
||||
(outputs.decoder_hidden_states,)
|
||||
if self.config.is_encoder_decoder
|
||||
else (outputs.hidden_states,)
|
||||
)
|
||||
if output_hidden_states and self.config.is_encoder_decoder:
|
||||
decoder_hidden_states.append(outputs.decoder_hidden_states)
|
||||
elif output_hidden_states and self.config.is_encoder_decoder:
|
||||
decoder_hidden_states.append(outputs.hidden_states)
|
||||
|
||||
# pre-process distribution
|
||||
# TODO(pvp, joao, matt) - all the logits processors/wrappers need to be adapted
|
||||
# to be XLA compatible
|
||||
input_ids = None
|
||||
if not use_xla:
|
||||
input_ids = tf.reshape(generated.concat(), (-1, batch_size))
|
||||
input_ids = tf.transpose(input_ids[:cur_len])
|
||||
next_tokens_scores = logits_processor(input_ids, next_token_logits, cur_len=cur_len)
|
||||
next_tokens_scores = logits_warper(input_ids, next_tokens_scores)
|
||||
|
||||
# sample
|
||||
if seed is not None:
|
||||
sample_seed = seed
|
||||
else:
|
||||
sample_seed = tf.cast(self.seed_generator.make_seeds(count=1)[:, 0], dtype=tf.int32)
|
||||
next_tokens = tf.squeeze(
|
||||
tf.random.categorical(logits=next_token_scores, num_samples=1, dtype=tf.int32), axis=1
|
||||
tf.random.stateless_categorical(
|
||||
logits=next_tokens_scores, num_samples=1, seed=sample_seed, dtype=tf.int32
|
||||
),
|
||||
axis=1,
|
||||
)
|
||||
|
||||
# finished sentences should have their next token be a padding token
|
||||
if eos_token_id is not None:
|
||||
if pad_token_id is None:
|
||||
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
||||
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
||||
unfinished_seq = 1 - tf.cast(finished_sequences, tf.int32)
|
||||
next_tokens = next_tokens * unfinished_seq + pad_token_id * (1 - unfinished_seq)
|
||||
finished_sequences = finished_sequences | (next_tokens == eos_token_id)
|
||||
|
||||
# update generated ids, model inputs, and length for next step
|
||||
input_ids = tf.concat([input_ids, next_tokens[:, None]], axis=-1)
|
||||
# update `generated` and `cur_len`
|
||||
generated = generated.write(cur_len, next_tokens)
|
||||
next_tokens = next_tokens[:, None]
|
||||
cur_len += 1
|
||||
|
||||
# update model_kwargs
|
||||
if use_xla:
|
||||
model_kwargs = self._update_model_kwargs_for_xla_generation(outputs, model_kwargs, cur_len, max_length)
|
||||
else:
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||
)
|
||||
cur_len = cur_len + 1
|
||||
# if we don't cache past key values we need the whole input
|
||||
if model_kwargs.get("past", None) is None:
|
||||
# let's throw out `past` since we don't want `None` tensors
|
||||
model_kwargs.pop("past", None)
|
||||
|
||||
# if eos_token was found in one sentence, set sentence to finished
|
||||
if eos_token_id is not None:
|
||||
eos_in_sents = next_tokens == eos_token_id
|
||||
# if sentence is unfinished and the token to add is eos
|
||||
is_sents_unfinished_and_token_to_add_is_eos = tf.math.multiply(
|
||||
unfinished_sequences, tf.cast(eos_in_sents, tf.int32)
|
||||
next_tokens = tf.reshape(generated.concat(), (-1, batch_size))
|
||||
next_tokens = tf.transpose(next_tokens[:cur_len])
|
||||
|
||||
return generated, finished_sequences, next_tokens, cur_len, model_kwargs
|
||||
|
||||
# 5. run generation
|
||||
# 1st generation step has to be run before to initialize `past`
|
||||
generated, finished_sequences, next_tokens, cur_len, model_kwargs = sample_body_fn(
|
||||
generated, finished_sequences, input_ids, cur_len, model_kwargs
|
||||
)
|
||||
|
||||
# unfinished_sequences is set to zero if eos in sentence
|
||||
unfinished_sequences -= is_sents_unfinished_and_token_to_add_is_eos
|
||||
# 2-to-n generation steps can then be run in autoregressive fashion
|
||||
# only in case 1st generation step does NOT yield EOS token though
|
||||
if sample_cond_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs):
|
||||
maximum_iterations = max_length - cur_len - 1
|
||||
generated, _, _, cur_len, _ = tf.while_loop(
|
||||
sample_cond_fn,
|
||||
sample_body_fn,
|
||||
(generated, finished_sequences, next_tokens, cur_len, model_kwargs),
|
||||
maximum_iterations=maximum_iterations,
|
||||
)
|
||||
|
||||
# stop when each sentence is finished, or if we exceed the maximum length
|
||||
if tf.math.reduce_max(unfinished_sequences) == 0:
|
||||
break
|
||||
# 6. prepare outputs
|
||||
output_ids = tf.transpose(tf.reshape(generated.concat(), (-1, batch_size)))
|
||||
|
||||
if not use_xla:
|
||||
# cut for backward compatibility
|
||||
output_ids = output_ids[:, :cur_len]
|
||||
|
||||
if return_dict_in_generate:
|
||||
if self.config.is_encoder_decoder:
|
||||
# if model is an encoder-decoder, retrieve encoder attention weights
|
||||
# and hidden states
|
||||
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
|
||||
encoder_hidden_states = (
|
||||
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
||||
)
|
||||
|
||||
scores = tuple(scores) if scores is not None else None
|
||||
decoder_attentions = tuple(decoder_attentions) if decoder_attentions is not None else None
|
||||
cross_attentions = tuple(cross_attentions) if cross_attentions is not None else None
|
||||
decoder_hidden_states = tuple(decoder_hidden_states) if decoder_hidden_states is not None else None
|
||||
|
||||
return TFSampleEncoderDecoderOutput(
|
||||
sequences=input_ids,
|
||||
sequences=output_ids,
|
||||
scores=scores,
|
||||
encoder_attentions=encoder_attentions,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
@@ -2324,13 +2394,13 @@ class TFGenerationMixin:
|
||||
)
|
||||
else:
|
||||
return TFSampleDecoderOnlyOutput(
|
||||
sequences=input_ids,
|
||||
sequences=output_ids,
|
||||
scores=scores,
|
||||
attentions=decoder_attentions,
|
||||
hidden_states=decoder_hidden_states,
|
||||
)
|
||||
else:
|
||||
return input_ids
|
||||
return output_ids
|
||||
|
||||
def beam_search(
|
||||
self,
|
||||
@@ -2575,8 +2645,8 @@ class TFGenerationMixin:
|
||||
sequences,
|
||||
scores,
|
||||
is_sent_finished,
|
||||
model_kwargs,
|
||||
input_ids_length,
|
||||
model_kwargs,
|
||||
):
|
||||
"""
|
||||
Beam Search termination condition function -- halts the generation loop if any of these conditions becomes
|
||||
@@ -2604,8 +2674,8 @@ class TFGenerationMixin:
|
||||
sequences,
|
||||
scores,
|
||||
is_sent_finished,
|
||||
model_kwargs,
|
||||
input_ids_length,
|
||||
model_kwargs,
|
||||
intermediary_running_sequences=None,
|
||||
):
|
||||
"""
|
||||
@@ -2781,8 +2851,8 @@ class TFGenerationMixin:
|
||||
next_sequences,
|
||||
next_scores,
|
||||
next_is_sent_finished,
|
||||
next_model_kwargs,
|
||||
next_input_ids_length,
|
||||
next_model_kwargs,
|
||||
)
|
||||
|
||||
# 5. run generation
|
||||
@@ -2799,8 +2869,8 @@ class TFGenerationMixin:
|
||||
sequences,
|
||||
scores,
|
||||
is_sent_finished,
|
||||
model_kwargs,
|
||||
input_ids_length,
|
||||
model_kwargs,
|
||||
) = beam_search_body_fn(
|
||||
cur_len,
|
||||
running_sequences,
|
||||
@@ -2808,8 +2878,8 @@ class TFGenerationMixin:
|
||||
sequences,
|
||||
scores,
|
||||
is_sent_finished,
|
||||
model_kwargs,
|
||||
input_ids_length,
|
||||
model_kwargs,
|
||||
)
|
||||
|
||||
# 2-to-n generation steps can then be run in autoregressive fashion (only in case 1st generation step does
|
||||
@@ -2821,8 +2891,8 @@ class TFGenerationMixin:
|
||||
sequences,
|
||||
scores,
|
||||
is_sent_finished,
|
||||
model_kwargs,
|
||||
input_ids_length,
|
||||
model_kwargs,
|
||||
):
|
||||
maximum_iterations = max_length - cur_len
|
||||
cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, _, _ = tf.while_loop(
|
||||
@@ -2835,8 +2905,8 @@ class TFGenerationMixin:
|
||||
sequences,
|
||||
scores,
|
||||
is_sent_finished,
|
||||
model_kwargs,
|
||||
input_ids_length,
|
||||
model_kwargs,
|
||||
),
|
||||
maximum_iterations=maximum_iterations,
|
||||
)
|
||||
|
||||
@@ -447,19 +447,6 @@ class TFGPT2ModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC
|
||||
|
||||
@require_tf
|
||||
class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_lm_generate_distilgpt2(self):
|
||||
model = TFGPT2LMHeadModel.from_pretrained("distilgpt2")
|
||||
input_ids = tf.convert_to_tensor([[464, 1893]], dtype=tf.int32) # The president
|
||||
|
||||
# The president of the United States, and the president of the United Kingdom, have been in the White
|
||||
# fmt: off
|
||||
expected_output_ids = [464, 1893, 286, 262, 1578, 1829, 11, 290, 262, 1893, 286, 262, 1578, 7526, 11, 423, 587, 287, 262, 2635]
|
||||
# fmt: on
|
||||
|
||||
output_ids = model.generate(input_ids, do_sample=False)
|
||||
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
|
||||
|
||||
@slow
|
||||
def test_lm_generate_greedy_distilgpt2_batch_special(self):
|
||||
model = TFGPT2LMHeadModel.from_pretrained("distilgpt2")
|
||||
@@ -506,18 +493,18 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||
"temperature": 1.5,
|
||||
"top_k": 500,
|
||||
"top_p": 0.9,
|
||||
"seed": [42, 0], # seed set -> deterministic sampling sequence -> deterministic generation
|
||||
}
|
||||
|
||||
# forces the generation to happen on CPU, to avoid GPU-related quirks
|
||||
with tf.device(":/CPU:0"):
|
||||
tf.random.set_seed(42) # deterministic sampling sequence -> deterministic generation
|
||||
output_ids = model.generate(input_ids, **generation_kwargs)
|
||||
|
||||
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
|
||||
expected_output_string = [
|
||||
"Today is a beautiful day and this makes finding holiday travel easier for you to do other project\nOh",
|
||||
"Yesterday was an enjoyable but especially great note though it certainly upset many Democrats who say",
|
||||
"Today is a beautiful day and we will make you feel very hot/terrific in all",
|
||||
"Yesterday was another solid success as news coverage became standard American domestic television hit.",
|
||||
]
|
||||
self.assertListEqual(output_strings, expected_output_string)
|
||||
|
||||
@@ -561,7 +548,7 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
|
||||
|
||||
@slow
|
||||
def test_lm_generate_gpt2_xla(self):
|
||||
def test_lm_generate_gpt2_xla_greedy(self):
|
||||
"""This test gives the exact same results as the non-xla test above"""
|
||||
model = TFGPT2LMHeadModel.from_pretrained("gpt2")
|
||||
input_ids = tf.convert_to_tensor([[464, 3290]], dtype=tf.int32) # The dog
|
||||
@@ -574,3 +561,16 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||
|
||||
output_ids = xla_generate(input_ids, do_sample=False)
|
||||
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
|
||||
|
||||
@slow
|
||||
def test_lm_generate_gpt2_xla_sample(self):
|
||||
model = TFGPT2LMHeadModel.from_pretrained("gpt2")
|
||||
input_ids = tf.convert_to_tensor([[464, 3290]], dtype=tf.int32) # The dog
|
||||
|
||||
# fmt: off
|
||||
expected_output_ids = [464, 3290, 550, 284, 307, 4376, 287, 281, 4044, 1363, 329, 734, 812, 878, 852, 4376, 757, 329, 2267, 0]
|
||||
# fmt: on
|
||||
xla_generate = tf.function(model.generate, jit_compile=True)
|
||||
|
||||
output_ids = xla_generate(input_ids, do_sample=True, seed=[42, 0])
|
||||
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
|
||||
|
||||
@@ -524,6 +524,35 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
|
||||
|
||||
self.assertListEqual(expected_output_string, output_strings)
|
||||
|
||||
@slow
|
||||
def test_sample_xla_generate_simple(self):
|
||||
model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
|
||||
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||
|
||||
sentence = "Translate English to German: Today is a beautiful day."
|
||||
input_ids = tokenizer(sentence, return_tensors="tf", padding=True).input_ids
|
||||
# XLA reorder ops, which causes operations like FP matmul to have slightly different results, causing
|
||||
# divergences in generate -- especially with sampling.
|
||||
expected_output_string = ["Heute ist ein schöner Tag."]
|
||||
expected_output_string_xla = ["Heute ist ein schöne Tage."]
|
||||
# However, notice that the first tokens are the same, for the same seed
|
||||
assert expected_output_string[0][:15] == expected_output_string_xla[0][:15]
|
||||
|
||||
# forces the generation to happen on CPU, to avoid GPU-related quirks
|
||||
with tf.device(":/CPU:0"):
|
||||
# seed set -> deterministic sampling sequence -> deterministic generation
|
||||
output_ids = model.generate(input_ids, do_sample=True, seed=[42, 0])
|
||||
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
self.assertListEqual(expected_output_string, output_strings)
|
||||
|
||||
# forces the generation to happen on CPU, to avoid GPU-related quirks
|
||||
with tf.device(":/CPU:0"):
|
||||
xla_generate = tf.function(model.generate, jit_compile=True)
|
||||
# seed set -> deterministic sampling sequence -> deterministic generation
|
||||
output_ids_xla = xla_generate(input_ids, do_sample=True, seed=[42, 0])
|
||||
output_strings_xla = tokenizer.batch_decode(output_ids_xla, skip_special_tokens=True)
|
||||
self.assertListEqual(expected_output_string_xla, output_strings_xla)
|
||||
|
||||
@slow
|
||||
def test_sample_generate(self):
|
||||
model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
|
||||
@@ -540,16 +569,16 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
|
||||
"temperature": 0.8,
|
||||
"top_k": 500,
|
||||
"top_p": 0.9,
|
||||
"seed": [20, 0], # seed set -> deterministic sampling sequence -> deterministic generation
|
||||
}
|
||||
|
||||
# forces the generation to happen on CPU, to avoid GPU-related quirks
|
||||
with tf.device(":/CPU:0"):
|
||||
tf.random.set_seed(42) # deterministic sampling sequence -> deterministic generation
|
||||
output_ids = model.generate(input_ids, **generation_kwargs)
|
||||
|
||||
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
|
||||
expected_output_string = ["i love her I really love my heart", "die Transformatoren sind wirklich erstaunlich"]
|
||||
expected_output_string = ["- I really love my way of this.", "die Transformatoren sind wirklich erstaunlich"]
|
||||
|
||||
self.assertListEqual(expected_output_string, output_strings)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user