Generate: add test to check KV format (#23403)

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Joao Gante
2023-05-16 19:28:19 +01:00
committed by GitHub
parent 9cf4a8b456
commit 918a06e25d
4 changed files with 84 additions and 1 deletions

View File

@@ -393,6 +393,10 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bloom_weight_initialization(*config_and_inputs)
@unittest.skip("Bloom has a non-standard KV cache format.")
def test_past_key_values_format(self):
pass
@slow
def test_model_from_pretrained(self):
for model_name in BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: