[RoBERTa-based] Add support for sdpa (#30510)

* Adding SDPA support for RoBERTa-based models

* add not is_cross_attention

* fix copies

* fix test

* add minimal test for camembert and xlm_roberta as their test class does not inherit from ModelTesterMixin

* address some review comments

* use copied from

* style

* consistency

* fix lists

---------

Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
JB (Don)
2024-08-28 16:26:00 +08:00
committed by GitHub
parent e0b87b0f40
commit f1a385b1de
11 changed files with 828 additions and 100 deletions

View File

@@ -70,6 +70,7 @@ def check_sdpa_support_list():
"For now, Transformers supports SDPA inference and training for the following architectures:"
)[1]
doctext = doctext.split("Note that FlashAttention can only be used for models using the")[0]
doctext = doctext.lower()
patterns = glob(os.path.join(REPO_PATH, "src/transformers/models/**/modeling_*.py"))
patterns_tf = glob(os.path.join(REPO_PATH, "src/transformers/models/**/modeling_tf_*.py"))
@@ -85,7 +86,7 @@ def check_sdpa_support_list():
archs_supporting_sdpa.append(model_name)
for arch in archs_supporting_sdpa:
if arch not in doctext and arch not in doctext.replace("-", "_"):
if not any(term in doctext for term in [arch, arch.replace("_", "-"), arch.replace("_", " ")]):
raise ValueError(
f"{arch} should be in listed in the SDPA documentation but is not. Please update the documentation."
)