Config: unified logic to retrieve text config (#33219)
This commit is contained in:
@@ -831,7 +831,7 @@ class GenerationTesterMixin:
|
||||
|
||||
# Sample constraints
|
||||
min_id = 3
|
||||
max_id = config.vocab_size
|
||||
max_id = config.get_text_config(decoder=True).vocab_size
|
||||
|
||||
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
|
||||
constraints = [
|
||||
@@ -889,7 +889,7 @@ class GenerationTesterMixin:
|
||||
|
||||
# Sample constraints
|
||||
min_id = 3
|
||||
max_id = model.config.vocab_size
|
||||
max_id = model.config.get_text_config(decoder=True).vocab_size
|
||||
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
|
||||
constraints = [
|
||||
PhrasalConstraint(force_tokens),
|
||||
@@ -2012,18 +2012,20 @@ class GenerationTesterMixin:
|
||||
self.assertTrue(output.past_key_values is None)
|
||||
|
||||
def _check_scores(self, batch_size, scores, length, config):
|
||||
expected_shape = (batch_size, config.vocab_size)
|
||||
vocab_size = config.get_text_config(decoder=True).vocab_size
|
||||
expected_shape = (batch_size, vocab_size)
|
||||
self.assertIsInstance(scores, tuple)
|
||||
self.assertEqual(len(scores), length)
|
||||
self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores))
|
||||
|
||||
def _check_logits(self, batch_size, scores, config):
|
||||
vocab_size = config.get_text_config(decoder=True).vocab_size
|
||||
self.assertIsInstance(scores, tuple)
|
||||
self.assertListEqual([iter_scores.shape[0] for iter_scores in scores], [batch_size] * len(scores))
|
||||
# vocabulary difference equal to one (imagegptmodel?) or zero (all other models)
|
||||
vocab_diff = config.vocab_size - scores[0].shape[-1]
|
||||
vocab_diff = vocab_size - scores[0].shape[-1]
|
||||
self.assertTrue(vocab_diff in [0, 1])
|
||||
self.assertListEqual([config.vocab_size - score.shape[-1] for score in scores], [vocab_diff] * len(scores))
|
||||
self.assertListEqual([vocab_size - score.shape[-1] for score in scores], [vocab_diff] * len(scores))
|
||||
|
||||
def _check_attentions_for_generate(
|
||||
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||
|
||||
Reference in New Issue
Block a user