Skipping aqlm non working inference tests till fix merged (#34865)

This commit is contained in:
Mohamed Mekkouri
2024-11-26 11:09:30 +01:00
committed by GitHub
parent 73b4ab1085
commit 0e805e6d1e

View File

@@ -17,6 +17,7 @@ import gc
import importlib import importlib
import tempfile import tempfile
import unittest import unittest
from unittest import skip
from packaging import version from packaging import version
@@ -142,6 +143,9 @@ class AqlmTest(unittest.TestCase):
self.assertEqual(nb_linears - 1, nb_aqlm_linear) self.assertEqual(nb_linears - 1, nb_aqlm_linear)
@skip(
"inference doesn't work with quantized aqlm models using torch.Any type with recent torch versions. Waiting for the fix from AQLM side"
)
def test_quantized_model(self): def test_quantized_model(self):
""" """
Simple test that checks if the quantized model is working properly Simple test that checks if the quantized model is working properly
@@ -158,6 +162,9 @@ class AqlmTest(unittest.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
_ = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config) _ = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config)
@skip(
"inference doesn't work with quantized aqlm models using torch.Any type with recent torch versions. Waiting for the fix from AQLM side"
)
def test_save_pretrained(self): def test_save_pretrained(self):
""" """
Simple test that checks if the quantized model is working properly after being saved and loaded Simple test that checks if the quantized model is working properly after being saved and loaded
@@ -171,6 +178,9 @@ class AqlmTest(unittest.TestCase):
output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens) output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
@skip(
"inference doesn't work with quantized aqlm models using torch.Any type with recent torch versions. Waiting for the fix from AQLM side"
)
@require_torch_multi_gpu @require_torch_multi_gpu
def test_quantized_model_multi_gpu(self): def test_quantized_model_multi_gpu(self):
""" """