Fix failing test_batch_generation for bloom (#25718)

fix

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2023-08-24 11:15:29 +02:00
committed by GitHub
parent f01459c75d
commit 8fff61b9db

View File

@@ -449,9 +449,9 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
input_sentence = ["I enjoy walking with my cute dog", "I enjoy walking with my cute dog"]
input_ids = tokenizer.batch_encode_plus(input_sentence, return_tensors="pt", padding=True)
input_ids = input_ids["input_ids"].to(torch_device)
attention_mask = input_ids["attention_mask"]
inputs = tokenizer.batch_encode_plus(input_sentence, return_tensors="pt", padding=True)
input_ids = inputs["input_ids"].to(torch_device)
attention_mask = inputs["attention_mask"]
greedy_output = model.generate(input_ids, attention_mask=attention_mask, max_length=50, do_sample=False)
self.assertEqual(