Generate: Fix flaky indexing error in test_constrained_beam_search_generate_dict_output (#21561)
This commit is contained in:
@@ -1359,13 +1359,8 @@ class GenerationTesterMixin:
|
||||
)
|
||||
|
||||
# Sample constraints
|
||||
if not input_ids.dtype == torch.float32:
|
||||
min_id = torch.min(input_ids) + 3
|
||||
max_id = torch.max(input_ids)
|
||||
else:
|
||||
# otherwise this throws an error for Speech2TextModel since its inputs are floating points
|
||||
min_id = 3
|
||||
max_id = 100
|
||||
min_id = 3
|
||||
max_id = model.config.vocab_size
|
||||
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
|
||||
constraints = [
|
||||
PhrasalConstraint(force_tokens),
|
||||
|
||||
Reference in New Issue
Block a user