TF generate refactor - Sample (#15793)
* Add TF logits wrappers * Add sample method * add tests for TF logit wrappers * TF generate sample tests now run on CPU Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
@@ -488,9 +488,12 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||
"top_k": 500,
|
||||
"top_p": 0.9,
|
||||
}
|
||||
tf.random.set_seed(42) # deterministic sampling sequence -> deterministic generation
|
||||
|
||||
output_ids = model.generate(input_ids, **generation_kwargs)
|
||||
# forces the generation to happen on CPU, to avoid GPU-related quirks
|
||||
with tf.device(":/CPU:0"):
|
||||
tf.random.set_seed(42) # deterministic sampling sequence -> deterministic generation
|
||||
output_ids = model.generate(input_ids, **generation_kwargs)
|
||||
|
||||
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
|
||||
expected_output_string = [
|
||||
|
||||
Reference in New Issue
Block a user