Expectation fixes and added AMD expectations (#38729)
This commit is contained in:
@@ -3237,7 +3237,9 @@ def cleanup(device: str, gc_collect=False):
|
|||||||
|
|
||||||
|
|
||||||
# Type definition of key used in `Expectations` class.
|
# Type definition of key used in `Expectations` class.
|
||||||
DeviceProperties = tuple[Union[str, None], Union[int, None]]
|
DeviceProperties = tuple[Optional[str], Optional[int], Optional[int]]
|
||||||
|
# Helper type. Makes creating instances of `Expectations` smoother.
|
||||||
|
PackedDeviceProperties = tuple[Optional[str], Union[None, int, tuple[int, int]]]
|
||||||
|
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
@@ -3248,11 +3250,11 @@ def get_device_properties() -> DeviceProperties:
|
|||||||
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
|
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
major, _ = torch.cuda.get_device_capability()
|
major, minor = torch.cuda.get_device_capability()
|
||||||
if IS_ROCM_SYSTEM:
|
if IS_ROCM_SYSTEM:
|
||||||
return ("rocm", major)
|
return ("rocm", major, minor)
|
||||||
else:
|
else:
|
||||||
return ("cuda", major)
|
return ("cuda", major, minor)
|
||||||
elif IS_XPU_SYSTEM:
|
elif IS_XPU_SYSTEM:
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -3260,58 +3262,81 @@ def get_device_properties() -> DeviceProperties:
|
|||||||
arch = torch.xpu.get_device_capability()["architecture"]
|
arch = torch.xpu.get_device_capability()["architecture"]
|
||||||
gen_mask = 0x000000FF00000000
|
gen_mask = 0x000000FF00000000
|
||||||
gen = (arch & gen_mask) >> 32
|
gen = (arch & gen_mask) >> 32
|
||||||
return ("xpu", gen)
|
return ("xpu", gen, None)
|
||||||
else:
|
else:
|
||||||
return (torch_device, None)
|
return (torch_device, None, None)
|
||||||
|
|
||||||
|
|
||||||
class Expectations(UserDict[DeviceProperties, Any]):
|
def unpack_device_properties(
|
||||||
|
properties: Optional[PackedDeviceProperties] = None,
|
||||||
|
) -> DeviceProperties:
|
||||||
|
"""
|
||||||
|
Unpack a `PackedDeviceProperties` tuple into consistently formatted `DeviceProperties` tuple. If properties is None, it is fetched.
|
||||||
|
"""
|
||||||
|
if properties is None:
|
||||||
|
return get_device_properties()
|
||||||
|
device_type, major_minor = properties
|
||||||
|
if major_minor is None:
|
||||||
|
major, minor = None, None
|
||||||
|
elif isinstance(major_minor, int):
|
||||||
|
major, minor = major_minor, None
|
||||||
|
else:
|
||||||
|
major, minor = major_minor
|
||||||
|
return device_type, major, minor
|
||||||
|
|
||||||
|
|
||||||
|
class Expectations(UserDict[PackedDeviceProperties, Any]):
|
||||||
def get_expectation(self) -> Any:
|
def get_expectation(self) -> Any:
|
||||||
"""
|
"""
|
||||||
Find best matching expectation based on environment device properties.
|
Find best matching expectation based on environment device properties.
|
||||||
"""
|
"""
|
||||||
return self.find_expectation(get_device_properties())
|
return self.find_expectation(get_device_properties())
|
||||||
|
|
||||||
@staticmethod
|
def unpacked(self) -> list[tuple[DeviceProperties, Any]]:
|
||||||
def is_default(key: DeviceProperties) -> bool:
|
return [(unpack_device_properties(k), v) for k, v in self.data.items()]
|
||||||
return all(p is None for p in key)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def score(key: DeviceProperties, other: DeviceProperties) -> int:
|
def is_default(properties: DeviceProperties) -> bool:
|
||||||
|
return all(p is None for p in properties)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def score(properties: DeviceProperties, other: DeviceProperties) -> float:
|
||||||
"""
|
"""
|
||||||
Returns score indicating how similar two instances of the `Properties` tuple are.
|
Returns score indicating how similar two instances of the `Properties` tuple are.
|
||||||
Points are calculated using bits, but documented as int.
|
|
||||||
Rules are as follows:
|
Rules are as follows:
|
||||||
* Matching `type` gives 8 points.
|
* Matching `type` adds one point, semi-matching `type` adds half a point (e.g. cuda and rocm).
|
||||||
* Semi-matching `type`, for example cuda and rocm, gives 4 points.
|
* If types match, matching `major` adds another point, and then matching `minor` adds another.
|
||||||
* Matching `major` (compute capability major version) gives 2 points.
|
* Default expectation (if present) is worth 0.1 point to distinguish it from a straight-up zero.
|
||||||
* Default expectation (if present) gives 1 points.
|
|
||||||
"""
|
"""
|
||||||
(device_type, major) = key
|
device_type, major, minor = properties
|
||||||
(other_device_type, other_major) = other
|
other_device_type, other_major, other_minor = other
|
||||||
|
|
||||||
score = 0b0
|
score = 0
|
||||||
if device_type == other_device_type:
|
# Matching device type, maybe major and minor
|
||||||
score |= 0b1000
|
if device_type is not None and device_type == other_device_type:
|
||||||
|
score += 1
|
||||||
|
if major is not None and major == other_major:
|
||||||
|
score += 1
|
||||||
|
if minor is not None and minor == other_minor:
|
||||||
|
score += 1
|
||||||
|
# Semi-matching device type
|
||||||
elif device_type in ["cuda", "rocm"] and other_device_type in ["cuda", "rocm"]:
|
elif device_type in ["cuda", "rocm"] and other_device_type in ["cuda", "rocm"]:
|
||||||
score |= 0b100
|
score = 0.5
|
||||||
|
|
||||||
if major == other_major and other_major is not None:
|
|
||||||
score |= 0b10
|
|
||||||
|
|
||||||
|
# Default expectation
|
||||||
if Expectations.is_default(other):
|
if Expectations.is_default(other):
|
||||||
score |= 0b1
|
score = 0.1
|
||||||
|
|
||||||
return int(score)
|
return score
|
||||||
|
|
||||||
def find_expectation(self, key: DeviceProperties = (None, None)) -> Any:
|
def find_expectation(self, properties: DeviceProperties = (None, None, None)) -> Any:
|
||||||
"""
|
"""
|
||||||
Find best matching expectation based on provided device properties.
|
Find best matching expectation based on provided device properties.
|
||||||
"""
|
"""
|
||||||
(result_key, result) = max(self.data.items(), key=lambda x: Expectations.score(key, x[0]))
|
(result_key, result) = max(self.unpacked(), key=lambda x: Expectations.score(properties, x[0]))
|
||||||
|
|
||||||
if Expectations.score(key, result_key) == 0:
|
if Expectations.score(properties, result_key) == 0:
|
||||||
raise ValueError(f"No matching expectation found for {key}")
|
raise ValueError(f"No matching expectation found for {properties}")
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|||||||
@@ -347,7 +347,8 @@ class AyaVisionIntegrationTest(unittest.TestCase):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def get_model(cls):
|
def get_model(cls):
|
||||||
# Use 4-bit on T4
|
# Use 4-bit on T4
|
||||||
load_in_4bit = get_device_properties()[0] == "cuda" and get_device_properties()[1] < 8
|
device_type, major, _ = get_device_properties()
|
||||||
|
load_in_4bit = (device_type == "cuda") and (major < 8)
|
||||||
torch_dtype = None if load_in_4bit else torch.float16
|
torch_dtype = None if load_in_4bit else torch.float16
|
||||||
|
|
||||||
if cls.model is None:
|
if cls.model is None:
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ from transformers import (
|
|||||||
is_torch_available,
|
is_torch_available,
|
||||||
)
|
)
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
DeviceProperties,
|
||||||
Expectations,
|
Expectations,
|
||||||
get_device_properties,
|
get_device_properties,
|
||||||
require_deterministic_for_xpu,
|
require_deterministic_for_xpu,
|
||||||
@@ -594,7 +595,7 @@ class BambaModelIntegrationTest(unittest.TestCase):
|
|||||||
tokenizer = None
|
tokenizer = None
|
||||||
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
||||||
# Depending on the hardware we get different logits / generations
|
# Depending on the hardware we get different logits / generations
|
||||||
device_properties = None
|
device_properties: DeviceProperties = (None, None, None)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
@@ -637,7 +638,7 @@ class BambaModelIntegrationTest(unittest.TestCase):
|
|||||||
self.assertEqual(output_sentence, expected)
|
self.assertEqual(output_sentence, expected)
|
||||||
|
|
||||||
# TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist
|
# TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist
|
||||||
if self.device_properties == ("cuda", 8):
|
if self.device_properties[0] == "cuda" and self.device_properties[1] == 8:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
logits = self.model(input_ids=input_ids, logits_to_keep=40).logits
|
logits = self.model(input_ids=input_ids, logits_to_keep=40).logits
|
||||||
|
|
||||||
@@ -690,7 +691,7 @@ class BambaModelIntegrationTest(unittest.TestCase):
|
|||||||
self.assertEqual(output_sentences[1], EXPECTED_TEXT[1])
|
self.assertEqual(output_sentences[1], EXPECTED_TEXT[1])
|
||||||
|
|
||||||
# TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist
|
# TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist
|
||||||
if self.device_properties == ("cuda", 8):
|
if self.device_properties[0] == "cuda" and self.device_properties[1] == 8:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
logits = self.model(input_ids=inputs["input_ids"]).logits
|
logits = self.model(input_ids=inputs["input_ids"]).logits
|
||||||
|
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from packaging import version
|
|||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GemmaConfig, is_torch_available
|
from transformers import AutoModelForCausalLM, AutoTokenizer, GemmaConfig, is_torch_available
|
||||||
from transformers.generation.configuration_utils import GenerationConfig
|
from transformers.generation.configuration_utils import GenerationConfig
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
DeviceProperties,
|
||||||
Expectations,
|
Expectations,
|
||||||
cleanup,
|
cleanup,
|
||||||
get_device_properties,
|
get_device_properties,
|
||||||
@@ -108,7 +109,7 @@ class GemmaIntegrationTest(unittest.TestCase):
|
|||||||
input_text = ["Hello I am doing", "Hi today"]
|
input_text = ["Hello I am doing", "Hi today"]
|
||||||
# This variable is used to determine which accelerator are we using for our runners (e.g. A10 or T4)
|
# This variable is used to determine which accelerator are we using for our runners (e.g. A10 or T4)
|
||||||
# Depending on the hardware we get different logits / generations
|
# Depending on the hardware we get different logits / generations
|
||||||
device_properties = None
|
device_properties: DeviceProperties = (None, None, None)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
@@ -241,7 +242,7 @@ class GemmaIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
@require_read_token
|
@require_read_token
|
||||||
def test_model_7b_fp16(self):
|
def test_model_7b_fp16(self):
|
||||||
if self.device_properties == ("cuda", 7):
|
if self.device_properties[0] == "cuda" and self.device_properties[1] == 7:
|
||||||
self.skipTest("This test is failing (`torch.compile` fails) on Nvidia T4 GPU (OOM).")
|
self.skipTest("This test is failing (`torch.compile` fails) on Nvidia T4 GPU (OOM).")
|
||||||
|
|
||||||
model_id = "google/gemma-7b"
|
model_id = "google/gemma-7b"
|
||||||
@@ -262,7 +263,7 @@ class GemmaIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
@require_read_token
|
@require_read_token
|
||||||
def test_model_7b_bf16(self):
|
def test_model_7b_bf16(self):
|
||||||
if self.device_properties == ("cuda", 7):
|
if self.device_properties[0] == "cuda" and self.device_properties[1] == 7:
|
||||||
self.skipTest("This test is failing (`torch.compile` fails) on Nvidia T4 GPU (OOM).")
|
self.skipTest("This test is failing (`torch.compile` fails) on Nvidia T4 GPU (OOM).")
|
||||||
|
|
||||||
model_id = "google/gemma-7b"
|
model_id = "google/gemma-7b"
|
||||||
@@ -293,7 +294,7 @@ class GemmaIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
@require_read_token
|
@require_read_token
|
||||||
def test_model_7b_fp16_static_cache(self):
|
def test_model_7b_fp16_static_cache(self):
|
||||||
if self.device_properties == ("cuda", 7):
|
if self.device_properties[0] == "cuda" and self.device_properties[1] == 7:
|
||||||
self.skipTest("This test is failing (`torch.compile` fails) on Nvidia T4 GPU (OOM).")
|
self.skipTest("This test is failing (`torch.compile` fails) on Nvidia T4 GPU (OOM).")
|
||||||
|
|
||||||
model_id = "google/gemma-7b"
|
model_id = "google/gemma-7b"
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import pytest
|
|||||||
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GlmConfig, is_torch_available
|
from transformers import AutoModelForCausalLM, AutoTokenizer, GlmConfig, is_torch_available
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
Expectations,
|
||||||
require_flash_attn,
|
require_flash_attn,
|
||||||
require_torch,
|
require_torch,
|
||||||
require_torch_large_accelerator,
|
require_torch_large_accelerator,
|
||||||
@@ -118,10 +119,17 @@ class GlmIntegrationTest(unittest.TestCase):
|
|||||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||||
|
|
||||||
def test_model_9b_eager(self):
|
def test_model_9b_eager(self):
|
||||||
EXPECTED_TEXTS = [
|
expected_texts = Expectations({
|
||||||
"Hello I am doing a project on the history of the internetSolution:\n\nStep 1: Introduction\nThe history of the",
|
("cuda", None): [
|
||||||
"Hi today I am going to show you how to make a simple and easy to make a DIY paper flower.",
|
"Hello I am doing a project on the history of the internetSolution:\n\nStep 1: Introduction\nThe history of the",
|
||||||
]
|
"Hi today I am going to show you how to make a simple and easy to make a DIY paper flower.",
|
||||||
|
],
|
||||||
|
("rocm", (9, 5)) : [
|
||||||
|
"Hello I am doing a project on the history of the internetSolution:\n\nStep 1: Introduction\nThe history of the",
|
||||||
|
"Hi today I am going to show you how to make a simple and easy to make a paper airplane. First",
|
||||||
|
]
|
||||||
|
}) # fmt: skip
|
||||||
|
EXPECTED_TEXTS = expected_texts.get_expectation()
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
self.model_id,
|
self.model_id,
|
||||||
|
|||||||
@@ -821,6 +821,7 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
|
|||||||
expected_outputs = Expectations(
|
expected_outputs = Expectations(
|
||||||
{
|
{
|
||||||
("rocm", None): 'Today is a nice day and we can do this again."\n\nDana said that she will',
|
("rocm", None): 'Today is a nice day and we can do this again."\n\nDana said that she will',
|
||||||
|
("rocm", (9, 5)): "Today is a nice day and if you don't know anything about the state of play during your holiday",
|
||||||
("cuda", None): "Today is a nice day and if you don't know anything about the state of play during your holiday",
|
("cuda", None): "Today is a nice day and if you don't know anything about the state of play during your holiday",
|
||||||
}
|
}
|
||||||
) # fmt: skip
|
) # fmt: skip
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import unittest
|
|||||||
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, HeliumConfig, is_torch_available
|
from transformers import AutoModelForCausalLM, AutoTokenizer, HeliumConfig, is_torch_available
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
Expectations,
|
||||||
require_read_token,
|
require_read_token,
|
||||||
require_torch,
|
require_torch,
|
||||||
slow,
|
slow,
|
||||||
@@ -83,9 +84,13 @@ class HeliumIntegrationTest(unittest.TestCase):
|
|||||||
@require_read_token
|
@require_read_token
|
||||||
def test_model_2b(self):
|
def test_model_2b(self):
|
||||||
model_id = "kyutai/helium-1-preview"
|
model_id = "kyutai/helium-1-preview"
|
||||||
EXPECTED_TEXTS = [
|
expected_texts = Expectations(
|
||||||
"Hello, today is a great day to start a new project. I have been working on a new project for a while now and I have"
|
{
|
||||||
]
|
("rocm", (9, 5)): ["Hello, today is a great day to start a new project. I have been working on a new project for a while now, and I"],
|
||||||
|
("cuda", None): ["Hello, today is a great day to start a new project. I have been working on a new project for a while now and I have"],
|
||||||
|
}
|
||||||
|
) # fmt: skip
|
||||||
|
EXPECTED_TEXTS = expected_texts.get_expectation()
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, revision="refs/pr/1").to(
|
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, revision="refs/pr/1").to(
|
||||||
torch_device
|
torch_device
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ from transformers import (
|
|||||||
is_vision_available,
|
is_vision_available,
|
||||||
)
|
)
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
Expectations,
|
||||||
cleanup,
|
cleanup,
|
||||||
require_bitsandbytes,
|
require_bitsandbytes,
|
||||||
require_flash_attn,
|
require_flash_attn,
|
||||||
@@ -621,8 +622,14 @@ class Idefics2ForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
generated_ids = model.generate(**inputs, max_new_tokens=10)
|
generated_ids = model.generate(**inputs, max_new_tokens=10)
|
||||||
generated_texts = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
|
generated_texts = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||||
|
|
||||||
expected_generated_text = "In this image, we see the Statue of Liberty, the Hudson River,"
|
expected_generated_texts = Expectations(
|
||||||
self.assertEqual(generated_texts[0], expected_generated_text)
|
{
|
||||||
|
("cuda", None): "In this image, we see the Statue of Liberty, the Hudson River,",
|
||||||
|
("rocm", (9, 5)): "In this image, we see the Statue of Liberty, the New York City",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
EXPECTED_GENERATED_TEXT = expected_generated_texts.get_expectation()
|
||||||
|
self.assertEqual(generated_texts[0], EXPECTED_GENERATED_TEXT)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_bitsandbytes
|
@require_bitsandbytes
|
||||||
|
|||||||
@@ -537,6 +537,7 @@ class InternVLQwen2IntegrationTest(unittest.TestCase):
|
|||||||
{
|
{
|
||||||
("xpu", 3): "The man is performing a volley.",
|
("xpu", 3): "The man is performing a volley.",
|
||||||
("cuda", 7): "The man is performing a forehand shot.",
|
("cuda", 7): "The man is performing a forehand shot.",
|
||||||
|
("rocm", (9, 5)): "The man is performing a volley shot.",
|
||||||
}
|
}
|
||||||
) # fmt: skip
|
) # fmt: skip
|
||||||
expected_output = expected_outputs.get_expectation()
|
expected_output = expected_outputs.get_expectation()
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import pytest
|
|||||||
|
|
||||||
from transformers import AutoTokenizer, JambaConfig, is_torch_available
|
from transformers import AutoTokenizer, JambaConfig, is_torch_available
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
DeviceProperties,
|
||||||
Expectations,
|
Expectations,
|
||||||
get_device_properties,
|
get_device_properties,
|
||||||
require_bitsandbytes,
|
require_bitsandbytes,
|
||||||
@@ -557,7 +558,7 @@ class JambaModelIntegrationTest(unittest.TestCase):
|
|||||||
tokenizer = None
|
tokenizer = None
|
||||||
# This variable is used to determine which acclerator are we using for our runners (e.g. A10 or T4)
|
# This variable is used to determine which acclerator are we using for our runners (e.g. A10 or T4)
|
||||||
# Depending on the hardware we get different logits / generations
|
# Depending on the hardware we get different logits / generations
|
||||||
device_properties = None
|
device_properties: DeviceProperties = (None, None, None)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
@@ -595,7 +596,7 @@ class JambaModelIntegrationTest(unittest.TestCase):
|
|||||||
self.assertEqual(output_sentence, expected_sentence)
|
self.assertEqual(output_sentence, expected_sentence)
|
||||||
|
|
||||||
# TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist
|
# TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist
|
||||||
if self.device_properties == ("cuda", 8):
|
if self.device_properties[0] == "cuda" and self.device_properties[1] == 8:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
logits = self.model(input_ids=input_ids).logits
|
logits = self.model(input_ids=input_ids).logits
|
||||||
|
|
||||||
@@ -638,7 +639,7 @@ class JambaModelIntegrationTest(unittest.TestCase):
|
|||||||
self.assertEqual(output_sentences[1], expected_sentences[1])
|
self.assertEqual(output_sentences[1], expected_sentences[1])
|
||||||
|
|
||||||
# TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist
|
# TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist
|
||||||
if self.device_properties == ("cuda", 8):
|
if self.device_properties[0] == "cuda" and self.device_properties[1] == 8:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
logits = self.model(input_ids=inputs["input_ids"]).logits
|
logits = self.model(input_ids=inputs["input_ids"]).logits
|
||||||
|
|
||||||
|
|||||||
@@ -541,16 +541,24 @@ class JanusIntegrationTest(unittest.TestCase):
|
|||||||
# fmt: off
|
# fmt: off
|
||||||
expected_tokens = Expectations(
|
expected_tokens = Expectations(
|
||||||
{
|
{
|
||||||
("rocm", None): [10367, 1380, 4841, 15155, 1224, 16361, 15834, 13722, 15258, 8321, 10496, 14532, 8770,
|
("rocm", None): [
|
||||||
12353, 5481, 11484, 2585, 8587, 3201, 14292, 3356, 2037, 3077, 6107, 3758, 2572, 9376,
|
10367, 1380, 4841, 15155, 1224, 16361, 15834, 13722, 15258, 8321, 10496, 14532, 8770, 12353, 5481,
|
||||||
13219, 6007, 14292, 12696, 10666, 10046, 13483, 8282, 9101, 5208, 4260, 13886, 13335,
|
11484, 2585, 8587, 3201, 14292, 3356, 2037, 3077, 6107, 3758, 2572, 9376, 13219, 6007, 14292, 12696,
|
||||||
6135, 2316, 15423, 311, 5460, 12218, 14172, 8583, 14577, 3648
|
10666, 10046, 13483, 8282, 9101, 5208, 4260, 13886, 13335, 6135, 2316, 15423, 311, 5460, 12218,
|
||||||
],
|
14172, 8583, 14577, 3648
|
||||||
("cuda", None): [4484, 4015, 15750, 506, 3758, 11651, 8597, 5739, 4861, 971, 14985, 14834, 15438, 7548,
|
],
|
||||||
1820, 1465, 13529, 12761, 10503, 12761, 14303, 6155, 4015, 11766, 705, 15736, 14146,
|
("rocm", (9, 5)): [
|
||||||
10417, 1951, 7713, 14305, 15617, 6169, 2706, 8006, 14893, 3855, 10188, 15652, 6297,
|
4484, 4015, 15750, 506, 3758, 11651, 8597, 5739, 4861, 971, 14985, 14834, 15438, 7548, 1820, 1465,
|
||||||
1097, 12108, 15038, 311, 14998, 15165, 897, 4044, 1762, 4676
|
13529, 12761, 10503, 12761, 14303, 6155, 4015, 11766, 705, 15736, 14146, 10417, 1951, 7713, 14305,
|
||||||
],
|
15617, 6169, 2706, 8006, 14893, 3855, 10188, 15652, 6297, 1097, 12108, 15038, 311, 14998, 15165,
|
||||||
|
897, 4044, 1762, 4676
|
||||||
|
],
|
||||||
|
("cuda", None): [
|
||||||
|
4484, 4015, 15750, 506, 3758, 11651, 8597, 5739, 4861, 971, 14985, 14834, 15438, 7548, 1820, 1465,
|
||||||
|
13529, 12761, 10503, 12761, 14303, 6155, 4015, 11766, 705, 15736, 14146, 10417, 1951, 7713, 14305,
|
||||||
|
15617, 6169, 2706, 8006, 14893, 3855, 10188, 15652, 6297, 1097, 12108, 15038, 311, 14998, 15165,
|
||||||
|
897, 4044, 1762, 4676
|
||||||
|
],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
expected_tokens = torch.tensor(expected_tokens.get_expectation()).to(model.device)
|
expected_tokens = torch.tensor(expected_tokens.get_expectation()).to(model.device)
|
||||||
|
|||||||
@@ -113,7 +113,7 @@ class LlamaIntegrationTest(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
# diff on `EXPECTED_TEXT`:
|
# diff on `EXPECTED_TEXT`:
|
||||||
# 2024-08-26: updating from torch 2.3.1 to 2.4.0 slightly changes the results.
|
# 2024-08-26: updating from torch 2.3.1 to 2.4.0 slightly changes the results.
|
||||||
EXPECTED_TEXT = (
|
expected_base_text = (
|
||||||
"Tell me about the french revolution. The french revolution was a period of radical political and social "
|
"Tell me about the french revolution. The french revolution was a period of radical political and social "
|
||||||
"upheaval in France that lasted from 1789 until 1799. It was a time of great change and upheaval, marked "
|
"upheaval in France that lasted from 1789 until 1799. It was a time of great change and upheaval, marked "
|
||||||
"by the overthrow of the monarchy, the rise of the middle class, and the eventual establishment of the "
|
"by the overthrow of the monarchy, the rise of the middle class, and the eventual establishment of the "
|
||||||
@@ -122,6 +122,13 @@ class LlamaIntegrationTest(unittest.TestCase):
|
|||||||
"demanded greater representation and eventually broke away to form the National Assembly. This marked "
|
"demanded greater representation and eventually broke away to form the National Assembly. This marked "
|
||||||
"the beginning of the end of the absolute monarchy and the rise of the middle class.\n"
|
"the beginning of the end of the absolute monarchy and the rise of the middle class.\n"
|
||||||
)
|
)
|
||||||
|
expected_texts = Expectations(
|
||||||
|
{
|
||||||
|
("rocm", (9, 5)): expected_base_text.replace("political and social", "social and political"),
|
||||||
|
("cuda", None): expected_base_text,
|
||||||
|
}
|
||||||
|
) # fmt: skip
|
||||||
|
EXPECTED_TEXT = expected_texts.get_expectation()
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
|
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
|
||||||
model = LlamaForCausalLM.from_pretrained(
|
model = LlamaForCausalLM.from_pretrained(
|
||||||
|
|||||||
@@ -341,7 +341,11 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
inputs = self.processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device)
|
inputs = self.processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device)
|
||||||
|
|
||||||
output = model.generate(**inputs, max_new_tokens=20)
|
output = model.generate(**inputs, max_new_tokens=20)
|
||||||
EXPECTED_DECODED_TEXT = "\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT: When visiting this place, there are a few things one should be cautious about. Firstly," # fmt: skip
|
expected_decoded_texts = Expectations({
|
||||||
|
("cuda", None): "\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT: When visiting this place, there are a few things one should be cautious about. Firstly,",
|
||||||
|
("rocm", (9, 5)): "\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT: When visiting this place, there are a few things one should be cautious about. First, the",
|
||||||
|
}) # fmt: skip
|
||||||
|
EXPECTED_DECODED_TEXT = expected_decoded_texts.get_expectation()
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
self.processor.decode(output[0], skip_special_tokens=True),
|
self.processor.decode(output[0], skip_special_tokens=True),
|
||||||
@@ -397,12 +401,28 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
output = model.generate(**inputs, max_new_tokens=20)
|
output = model.generate(**inputs, max_new_tokens=20)
|
||||||
|
|
||||||
EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me? ASSISTANT: When visiting this place, which is a pier or dock extending over a body of water, you', 'USER: \nWhat is this? ASSISTANT: The image features two cats lying down on a pink couch. One cat is located on'] # fmt: skip
|
expected_decoded_texts = Expectations(
|
||||||
|
{
|
||||||
self.assertEqual(
|
("cuda", None): [
|
||||||
processor.batch_decode(output, skip_special_tokens=True),
|
"USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring "
|
||||||
EXPECTED_DECODED_TEXT,
|
"with me? ASSISTANT: When visiting this place, which is a pier or dock extending over a body of water, "
|
||||||
|
"you",
|
||||||
|
"USER: \nWhat is this? ASSISTANT: The image features two cats lying down on a pink couch. One cat "
|
||||||
|
"is located on",
|
||||||
|
],
|
||||||
|
("rocm", (9, 5)): [
|
||||||
|
"USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring "
|
||||||
|
"with me? ASSISTANT: When visiting this serene location, which features a wooden pier overlooking a "
|
||||||
|
"lake, you should",
|
||||||
|
"USER: \nWhat is this? ASSISTANT: The image features two cats lying down on a pink couch. One cat "
|
||||||
|
"is located on",
|
||||||
|
],
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
EXPECTED_DECODED_TEXT = expected_decoded_texts.get_expectation()
|
||||||
|
|
||||||
|
decoded_output = processor.batch_decode(output, skip_special_tokens=True)
|
||||||
|
self.assertEqual(decoded_output, EXPECTED_DECODED_TEXT)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_bitsandbytes
|
@require_bitsandbytes
|
||||||
@@ -433,6 +453,10 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
'USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, there are a few things to be cautious about and items to bring along',
|
'USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, there are a few things to be cautious about and items to bring along',
|
||||||
'USER: \nWhat is this?\nASSISTANT: Cats',
|
'USER: \nWhat is this?\nASSISTANT: Cats',
|
||||||
],
|
],
|
||||||
|
("rocm", (9, 5)): [
|
||||||
|
"USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this dock on a lake, there are several things to be cautious about and items to",
|
||||||
|
"USER: \nWhat is this?\nASSISTANT: This is a picture of two cats lying on a couch.",
|
||||||
|
],
|
||||||
}
|
}
|
||||||
) # fmt: skip
|
) # fmt: skip
|
||||||
EXPECTED_DECODED_TEXT = EXPECTED_DECODED_TEXTS.get_expectation()
|
EXPECTED_DECODED_TEXT = EXPECTED_DECODED_TEXTS.get_expectation()
|
||||||
@@ -467,12 +491,28 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
output = model.generate(**inputs, max_new_tokens=20)
|
output = model.generate(**inputs, max_new_tokens=20)
|
||||||
|
|
||||||
EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, which appears to be a dock or pier extending over a body of water', 'USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT: A cat sleeping on a bed.'] # fmt: skip
|
expected_decoded_texts = Expectations(
|
||||||
|
{
|
||||||
self.assertEqual(
|
("cuda", None): [
|
||||||
processor.batch_decode(output, skip_special_tokens=True),
|
"USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring "
|
||||||
EXPECTED_DECODED_TEXT,
|
"with me?\nASSISTANT: When visiting this place, which appears to be a dock or pier extending over a "
|
||||||
|
"body of water",
|
||||||
|
"USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT: A cat "
|
||||||
|
"sleeping on a bed.",
|
||||||
|
],
|
||||||
|
("rocm", (9, 5)): [
|
||||||
|
"USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring "
|
||||||
|
"with me?\nASSISTANT: When visiting this place, which is a pier or dock overlooking a lake, you should "
|
||||||
|
"be",
|
||||||
|
"USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT: A cat "
|
||||||
|
"sleeping on a bed.",
|
||||||
|
],
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
EXPECTED_DECODED_TEXT = expected_decoded_texts.get_expectation()
|
||||||
|
|
||||||
|
decoded_output = processor.batch_decode(output, skip_special_tokens=True)
|
||||||
|
self.assertEqual(decoded_output, EXPECTED_DECODED_TEXT)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from packaging import version
|
|||||||
|
|
||||||
from transformers import AutoTokenizer, MistralConfig, is_torch_available, set_seed
|
from transformers import AutoTokenizer, MistralConfig, is_torch_available, set_seed
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
DeviceProperties,
|
||||||
Expectations,
|
Expectations,
|
||||||
backend_empty_cache,
|
backend_empty_cache,
|
||||||
cleanup,
|
cleanup,
|
||||||
@@ -114,7 +115,7 @@ class MistralModelTest(CausalLMModelTest, unittest.TestCase):
|
|||||||
class MistralIntegrationTest(unittest.TestCase):
|
class MistralIntegrationTest(unittest.TestCase):
|
||||||
# This variable is used to determine which accelerator are we using for our runners (e.g. A10 or T4)
|
# This variable is used to determine which accelerator are we using for our runners (e.g. A10 or T4)
|
||||||
# Depending on the hardware we get different logits / generations
|
# Depending on the hardware we get different logits / generations
|
||||||
device_properties = None
|
device_properties: DeviceProperties = (None, None, None)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
@@ -279,7 +280,7 @@ class MistralIntegrationTest(unittest.TestCase):
|
|||||||
if version.parse(torch.__version__) < version.parse("2.3.0"):
|
if version.parse(torch.__version__) < version.parse("2.3.0"):
|
||||||
self.skipTest(reason="This test requires torch >= 2.3 to run.")
|
self.skipTest(reason="This test requires torch >= 2.3 to run.")
|
||||||
|
|
||||||
if self.device_properties == ("cuda", 7):
|
if self.device_properties[0] == "cuda" and self.device_properties[1] == 7:
|
||||||
self.skipTest(reason="This test is failing (`torch.compile` fails) on Nvidia T4 GPU.")
|
self.skipTest(reason="This test is failing (`torch.compile` fails) on Nvidia T4 GPU.")
|
||||||
|
|
||||||
NUM_TOKENS_TO_GENERATE = 40
|
NUM_TOKENS_TO_GENERATE = 40
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ import pytest
|
|||||||
from transformers import MixtralConfig, is_torch_available
|
from transformers import MixtralConfig, is_torch_available
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
Expectations,
|
Expectations,
|
||||||
get_device_properties,
|
|
||||||
require_flash_attn,
|
require_flash_attn,
|
||||||
require_torch,
|
require_torch,
|
||||||
require_torch_accelerator,
|
require_torch_accelerator,
|
||||||
@@ -142,14 +141,6 @@ class MistralModelTest(CausalLMModelTest, unittest.TestCase):
|
|||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class MixtralIntegrationTest(unittest.TestCase):
|
class MixtralIntegrationTest(unittest.TestCase):
|
||||||
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
|
||||||
# Depending on the hardware we get different logits / generations
|
|
||||||
device_properties = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
cls.device_properties = get_device_properties()
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch_accelerator
|
@require_torch_accelerator
|
||||||
def test_small_model_logits(self):
|
def test_small_model_logits(self):
|
||||||
|
|||||||
@@ -445,7 +445,11 @@ class MptIntegrationTests(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
input_text = "Hello"
|
input_text = "Hello"
|
||||||
expected_output = "Hello, I'm a new user of the forum. I have a question about the \"Solaris"
|
expected_outputs = Expectations({
|
||||||
|
("cuda", None): "Hello, I'm a new user of the forum. I have a question about the \"Solaris",
|
||||||
|
("rocm", (9, 5)): "Hello, I'm a newbie to the forum. I have a question about the \"B\" in",
|
||||||
|
}) # fmt: off
|
||||||
|
expected_output = expected_outputs.get_expectation()
|
||||||
|
|
||||||
inputs = tokenizer(input_text, return_tensors="pt").to(torch_device)
|
inputs = tokenizer(input_text, return_tensors="pt").to(torch_device)
|
||||||
outputs = model.generate(**inputs, max_new_tokens=20)
|
outputs = model.generate(**inputs, max_new_tokens=20)
|
||||||
@@ -463,19 +467,12 @@ class MptIntegrationTests(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
input_text = "Hello"
|
input_text = "Hello"
|
||||||
expected_outputs = Expectations(
|
expected_outputs = Expectations({
|
||||||
{
|
("rocm", (9, 5)): "Hello and welcome to the first day of the new release at The Stamp Man!\nToday we are",
|
||||||
(
|
("xpu", 3): "Hello and welcome to the first ever episode of the new and improved, and hopefully improved, podcast.\n",
|
||||||
"xpu",
|
("cuda", 7): "Hello and welcome to the first episode of the new podcast, The Frugal Feminist.\n",
|
||||||
3,
|
("cuda", 8): "Hello and welcome to the first day of the new release countdown for the month of May!\nToday",
|
||||||
): "Hello and welcome to the first ever episode of the new and improved, and hopefully improved, podcast.\n",
|
}) # fmt: off
|
||||||
("cuda", 7): "Hello and welcome to the first episode of the new podcast, The Frugal Feminist.\n",
|
|
||||||
(
|
|
||||||
"cuda",
|
|
||||||
8,
|
|
||||||
): "Hello and welcome to the first day of the new release countdown for the month of May!\nToday",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
expected_output = expected_outputs.get_expectation()
|
expected_output = expected_outputs.get_expectation()
|
||||||
|
|
||||||
inputs = tokenizer(input_text, return_tensors="pt").to(torch_device)
|
inputs = tokenizer(input_text, return_tensors="pt").to(torch_device)
|
||||||
@@ -510,6 +507,10 @@ class MptIntegrationTests(unittest.TestCase):
|
|||||||
"Hello my name is Tiffany and I am a mother of two beautiful children. I have been a nanny for the",
|
"Hello my name is Tiffany and I am a mother of two beautiful children. I have been a nanny for the",
|
||||||
"Today I am going at the gym and then I am going to go to the grocery store. I am going to buy some food and some",
|
"Today I am going at the gym and then I am going to go to the grocery store. I am going to buy some food and some",
|
||||||
],
|
],
|
||||||
|
("rocm", (9, 5)): [
|
||||||
|
"Hello my name is Jasmine and I am a very sweet and loving dog. I am a very playful dog and I",
|
||||||
|
"Today I am going at the gym and then I am going to go to the mall. I am going to buy a new pair of jeans",
|
||||||
|
],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
expected_output = expected_outputs.get_expectation()
|
expected_output = expected_outputs.get_expectation()
|
||||||
@@ -535,9 +536,10 @@ class MptIntegrationTests(unittest.TestCase):
|
|||||||
{
|
{
|
||||||
("xpu", 3): torch.Tensor([-0.2090, -0.2061, -0.1465]),
|
("xpu", 3): torch.Tensor([-0.2090, -0.2061, -0.1465]),
|
||||||
("cuda", 7): torch.Tensor([-0.2520, -0.2178, -0.1953]),
|
("cuda", 7): torch.Tensor([-0.2520, -0.2178, -0.1953]),
|
||||||
|
# TODO: This is quite a bit off, check BnB
|
||||||
|
("rocm", (9, 5)): torch.Tensor([-0.3008, -0.1309, -0.1562]),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
expected_slice = expected_slices.get_expectation().to(torch_device, torch.bfloat16)
|
expected_slice = expected_slices.get_expectation().to(torch_device, torch.bfloat16)
|
||||||
predicted_slice = outputs.hidden_states[-1][0, 0, :3]
|
predicted_slice = outputs.hidden_states[-1][0, 0, :3]
|
||||||
|
|
||||||
torch.testing.assert_close(expected_slice, predicted_slice, rtol=1e-3, atol=1e-3)
|
torch.testing.assert_close(expected_slice, predicted_slice, rtol=1e-3, atol=1e-3)
|
||||||
|
|||||||
@@ -1041,7 +1041,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
if not self.has_attentions:
|
if not self.has_attentions:
|
||||||
self.skipTest(reason="Model architecture does not support attentions")
|
self.skipTest(reason="Model architecture does not support attentions")
|
||||||
|
|
||||||
(device_type, major) = get_device_properties()
|
device_type, major, _ = get_device_properties()
|
||||||
if device_type == "cuda" and major < 8:
|
if device_type == "cuda" and major < 8:
|
||||||
self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")
|
self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")
|
||||||
elif device_type == "rocm" and major < 9:
|
elif device_type == "rocm" and major < 9:
|
||||||
|
|||||||
@@ -1041,7 +1041,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
if not self.has_attentions:
|
if not self.has_attentions:
|
||||||
self.skipTest(reason="Model architecture does not support attentions")
|
self.skipTest(reason="Model architecture does not support attentions")
|
||||||
|
|
||||||
(device_type, major) = get_device_properties()
|
device_type, major, _ = get_device_properties()
|
||||||
if device_type == "cuda" and major < 8:
|
if device_type == "cuda" and major < 8:
|
||||||
self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")
|
self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")
|
||||||
elif device_type == "rocm" and major < 9:
|
elif device_type == "rocm" and major < 9:
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ from transformers import (
|
|||||||
is_vision_available,
|
is_vision_available,
|
||||||
)
|
)
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
Expectations,
|
||||||
cleanup,
|
cleanup,
|
||||||
require_read_token,
|
require_read_token,
|
||||||
require_torch,
|
require_torch,
|
||||||
@@ -590,7 +591,13 @@ class PaliGemmaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
output = model.generate(**inputs, max_new_tokens=20)
|
output = model.generate(**inputs, max_new_tokens=20)
|
||||||
|
|
||||||
EXPECTED_DECODED_TEXT = "detect shoe\n<loc0051><loc0309><loc0708><loc0646> shoe" # fmt: skip
|
expected_decoded_texts = Expectations(
|
||||||
|
{
|
||||||
|
("rocm", (9, 5)): "detect shoe\n<loc0051><loc0309><loc0708><loc0644> shoe",
|
||||||
|
("cuda", None): "detect shoe\n<loc0051><loc0309><loc0708><loc0646> shoe",
|
||||||
|
}
|
||||||
|
) # fmt: skip
|
||||||
|
EXPECTED_DECODED_TEXT = expected_decoded_texts.get_expectation()
|
||||||
self.assertEqual(self.processor.decode(output[0], skip_special_tokens=True), EXPECTED_DECODED_TEXT)
|
self.assertEqual(self.processor.decode(output[0], skip_special_tokens=True), EXPECTED_DECODED_TEXT)
|
||||||
|
|
||||||
def test_paligemma_index_error_bug(self):
|
def test_paligemma_index_error_bug(self):
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import unittest
|
|||||||
from transformers import Phi3Config, StaticCache, is_torch_available
|
from transformers import Phi3Config, StaticCache, is_torch_available
|
||||||
from transformers.models.auto.configuration_auto import AutoConfig
|
from transformers.models.auto.configuration_auto import AutoConfig
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
Expectations,
|
||||||
require_torch,
|
require_torch,
|
||||||
slow,
|
slow,
|
||||||
torch_device,
|
torch_device,
|
||||||
@@ -352,9 +353,14 @@ class Phi3IntegrationTest(unittest.TestCase):
|
|||||||
model_id = "microsoft/Phi-4-mini-instruct"
|
model_id = "microsoft/Phi-4-mini-instruct"
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_id, pad_token="</s>", padding_side="right")
|
tokenizer = AutoTokenizer.from_pretrained(model_id, pad_token="</s>", padding_side="right")
|
||||||
EXPECTED_TEXT_COMPLETION = [
|
|
||||||
"You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user. A 45-year-old patient with a 10-year history of type 2 diabetes mellitus, who is currently on metformin and a SGLT2 inhibitor, presents with a 2-year history"
|
expected_text_completions = Expectations(
|
||||||
]
|
{
|
||||||
|
("rocm", (9, 5)): ["You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user. A 45-year-old patient with a 10-year history of type 2 diabetes mellitus presents with a 2-year history of progressive, non-healing, and painful, 2.5 cm"],
|
||||||
|
("cuda", None): ["You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user. A 45-year-old patient with a 10-year history of type 2 diabetes mellitus, who is currently on metformin and a SGLT2 inhibitor, presents with a 2-year history"],
|
||||||
|
}
|
||||||
|
) # fmt: skip
|
||||||
|
EXPECTED_TEXT_COMPLETION = expected_text_completions.get_expectation()
|
||||||
max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[
|
max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[
|
||||||
"input_ids"
|
"input_ids"
|
||||||
].shape[-1]
|
].shape[-1]
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from packaging import version
|
|||||||
from transformers import AutoTokenizer, Qwen2Config, is_torch_available, set_seed
|
from transformers import AutoTokenizer, Qwen2Config, is_torch_available, set_seed
|
||||||
from transformers.generation.configuration_utils import GenerationConfig
|
from transformers.generation.configuration_utils import GenerationConfig
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
Expectations,
|
||||||
backend_empty_cache,
|
backend_empty_cache,
|
||||||
require_bitsandbytes,
|
require_bitsandbytes,
|
||||||
require_flash_attn,
|
require_flash_attn,
|
||||||
@@ -250,9 +251,17 @@ class Qwen2IntegrationTest(unittest.TestCase):
|
|||||||
qwen_model = "Qwen/Qwen2-0.5B"
|
qwen_model = "Qwen/Qwen2-0.5B"
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(qwen_model, pad_token="</s>", padding_side="right")
|
tokenizer = AutoTokenizer.from_pretrained(qwen_model, pad_token="</s>", padding_side="right")
|
||||||
EXPECTED_TEXT_COMPLETION = [
|
|
||||||
"My favourite condiment is 100% natural, organic, gluten free, vegan, and free from preservatives. I"
|
expected_text_completions = Expectations({
|
||||||
]
|
("cuda", None): [
|
||||||
|
"My favourite condiment is 100% natural, organic, gluten free, vegan, and free from preservatives. I"
|
||||||
|
],
|
||||||
|
("rocm", (9, 5)): [
|
||||||
|
"My favourite condiment is 100% natural, organic, gluten free, vegan, and vegetarian. I love to use"
|
||||||
|
]
|
||||||
|
}) # fmt: off
|
||||||
|
EXPECTED_TEXT_COMPLETION = expected_text_completions.get_expectation()
|
||||||
|
|
||||||
max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[
|
max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[
|
||||||
"input_ids"
|
"input_ids"
|
||||||
].shape[-1]
|
].shape[-1]
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from packaging import version
|
|||||||
from transformers import AutoTokenizer, Qwen3Config, is_torch_available, set_seed
|
from transformers import AutoTokenizer, Qwen3Config, is_torch_available, set_seed
|
||||||
from transformers.generation.configuration_utils import GenerationConfig
|
from transformers.generation.configuration_utils import GenerationConfig
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
Expectations,
|
||||||
backend_empty_cache,
|
backend_empty_cache,
|
||||||
require_bitsandbytes,
|
require_bitsandbytes,
|
||||||
require_flash_attn,
|
require_flash_attn,
|
||||||
@@ -246,10 +247,18 @@ class Qwen3IntegrationTest(unittest.TestCase):
|
|||||||
tokenizer = AutoTokenizer.from_pretrained(qwen_model, pad_token="</s>", padding_side="right")
|
tokenizer = AutoTokenizer.from_pretrained(qwen_model, pad_token="</s>", padding_side="right")
|
||||||
if version.parse(torch.__version__) == version.parse("2.7.0"):
|
if version.parse(torch.__version__) == version.parse("2.7.0"):
|
||||||
strict = False # Due to https://github.com/pytorch/pytorch/issues/150994
|
strict = False # Due to https://github.com/pytorch/pytorch/issues/150994
|
||||||
EXPECTED_TEXT_COMPLETION = ["My favourite condiment is 100% plain, unflavoured, and unadulterated."]
|
cuda_expectation = ["My favourite condiment is 100% plain, unflavoured, and unadulterated."]
|
||||||
else:
|
else:
|
||||||
strict = True
|
strict = True
|
||||||
EXPECTED_TEXT_COMPLETION = ["My favourite condiment is 100% plain, unflavoured, and unadulterated. It is"]
|
cuda_expectation = ["My favourite condiment is 100% plain, unflavoured, and unadulterated. It is"]
|
||||||
|
|
||||||
|
expected_text_completions = Expectations(
|
||||||
|
{
|
||||||
|
("rocm", (9, 5)): ["My favourite condiment is 100% plain, unflavoured, and unadulterated."],
|
||||||
|
("cuda", None): cuda_expectation,
|
||||||
|
}
|
||||||
|
) # fmt: skip
|
||||||
|
EXPECTED_TEXT_COMPLETION = expected_text_completions.get_expectation()
|
||||||
|
|
||||||
max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[
|
max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[
|
||||||
"input_ids"
|
"input_ids"
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from pytest import mark
|
|||||||
|
|
||||||
from transformers import Siglip2Config, Siglip2TextConfig, Siglip2VisionConfig
|
from transformers import Siglip2Config, Siglip2TextConfig, Siglip2VisionConfig
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
Expectations,
|
||||||
is_flaky,
|
is_flaky,
|
||||||
require_flash_attn,
|
require_flash_attn,
|
||||||
require_torch,
|
require_torch,
|
||||||
@@ -760,17 +761,19 @@ class Siglip2ModelIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
# verify the logits values
|
# verify the logits values
|
||||||
# fmt: off
|
# fmt: off
|
||||||
expected_logits_per_text = torch.tensor(
|
expected_logits_per_texts = Expectations({
|
||||||
[
|
("cuda", None): [
|
||||||
[ 1.0195, -0.0280, -1.4468],
|
[ 1.0195, -0.0280, -1.4468], [ -4.5395, -6.2269, -1.5667], [ 4.1757, 5.0358, 3.5159],
|
||||||
[ -4.5395, -6.2269, -1.5667],
|
[ 9.4264, 10.1879, 6.3353], [ 2.4409, 3.1058, 4.5491], [-12.3230, -13.7355, -13.4632],
|
||||||
[ 4.1757, 5.0358, 3.5159],
|
|
||||||
[ 9.4264, 10.1879, 6.3353],
|
|
||||||
[ 2.4409, 3.1058, 4.5491],
|
|
||||||
[-12.3230, -13.7355, -13.4632],
|
|
||||||
[ 1.1520, 1.1687, -1.9647],
|
[ 1.1520, 1.1687, -1.9647],
|
||||||
]
|
],
|
||||||
).to(torch_device)
|
("rocm", (9, 5)): [
|
||||||
|
[ 1.0236, -0.0376, -1.4464], [ -4.5358, -6.2235, -1.5628], [ 4.1708, 5.0334, 3.5187],
|
||||||
|
[ 9.4241, 10.1828, 6.3366], [ 2.4371, 3.1062, 4.5530], [-12.3173, -13.7240, -13.4580],
|
||||||
|
[ 1.1502, 1.1716, -1.9623]
|
||||||
|
],
|
||||||
|
})
|
||||||
|
EXPECTED_LOGITS_PER_TEXT = torch.tensor(expected_logits_per_texts.get_expectation()).to(torch_device)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
torch.testing.assert_close(outputs.logits_per_text, expected_logits_per_text, rtol=1e-3, atol=1e-3)
|
torch.testing.assert_close(outputs.logits_per_text, EXPECTED_LOGITS_PER_TEXT, rtol=1e-3, atol=1e-3)
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ import numpy as np
|
|||||||
|
|
||||||
from transformers import PretrainedConfig, VitsConfig
|
from transformers import PretrainedConfig, VitsConfig
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
Expectations,
|
||||||
is_flaky,
|
is_flaky,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
require_torch,
|
require_torch,
|
||||||
@@ -454,13 +455,21 @@ class VitsModelIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(outputs.waveform.shape, (1, 87040))
|
self.assertEqual(outputs.waveform.shape, (1, 87040))
|
||||||
# fmt: off
|
# fmt: off
|
||||||
EXPECTED_LOGITS = torch.tensor(
|
expected_logits = Expectations({
|
||||||
[
|
("cuda", None): [
|
||||||
0.0101, 0.0318, 0.0489, 0.0627, 0.0728, 0.0865, 0.1053, 0.1279,
|
0.0101, 0.0318, 0.0489, 0.0627, 0.0728, 0.0865, 0.1053, 0.1279,
|
||||||
0.1514, 0.1703, 0.1827, 0.1829, 0.1694, 0.1509, 0.1332, 0.1188,
|
0.1514, 0.1703, 0.1827, 0.1829, 0.1694, 0.1509, 0.1332, 0.1188,
|
||||||
0.1066, 0.0978, 0.0936, 0.0867, 0.0724, 0.0493, 0.0197, -0.0141,
|
0.1066, 0.0978, 0.0936, 0.0867, 0.0724, 0.0493, 0.0197, -0.0141,
|
||||||
-0.0501, -0.0817, -0.1065, -0.1223, -0.1311, -0.1339
|
-0.0501, -0.0817, -0.1065, -0.1223, -0.1311, -0.1339
|
||||||
|
],
|
||||||
|
("rocm", (9, 5)): [
|
||||||
|
0.0097, 0.0315, 0.0486, 0.0626, 0.0728, 0.0865, 0.1053, 0.1279,
|
||||||
|
0.1515, 0.1703, 0.1827, 0.1829, 0.1694, 0.1509, 0.1333, 0.1189,
|
||||||
|
0.1066, 0.0978, 0.0937, 0.0868, 0.0726, 0.0496, 0.0200, -0.0138,
|
||||||
|
-0.0500, -0.0817, -0.1067, -0.1225, -0.1313, -0.1340
|
||||||
]
|
]
|
||||||
).to(torch.float16)
|
})
|
||||||
|
EXPECTED_LOGITS = torch.tensor(expected_logits.get_expectation(), dtype=torch.float16)
|
||||||
|
|
||||||
# fmt: on
|
# fmt: on
|
||||||
torch.testing.assert_close(outputs.waveform[0, 10000:10030].cpu(), EXPECTED_LOGITS, rtol=1e-4, atol=1e-4)
|
torch.testing.assert_close(outputs.waveform[0, 10000:10030].cpu(), EXPECTED_LOGITS, rtol=1e-4, atol=1e-4)
|
||||||
|
|||||||
@@ -17,7 +17,9 @@ import unittest
|
|||||||
|
|
||||||
from transformers import XGLMConfig, is_torch_available
|
from transformers import XGLMConfig, is_torch_available
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
Expectations,
|
||||||
cleanup,
|
cleanup,
|
||||||
|
is_torch_greater_or_equal,
|
||||||
require_torch,
|
require_torch,
|
||||||
require_torch_accelerator,
|
require_torch_accelerator,
|
||||||
require_torch_fp16,
|
require_torch_fp16,
|
||||||
@@ -422,13 +424,21 @@ class XGLMModelLanguageGenerationTest(unittest.TestCase):
|
|||||||
output_ids = model.generate(input_ids, do_sample=True, num_beams=1)
|
output_ids = model.generate(input_ids, do_sample=True, num_beams=1)
|
||||||
output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
||||||
|
|
||||||
EXPECTED_OUTPUT_STRS = [
|
if is_torch_greater_or_equal("2.7.0"):
|
||||||
# torch 2.6
|
cuda_expectation = (
|
||||||
"Today is a nice day and the water is still cold. We just stopped off for some fresh coffee. This place looks like a",
|
"Today is a nice day and the sun is shining. A nice day with warm rainy and windy weather today."
|
||||||
# torch 2.7
|
)
|
||||||
"Today is a nice day and the sun is shining. A nice day with warm rainy and windy weather today.",
|
else:
|
||||||
]
|
cuda_expectation = "Today is a nice day and the water is still cold. We just stopped off for some fresh coffee. This place looks like a"
|
||||||
self.assertIn(output_str, EXPECTED_OUTPUT_STRS)
|
|
||||||
|
expected_output_strings = Expectations(
|
||||||
|
{
|
||||||
|
("rocm", (9, 5)): "Today is a nice day and the sun is shining. A nice day with warm rainy and windy weather today.",
|
||||||
|
("cuda", None): cuda_expectation,
|
||||||
|
}
|
||||||
|
) # fmt: skip
|
||||||
|
EXPECTED_OUTPUT_STR = expected_output_strings.get_expectation()
|
||||||
|
self.assertEqual(output_str, EXPECTED_OUTPUT_STR)
|
||||||
|
|
||||||
@require_torch_accelerator
|
@require_torch_accelerator
|
||||||
@require_torch_fp16
|
@require_torch_fp16
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
from transformers.pipelines import MaskGenerationPipeline
|
from transformers.pipelines import MaskGenerationPipeline
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
Expectations,
|
||||||
is_pipeline_test,
|
is_pipeline_test,
|
||||||
nested_simplify,
|
nested_simplify,
|
||||||
require_tf,
|
require_tf,
|
||||||
@@ -120,6 +121,11 @@ class MaskGenerationPipelineTests(unittest.TestCase):
|
|||||||
new_outupt += [{"mask": mask_to_test_readable(o), "scores": outputs["scores"][i]}]
|
new_outupt += [{"mask": mask_to_test_readable(o), "scores": outputs["scores"][i]}]
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
|
last_output = Expectations({
|
||||||
|
("cuda", None): {'mask': {'hash': 'b5f47c9191', 'shape': (480, 640)}, 'scores': 0.8871},
|
||||||
|
("rocm", (9, 5)): {'mask': {'hash': 'b5f47c9191', 'shape': (480, 640)}, 'scores': 0.8872}
|
||||||
|
}).get_expectation()
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
nested_simplify(new_outupt, decimals=4),
|
nested_simplify(new_outupt, decimals=4),
|
||||||
[
|
[
|
||||||
@@ -152,7 +158,7 @@ class MaskGenerationPipelineTests(unittest.TestCase):
|
|||||||
{'mask': {'hash': '7b9e8ddb73', 'shape': (480, 640)}, 'scores': 0.8986},
|
{'mask': {'hash': '7b9e8ddb73', 'shape': (480, 640)}, 'scores': 0.8986},
|
||||||
{'mask': {'hash': 'cd24047c8a', 'shape': (480, 640)}, 'scores': 0.8984},
|
{'mask': {'hash': 'cd24047c8a', 'shape': (480, 640)}, 'scores': 0.8984},
|
||||||
{'mask': {'hash': '6943e6bcbd', 'shape': (480, 640)}, 'scores': 0.8873},
|
{'mask': {'hash': '6943e6bcbd', 'shape': (480, 640)}, 'scores': 0.8873},
|
||||||
{'mask': {'hash': 'b5f47c9191', 'shape': (480, 640)}, 'scores': 0.8871}
|
last_output
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ class AwqConfigTest(unittest.TestCase):
|
|||||||
|
|
||||||
# Only cuda and xpu devices can run this function
|
# Only cuda and xpu devices can run this function
|
||||||
support_llm_awq = False
|
support_llm_awq = False
|
||||||
device_type, major = get_device_properties()
|
device_type, major, _ = get_device_properties()
|
||||||
if device_type == "cuda" and major >= 8:
|
if device_type == "cuda" and major >= 8:
|
||||||
support_llm_awq = True
|
support_llm_awq = True
|
||||||
elif device_type == "xpu":
|
elif device_type == "xpu":
|
||||||
|
|||||||
@@ -552,7 +552,8 @@ class TorchAoSerializationFP8AcceleratorTest(TorchAoSerializationTest):
|
|||||||
# called only once for all test in this class
|
# called only once for all test in this class
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
if get_device_properties()[0] == "cuda" and get_device_properties()[1] < 9:
|
device_type, major, minor = get_device_properties()
|
||||||
|
if device_type == "cuda" and major < 9:
|
||||||
raise unittest.SkipTest("CUDA compute capability 9.0 or higher required for FP8 tests")
|
raise unittest.SkipTest("CUDA compute capability 9.0 or higher required for FP8 tests")
|
||||||
|
|
||||||
from torchao.quantization import Float8WeightOnlyConfig
|
from torchao.quantization import Float8WeightOnlyConfig
|
||||||
@@ -573,7 +574,8 @@ class TorchAoSerializationA8W4Test(TorchAoSerializationTest):
|
|||||||
# called only once for all test in this class
|
# called only once for all test in this class
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
if get_device_properties()[0] == "cuda" and get_device_properties()[1] < 9:
|
device_type, major, minor = get_device_properties()
|
||||||
|
if device_type == "cuda" and major < 9:
|
||||||
raise unittest.SkipTest("CUDA compute capability 9.0 or higher required for FP8 tests")
|
raise unittest.SkipTest("CUDA compute capability 9.0 or higher required for FP8 tests")
|
||||||
|
|
||||||
from torchao.quantization import Int8DynamicActivationInt4WeightConfig
|
from torchao.quantization import Int8DynamicActivationInt4WeightConfig
|
||||||
|
|||||||
@@ -3775,7 +3775,7 @@ class ModelTesterMixin:
|
|||||||
if not self.has_attentions:
|
if not self.has_attentions:
|
||||||
self.skipTest(reason="Model architecture does not support attentions")
|
self.skipTest(reason="Model architecture does not support attentions")
|
||||||
|
|
||||||
(device_type, major) = get_device_properties()
|
device_type, major, minor = get_device_properties()
|
||||||
if device_type == "cuda" and major < 8:
|
if device_type == "cuda" and major < 8:
|
||||||
self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")
|
self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")
|
||||||
elif device_type == "rocm" and major < 9:
|
elif device_type == "rocm" and major < 9:
|
||||||
@@ -3823,7 +3823,7 @@ class ModelTesterMixin:
|
|||||||
if not self.has_attentions:
|
if not self.has_attentions:
|
||||||
self.skipTest(reason="Model architecture does not support attentions")
|
self.skipTest(reason="Model architecture does not support attentions")
|
||||||
|
|
||||||
(device_type, major) = get_device_properties()
|
device_type, major, minor = get_device_properties()
|
||||||
if device_type == "cuda" and major < 8:
|
if device_type == "cuda" and major < 8:
|
||||||
self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")
|
self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")
|
||||||
elif device_type == "rocm" and major < 9:
|
elif device_type == "rocm" and major < 9:
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ from transformers.testing_utils import Expectations
|
|||||||
|
|
||||||
class ExpectationsTest(unittest.TestCase):
|
class ExpectationsTest(unittest.TestCase):
|
||||||
def test_expectations(self):
|
def test_expectations(self):
|
||||||
|
# We use the expectations below to make sure the right expectations are found for the right devices.
|
||||||
|
# Each value is just a unique ID.
|
||||||
expectations = Expectations(
|
expectations = Expectations(
|
||||||
{
|
{
|
||||||
(None, None): 1,
|
(None, None): 1,
|
||||||
@@ -17,18 +19,20 @@ class ExpectationsTest(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
def check(value, key):
|
def check(expected_id, device_prop):
|
||||||
assert expectations.find_expectation(key) == value
|
found_id = expectations.find_expectation(device_prop)
|
||||||
|
assert found_id == expected_id, f"Expected {expected_id} for {device_prop}, found {found_id}"
|
||||||
|
|
||||||
# npu has no matches so should find default expectation
|
# npu has no matches so should find default expectation
|
||||||
check(1, ("npu", None))
|
check(1, ("npu", None, None))
|
||||||
check(7, ("xpu", 3))
|
check(7, ("xpu", 3, None))
|
||||||
check(2, ("cuda", 8))
|
check(2, ("cuda", 8, None))
|
||||||
check(3, ("cuda", 7))
|
check(3, ("cuda", 7, None))
|
||||||
check(4, ("rocm", 9))
|
check(4, ("rocm", 9, None))
|
||||||
check(4, ("rocm", None))
|
check(4, ("rocm", None, None))
|
||||||
check(2, ("cuda", 2))
|
check(2, ("cuda", 2, None))
|
||||||
|
|
||||||
|
# We also test that if there is no default excpectation and no match is found, a ValueError is raised.
|
||||||
expectations = Expectations({("cuda", 8): 1})
|
expectations = Expectations({("cuda", 8): 1})
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
expectations.find_expectation(("xpu", None))
|
expectations.find_expectation(("xpu", None))
|
||||||
|
|||||||
Reference in New Issue
Block a user