Generate: Fix flaky indexing error in test_constrained_beam_search_generate_dict_output (#21561)
This commit is contained in:
@@ -1359,13 +1359,8 @@ class GenerationTesterMixin:
|
||||
)
|
||||
|
||||
# Sample constraints
|
||||
if not input_ids.dtype == torch.float32:
|
||||
min_id = torch.min(input_ids) + 3
|
||||
max_id = torch.max(input_ids)
|
||||
else:
|
||||
# otherwise this throws an error for Speech2TextModel since its inputs are floating points
|
||||
min_id = 3
|
||||
max_id = 100
|
||||
min_id = 3
|
||||
max_id = model.config.vocab_size
|
||||
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
|
||||
constraints = [
|
||||
PhrasalConstraint(force_tokens),
|
||||
|
||||
@@ -632,10 +632,6 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
|
||||
def test_generate_without_input_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Skip while we investigate while it's failing.")
|
||||
def test_constrained_beam_search_generate_dict_output(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _get_encoder_outputs(
|
||||
model, input_ids, attention_mask, output_attentions=None, output_hidden_states=None, num_interleave=1
|
||||
|
||||
Reference in New Issue
Block a user