Generate: Fix flaky indexing error in test_constrained_beam_search_generate_dict_output (#21561)

This commit is contained in:
Joao Gante
2023-02-13 15:12:07 +00:00
committed by GitHub
parent 93ed89bf40
commit 24273268b7
2 changed files with 2 additions and 11 deletions

View File

@@ -1359,13 +1359,8 @@ class GenerationTesterMixin:
)
# Sample constraints
if not input_ids.dtype == torch.float32:
min_id = torch.min(input_ids) + 3
max_id = torch.max(input_ids)
else:
# otherwise this throws an error for Speech2TextModel since its inputs are floating points
min_id = 3
max_id = 100
min_id = 3
max_id = model.config.vocab_size
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
constraints = [
PhrasalConstraint(force_tokens),