TF generate refactor - Sample (#15793)

* Add TF logits wrappers 

* Add sample method

* add tests for TF logit wrappers

* TF generate sample tests now run on CPU

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
Joao Gante
2022-03-02 16:13:54 +00:00
committed by GitHub
parent 96ae92be8c
commit baab5e7cdf
13 changed files with 652 additions and 332 deletions

View File

@@ -497,9 +497,11 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
"top_k": 500,
"top_p": 0.9,
}
tf.random.set_seed(42) # deterministic sampling sequence -> deterministic generation
output_ids = model.generate(input_ids, **generation_kwargs)
# forces the generation to happen on CPU, to avoid GPU-related quirks
with tf.device(":/CPU:0"):
tf.random.set_seed(42) # deterministic sampling sequence -> deterministic generation
output_ids = model.generate(input_ids, **generation_kwargs)
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)