Unbreak optimum-executorch (#38646)
* Unbreak optimum-executorch * use static cache if has layer_types but no sliding_window * revert view on kv_arange --------- Co-authored-by: Guang Yang <guangyang@fb.com>
This commit is contained in:
@@ -378,7 +378,6 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
|
||||
from transformers.integrations.executorch import (
|
||||
TorchExportableModuleWithStaticCache,
|
||||
convert_and_export_with_cache,
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b", pad_token="</s>", padding_side="right")
|
||||
@@ -424,7 +423,10 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, eager_generated_text)
|
||||
|
||||
# Static Cache + export
|
||||
exported_program = convert_and_export_with_cache(model)
|
||||
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
|
||||
|
||||
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
|
||||
exported_program = exportable_module.export()
|
||||
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
|
||||
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
|
||||
)
|
||||
|
||||
@@ -313,7 +313,6 @@ class Gemma2IntegrationTest(unittest.TestCase):
|
||||
|
||||
from transformers.integrations.executorch import (
|
||||
TorchExportableModuleWithStaticCache,
|
||||
convert_and_export_with_cache,
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b", pad_token="</s>", padding_side="right")
|
||||
@@ -363,7 +362,10 @@ class Gemma2IntegrationTest(unittest.TestCase):
|
||||
max_new_tokens = max_generation_length - prompt_token_ids.shape[-1]
|
||||
|
||||
# Static Cache + export
|
||||
exported_program = convert_and_export_with_cache(model)
|
||||
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
|
||||
|
||||
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
|
||||
exported_program = exportable_module.export()
|
||||
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
|
||||
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
|
||||
)
|
||||
|
||||
@@ -306,7 +306,6 @@ class LlamaIntegrationTest(unittest.TestCase):
|
||||
|
||||
from transformers.integrations.executorch import (
|
||||
TorchExportableModuleWithStaticCache,
|
||||
convert_and_export_with_cache,
|
||||
)
|
||||
|
||||
llama_models = {
|
||||
@@ -352,7 +351,10 @@ class LlamaIntegrationTest(unittest.TestCase):
|
||||
max_new_tokens = max_generation_length - prompt_token_ids.shape[-1]
|
||||
|
||||
# Static Cache + export
|
||||
exported_program = convert_and_export_with_cache(model)
|
||||
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
|
||||
|
||||
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
|
||||
exported_program = exportable_module.export()
|
||||
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
|
||||
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
|
||||
)
|
||||
|
||||
@@ -334,7 +334,6 @@ class OlmoIntegrationTest(unittest.TestCase):
|
||||
|
||||
from transformers.integrations.executorch import (
|
||||
TorchExportableModuleWithStaticCache,
|
||||
convert_and_export_with_cache,
|
||||
)
|
||||
|
||||
olmo_model = "allenai/OLMo-1B-hf"
|
||||
@@ -382,7 +381,10 @@ class OlmoIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, eager_generated_text)
|
||||
|
||||
# Static Cache + export
|
||||
exported_program = convert_and_export_with_cache(model)
|
||||
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
|
||||
|
||||
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
|
||||
exported_program = exportable_module.export()
|
||||
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
|
||||
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
|
||||
)
|
||||
|
||||
@@ -347,7 +347,6 @@ class Phi3IntegrationTest(unittest.TestCase):
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
||||
from transformers.integrations.executorch import (
|
||||
TorchExportableModuleWithStaticCache,
|
||||
convert_and_export_with_cache,
|
||||
)
|
||||
|
||||
model_id = "microsoft/Phi-4-mini-instruct"
|
||||
@@ -399,7 +398,10 @@ class Phi3IntegrationTest(unittest.TestCase):
|
||||
max_new_tokens = max_generation_length - prompt_token_ids.shape[-1]
|
||||
|
||||
# Static Cache + export
|
||||
exported_program = convert_and_export_with_cache(model)
|
||||
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
|
||||
|
||||
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
|
||||
exported_program = exportable_module.export()
|
||||
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
|
||||
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
|
||||
)
|
||||
|
||||
@@ -31,7 +31,6 @@ from transformers.testing_utils import (
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils.import_utils import is_torch_greater_or_equal
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -246,7 +245,6 @@ class Qwen2IntegrationTest(unittest.TestCase):
|
||||
|
||||
from transformers.integrations.executorch import (
|
||||
TorchExportableModuleWithStaticCache,
|
||||
convert_and_export_with_cache,
|
||||
)
|
||||
|
||||
qwen_model = "Qwen/Qwen2-0.5B"
|
||||
@@ -287,8 +285,13 @@ class Qwen2IntegrationTest(unittest.TestCase):
|
||||
max_new_tokens = max_generation_length - prompt_token_ids.shape[-1]
|
||||
|
||||
# Static Cache + export
|
||||
strict = is_torch_greater_or_equal("2.7.0") # Due to https://github.com/pytorch/pytorch/issues/150994
|
||||
exported_program = convert_and_export_with_cache(model, strict=strict)
|
||||
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
|
||||
|
||||
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
|
||||
strict = version.parse(torch.__version__) != version.parse(
|
||||
"2.7.0"
|
||||
) # Due to https://github.com/pytorch/pytorch/issues/150994
|
||||
exported_program = exportable_module.export(strict=strict)
|
||||
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
|
||||
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
|
||||
)
|
||||
|
||||
@@ -31,7 +31,6 @@ from transformers.testing_utils import (
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils.import_utils import is_torch_greater_or_equal
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -240,13 +239,12 @@ class Qwen3IntegrationTest(unittest.TestCase):
|
||||
|
||||
from transformers.integrations.executorch import (
|
||||
TorchExportableModuleWithStaticCache,
|
||||
convert_and_export_with_cache,
|
||||
)
|
||||
|
||||
qwen_model = "Qwen/Qwen3-0.6B-Base"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(qwen_model, pad_token="</s>", padding_side="right")
|
||||
if is_torch_greater_or_equal("2.7.0"):
|
||||
if version.parse(torch.__version__) == version.parse("2.7.0"):
|
||||
strict = False # Due to https://github.com/pytorch/pytorch/issues/150994
|
||||
EXPECTED_TEXT_COMPLETION = ["My favourite condiment is 100% plain, unflavoured, and unadulterated."]
|
||||
else:
|
||||
@@ -285,7 +283,10 @@ class Qwen3IntegrationTest(unittest.TestCase):
|
||||
max_new_tokens = max_generation_length - prompt_token_ids.shape[-1]
|
||||
|
||||
# Static Cache + export
|
||||
exported_program = convert_and_export_with_cache(model, strict=strict)
|
||||
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
|
||||
|
||||
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
|
||||
exported_program = exportable_module.export(strict=strict)
|
||||
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
|
||||
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user