Fixing quantization tests (#37650)
* fix * style * add capability check
This commit is contained in:
@@ -110,7 +110,7 @@ class AwqTest(unittest.TestCase):
|
||||
input_text = "Hello my name is"
|
||||
|
||||
EXPECTED_OUTPUT = "Hello my name is Katie and I am a 20 year old student at the University of North Carolina at Chapel Hill. I am a junior and I am majoring in Journalism and minoring in Spanish"
|
||||
EXPECTED_OUTPUT_BF16 = "Hello my name is Katie and I am a 20 year old student at the University of North Carolina at Chapel Hill. I am a junior and I am majoring in Exercise and Sport Science with a"
|
||||
EXPECTED_OUTPUT_BF16 = "Hello my name is Katie and I am a 20 year old student at the University of North Carolina at Chapel Hill. I am a junior and I am majoring in Journalism and minoring in Spanish"
|
||||
|
||||
EXPECTED_OUTPUT_EXLLAMA = [
|
||||
"Hello my name is Katie and I am a 20 year old student from the UK. I am currently studying for a degree in English Literature and History at the University of York. I am a very out",
|
||||
@@ -299,7 +299,7 @@ class AwqFusedTest(unittest.TestCase):
|
||||
"You end up exactly where you started. Where are you?"
|
||||
)
|
||||
|
||||
EXPECTED_GENERATION = prompt + "\n\nThis is a classic puzzle that has been around for"
|
||||
EXPECTED_GENERATION = prompt + "\n\nYou're at the center of a square."
|
||||
EXPECTED_GENERATION_CUSTOM_MODEL = "Hello,\n\nI have a problem with my 20"
|
||||
EXPECTED_GENERATION_MIXTRAL = prompt + " You're on the North Pole.\n\nThe"
|
||||
|
||||
@@ -355,6 +355,10 @@ class AwqFusedTest(unittest.TestCase):
|
||||
# Checks if the modules_to_not_convert (here gate layer) is a Linear
|
||||
self.assertTrue(isinstance(model.model.layers[0].block_sparse_moe.gate, torch.nn.Linear))
|
||||
|
||||
@unittest.skipIf(
|
||||
torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8,
|
||||
"Skipping because RuntimeError: FlashAttention only supports Ampere GPUs or newer, so not supported on GPU with capability < 8.0",
|
||||
)
|
||||
def test_generation_fused(self):
|
||||
"""
|
||||
Test generation quality for fused models - single batch case
|
||||
@@ -378,6 +382,10 @@ class AwqFusedTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(tokenizer.decode(outputs[0], skip_special_tokens=True), self.EXPECTED_GENERATION)
|
||||
|
||||
@unittest.skipIf(
|
||||
torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8,
|
||||
"Skipping because RuntimeError: FlashAttention only supports Ampere GPUs or newer, so not supported on GPU with capability < 8.0",
|
||||
)
|
||||
def test_generation_fused_batched(self):
|
||||
"""
|
||||
Test generation quality for fused models - multi batch case
|
||||
@@ -426,6 +434,10 @@ class AwqFusedTest(unittest.TestCase):
|
||||
self.assertEqual(outputs[0]["generated_text"], EXPECTED_OUTPUT)
|
||||
|
||||
@require_torch_multi_gpu
|
||||
@unittest.skipIf(
|
||||
torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8,
|
||||
"Skipping because RuntimeError: FlashAttention only supports Ampere GPUs or newer, so not supported on GPU with capability < 8.0",
|
||||
)
|
||||
def test_generation_custom_model(self):
|
||||
"""
|
||||
Test generation quality for fused models using custom fused map.
|
||||
|
||||
Reference in New Issue
Block a user