Fix quality

This commit is contained in:
Sylvain Gugger
2022-02-09 12:06:59 -05:00
parent eed3186b79
commit b1ba03e082

View File

@@ -318,7 +318,7 @@ class ConstrainedBeamSearchTester:
beam_hyp.add(input_ids[beam_idx], -10.0 + float(beam_idx)) beam_hyp.add(input_ids[beam_idx], -10.0 + float(beam_idx))
# -10.0 is removed => -9.0 is worst score # -10.0 is removed => -9.0 is worst score
self.parent.assertAlmostEqual(beam_hyp.worst_score, -9.0 / (self.sequence_length ** beam_hyp.length_penalty)) self.parent.assertAlmostEqual(beam_hyp.worst_score, -9.0 / (self.sequence_length**beam_hyp.length_penalty))
# -5.0 is better than worst score => should not be finished # -5.0 is better than worst score => should not be finished
self.parent.assertFalse(beam_hyp.is_done(-5.0, self.sequence_length)) self.parent.assertFalse(beam_hyp.is_done(-5.0, self.sequence_length))