CI: AMD MI300 tests fix (#30797)
* add fix * update import * updated dicts and comments * remove prints * Update testing_utils.py
This commit is contained in:
@@ -601,6 +601,11 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
@require_read_token
|
||||
def test_model_2b_bf16(self):
|
||||
model_id = "google/gemma-2b"
|
||||
|
||||
# Key 9 for MI300, Key 8 for A100/A10, and Key 7 for T4.
|
||||
#
|
||||
# Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s,
|
||||
# considering differences in hardware processing and potential deviations in generated text.
|
||||
EXPECTED_TEXTS = {
|
||||
7: [
|
||||
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
|
||||
@@ -610,6 +615,10 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
|
||||
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Kaju Kat",
|
||||
],
|
||||
9: [
|
||||
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
|
||||
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Kaju Kat",
|
||||
],
|
||||
}
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to(
|
||||
@@ -627,6 +636,11 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
@require_read_token
|
||||
def test_model_2b_eager(self):
|
||||
model_id = "google/gemma-2b"
|
||||
|
||||
# Key 9 for MI300, Key 8 for A100/A10, and Key 7 for T4.
|
||||
#
|
||||
# Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s,
|
||||
# considering differences in hardware processing and potential deviations in generated text.
|
||||
EXPECTED_TEXTS = {
|
||||
7: [
|
||||
"Hello I am doing a project on the 1990s and I am looking for some information on the ",
|
||||
@@ -636,6 +650,10 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
|
||||
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Kaju Kat",
|
||||
],
|
||||
9: [
|
||||
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
|
||||
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Kaju Kat",
|
||||
],
|
||||
}
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
@@ -655,6 +673,11 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
@require_read_token
|
||||
def test_model_2b_sdpa(self):
|
||||
model_id = "google/gemma-2b"
|
||||
|
||||
# Key 9 for MI300, Key 8 for A100/A10, and Key 7 for T4.
|
||||
#
|
||||
# Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s,
|
||||
# considering differences in hardware processing and potential deviations in generated text.
|
||||
EXPECTED_TEXTS = {
|
||||
7: [
|
||||
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
|
||||
@@ -664,6 +687,10 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
|
||||
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Kaju Kat",
|
||||
],
|
||||
9: [
|
||||
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
|
||||
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Kaju Kat",
|
||||
],
|
||||
}
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
@@ -763,6 +790,11 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
@require_read_token
|
||||
def test_model_7b_bf16(self):
|
||||
model_id = "google/gemma-7b"
|
||||
|
||||
# Key 9 for MI300, Key 8 for A100/A10, and Key 7 for T4.
|
||||
#
|
||||
# Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s,
|
||||
# considering differences in hardware processing and potential deviations in generated text.
|
||||
EXPECTED_TEXTS = {
|
||||
7: [
|
||||
"""Hello I am doing a project on a 1991 240sx and I am trying to find""",
|
||||
@@ -772,6 +804,10 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
"Hello I am doing a project for my school and I am trying to make a program that will read a .txt file",
|
||||
"Hi today I am going to show you how to make a very simple and easy to make a very simple and",
|
||||
],
|
||||
9: [
|
||||
"Hello I am doing a project for my school and I am trying to get a servo to move a certain amount of degrees",
|
||||
"Hi today I am going to show you how to make a very simple and easy to make DIY light up sign",
|
||||
],
|
||||
}
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to(
|
||||
@@ -845,6 +881,11 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
NUM_TOKENS_TO_GENERATE = 40
|
||||
# Note on `EXPECTED_TEXT_COMPLETION`'s diff: the current value matches the original test if the original test
|
||||
# was changed to have a cache of 53 tokens (as opposed to 4096), on Ampere GPUs.
|
||||
#
|
||||
# Key 9 for MI300, Key 8 for A100/A10, and Key 7 for T4.
|
||||
#
|
||||
# Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s,
|
||||
# considering differences in hardware processing and potential deviations in generated text.
|
||||
EXPECTED_TEXT_COMPLETION = {
|
||||
8: [
|
||||
"Hello I am doing a project on the 1990s and I need to know what the most popular music was in the 1990s. I have looked on the internet and I have found",
|
||||
@@ -854,6 +895,10 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
"Hello I am doing a project on the 1990s and I need to know what the most popular music was in the 1990s. I have looked on the internet and I have found",
|
||||
"Hi today\nI have a problem with my 2007 1.9 tdi 105bhp.\nI have a problem with the engine management light on.\nI have checked the",
|
||||
],
|
||||
9: [
|
||||
"Hello I am doing a project on the 1990s and I need to know what the most popular music was in the 1990s. I have looked on the internet and I have found",
|
||||
"Hi today\nI have a problem with my 2007 1.9 tdi 105bhp.\nI have a problem with the engine management light on.\nI have checked the",
|
||||
],
|
||||
}
|
||||
|
||||
prompts = ["Hello I am doing", "Hi today"]
|
||||
|
||||
Reference in New Issue
Block a user