TF generate refactor - Sample (#15793)

* Add TF logits wrappers 

* Add sample method

* add tests for TF logit wrappers

* TF generate sample tests now run on CPU

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
Joao Gante
2022-03-02 16:13:54 +00:00
committed by GitHub
parent 96ae92be8c
commit baab5e7cdf
13 changed files with 652 additions and 332 deletions

View File

@@ -51,7 +51,7 @@ class LogitsProcessorTest(unittest.TestCase):
scores = torch.ones((batch_size, length), device=torch_device, dtype=torch.float) / length
return scores
def test_min_lenght_dist_processor(self):
def test_min_length_dist_processor(self):
vocab_size = 20
batch_size = 4
eos_token_id = 0