Fix MinNewTokensLengthLogitsProcessor when used with a list of eos tokens (#21959)

* Fix MinNewTokensLengthLogitsProcessor when used with a list of eos tokens

* fix docs

* Empty commit

* formatting
This commit is contained in:
Elad Segal
2023-03-07 13:59:22 +02:00
committed by GitHub
parent 4063fd9cba
commit eec46b4f75
2 changed files with 31 additions and 12 deletions

View File

@@ -13,8 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from typing import List, Union
from parameterized import parameterized
from transformers import is_torch_available
from transformers.testing_utils import require_torch, torch_device
@@ -76,10 +78,10 @@ class LogitsProcessorTest(unittest.TestCase):
scores_before_min_length = min_dist_processor(input_ids, scores)
self.assertFalse(torch.isinf(scores_before_min_length).any())
def test_new_min_length_dist_processor(self):
@parameterized.expand([(0,), ([0, 18],)])
def test_new_min_length_dist_processor(self, eos_token_id: Union[int, List[int]]):
vocab_size = 20
batch_size = 4
eos_token_id = 0
# check that first input is skipped (min new length applying)
input_ids = ids_tensor((batch_size, 5), vocab_size=20)
@@ -87,9 +89,15 @@ class LogitsProcessorTest(unittest.TestCase):
prompt_length_to_skip=input_ids.shape[-1], min_new_tokens=3, eos_token_id=eos_token_id
)
expected_eos_scores_before_min_length = batch_size * [-float("inf")]
if isinstance(eos_token_id, list):
expected_eos_scores_before_min_length *= len(eos_token_id)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores_before_min_length = new_min_dist_processor(input_ids, scores)
self.assertListEqual(scores_before_min_length[:, eos_token_id].tolist(), batch_size * [-float("inf")])
self.assertListEqual(
scores_before_min_length[:, eos_token_id].flatten().tolist(), expected_eos_scores_before_min_length
)
# check that, for skipping, now prompt length is 5, after that we expect first 5 tokens will be skipped
self.assertTrue(new_min_dist_processor.prompt_length_to_skip == 5)
@@ -98,19 +106,25 @@ class LogitsProcessorTest(unittest.TestCase):
input_ids = ids_tensor((batch_size, 2), vocab_size=20)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores_before_min_length = new_min_dist_processor(input_ids, scores)
self.assertListEqual(scores_before_min_length[:, eos_token_id].tolist(), batch_size * [-float("inf")])
self.assertListEqual(
scores_before_min_length[:, eos_token_id].flatten().tolist(), expected_eos_scores_before_min_length
)
# check that min new length is applied at length 6 (because it has only 1 new token)
input_ids = ids_tensor((batch_size, 6), vocab_size=20)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores_before_min_length = new_min_dist_processor(input_ids, scores)
self.assertListEqual(scores_before_min_length[:, eos_token_id].tolist(), batch_size * [-float("inf")])
self.assertListEqual(
scores_before_min_length[:, eos_token_id].flatten().tolist(), expected_eos_scores_before_min_length
)
# check that min new length is applied at length 7 (because it has only 2 new tokens)
input_ids = ids_tensor((batch_size, 7), vocab_size=20)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores_before_min_length = new_min_dist_processor(input_ids, scores)
self.assertListEqual(scores_before_min_length[:, eos_token_id].tolist(), batch_size * [-float("inf")])
self.assertListEqual(
scores_before_min_length[:, eos_token_id].flatten().tolist(), expected_eos_scores_before_min_length
)
# check that min new length is not applied anymore at length 8
input_ids = ids_tensor((batch_size, 8), vocab_size=20)