Update composition flag usage (#36263)
* update composition flag usage * remove print * fix tests * actually fix * oh c'mon * now should be fixed right? * fix copies
This commit is contained in:
committed by
GitHub
parent
08e3217baf
commit
6f4058aee3
@@ -1755,9 +1755,7 @@ class GenerationTesterMixin:
|
||||
|
||||
text_config = model.config.get_text_config()
|
||||
head_dim = (
|
||||
text_config.head_dim
|
||||
if hasattr(text_config, "head_dim")
|
||||
else text_config.hidden_size // text_config.num_attention_heads
|
||||
getattr(text_config, "head_dim", None) or text_config.hidden_size // text_config.num_attention_heads
|
||||
)
|
||||
num_key_value_heads = (
|
||||
text_config.num_attention_heads
|
||||
@@ -2008,9 +2006,8 @@ class GenerationTesterMixin:
|
||||
max_cache_len = seq_length + max_new_tokens - 1 # cache len = gen len - 1, the last token has no cache
|
||||
text_config = config.text_config if hasattr(config, "text_config") else config
|
||||
head_dim = (
|
||||
text_config.head_dim
|
||||
if hasattr(text_config, "head_dim")
|
||||
else text_config.hidden_size // text_config.num_attention_heads
|
||||
getattr(text_config, "head_dim", None)
|
||||
or text_config.hidden_size // text_config.num_attention_heads
|
||||
)
|
||||
num_key_value_heads = (
|
||||
text_config.num_attention_heads
|
||||
|
||||
@@ -184,13 +184,6 @@ class LlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterM
|
||||
)
|
||||
|
||||
def test_config(self):
|
||||
# overwritten from `tests/test_configuration_common.py::ConfigTester` after #36077
|
||||
# TODO: avoid overwritten once there is a better fix for #36077
|
||||
def check_config_can_be_init_without_params():
|
||||
config = self.config_tester.config_class()
|
||||
self.config_tester.parent.assertIsNotNone(config)
|
||||
|
||||
self.config_tester.check_config_can_be_init_without_params = check_config_can_be_init_without_params
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
|
||||
@@ -163,7 +163,7 @@ class ConfigTester:
|
||||
self.parent.assertEqual(len(config.label2id), 3)
|
||||
|
||||
def check_config_can_be_init_without_params(self):
|
||||
if self.config_class.is_composition:
|
||||
if self.config_class.has_no_defaults_at_init:
|
||||
with self.parent.assertRaises(ValueError):
|
||||
config = self.config_class()
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user