Config: unified logic to retrieve text config (#33219)

This commit is contained in:
Joao Gante
2024-09-04 12:03:30 +01:00
committed by GitHub
parent ebbe8d8014
commit d750b509fc
10 changed files with 91 additions and 88 deletions

View File

@@ -831,7 +831,7 @@ class GenerationTesterMixin:
# Sample constraints
min_id = 3
max_id = config.vocab_size
max_id = config.get_text_config(decoder=True).vocab_size
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
constraints = [
@@ -889,7 +889,7 @@ class GenerationTesterMixin:
# Sample constraints
min_id = 3
max_id = model.config.vocab_size
max_id = model.config.get_text_config(decoder=True).vocab_size
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
constraints = [
PhrasalConstraint(force_tokens),
@@ -2012,18 +2012,20 @@ class GenerationTesterMixin:
self.assertTrue(output.past_key_values is None)
def _check_scores(self, batch_size, scores, length, config):
expected_shape = (batch_size, config.vocab_size)
vocab_size = config.get_text_config(decoder=True).vocab_size
expected_shape = (batch_size, vocab_size)
self.assertIsInstance(scores, tuple)
self.assertEqual(len(scores), length)
self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores))
def _check_logits(self, batch_size, scores, config):
vocab_size = config.get_text_config(decoder=True).vocab_size
self.assertIsInstance(scores, tuple)
self.assertListEqual([iter_scores.shape[0] for iter_scores in scores], [batch_size] * len(scores))
# vocabulary difference equal to one (imagegptmodel?) or zero (all other models)
vocab_diff = config.vocab_size - scores[0].shape[-1]
vocab_diff = vocab_size - scores[0].shape[-1]
self.assertTrue(vocab_diff in [0, 1])
self.assertListEqual([config.vocab_size - score.shape[-1] for score in scores], [vocab_diff] * len(scores))
self.assertListEqual([vocab_size - score.shape[-1] for score in scores], [vocab_diff] * len(scores))
def _check_attentions_for_generate(
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1