Fix failing test_batch_generation for bloom (#25718)
fix Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -449,9 +449,9 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
|
|
||||||
input_sentence = ["I enjoy walking with my cute dog", "I enjoy walking with my cute dog"]
|
input_sentence = ["I enjoy walking with my cute dog", "I enjoy walking with my cute dog"]
|
||||||
|
|
||||||
input_ids = tokenizer.batch_encode_plus(input_sentence, return_tensors="pt", padding=True)
|
inputs = tokenizer.batch_encode_plus(input_sentence, return_tensors="pt", padding=True)
|
||||||
input_ids = input_ids["input_ids"].to(torch_device)
|
input_ids = inputs["input_ids"].to(torch_device)
|
||||||
attention_mask = input_ids["attention_mask"]
|
attention_mask = inputs["attention_mask"]
|
||||||
greedy_output = model.generate(input_ids, attention_mask=attention_mask, max_length=50, do_sample=False)
|
greedy_output = model.generate(input_ids, attention_mask=attention_mask, max_length=50, do_sample=False)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
|
|||||||
Reference in New Issue
Block a user