Generate: Fix flaky indexing error in test_constrained_beam_search_generate_dict_output (#21561)
This commit is contained in:
@@ -1359,13 +1359,8 @@ class GenerationTesterMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Sample constraints
|
# Sample constraints
|
||||||
if not input_ids.dtype == torch.float32:
|
|
||||||
min_id = torch.min(input_ids) + 3
|
|
||||||
max_id = torch.max(input_ids)
|
|
||||||
else:
|
|
||||||
# otherwise this throws an error for Speech2TextModel since its inputs are floating points
|
|
||||||
min_id = 3
|
min_id = 3
|
||||||
max_id = 100
|
max_id = model.config.vocab_size
|
||||||
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
|
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
|
||||||
constraints = [
|
constraints = [
|
||||||
PhrasalConstraint(force_tokens),
|
PhrasalConstraint(force_tokens),
|
||||||
|
|||||||
@@ -632,10 +632,6 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
|
|||||||
def test_generate_without_input_ids(self):
|
def test_generate_without_input_ids(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip("Skip while we investigate while it's failing.")
|
|
||||||
def test_constrained_beam_search_generate_dict_output(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_encoder_outputs(
|
def _get_encoder_outputs(
|
||||||
model, input_ids, attention_mask, output_attentions=None, output_hidden_states=None, num_interleave=1
|
model, input_ids, attention_mask, output_attentions=None, output_hidden_states=None, num_interleave=1
|
||||||
|
|||||||
Reference in New Issue
Block a user