[RoBERTa-based] Add support for sdpa (#30510)
* Adding SDPA support for RoBERTa-based models * add not is_cross_attention * fix copies * fix test * add minimal test for camembert and xlm_roberta as their test class does not inherit from ModelTesterMixin * address some review comments * use copied from * style * consistency * fix lists --------- Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -70,6 +70,7 @@ def check_sdpa_support_list():
|
||||
"For now, Transformers supports SDPA inference and training for the following architectures:"
|
||||
)[1]
|
||||
doctext = doctext.split("Note that FlashAttention can only be used for models using the")[0]
|
||||
doctext = doctext.lower()
|
||||
|
||||
patterns = glob(os.path.join(REPO_PATH, "src/transformers/models/**/modeling_*.py"))
|
||||
patterns_tf = glob(os.path.join(REPO_PATH, "src/transformers/models/**/modeling_tf_*.py"))
|
||||
@@ -85,7 +86,7 @@ def check_sdpa_support_list():
|
||||
archs_supporting_sdpa.append(model_name)
|
||||
|
||||
for arch in archs_supporting_sdpa:
|
||||
if arch not in doctext and arch not in doctext.replace("-", "_"):
|
||||
if not any(term in doctext for term in [arch, arch.replace("_", "-"), arch.replace("_", " ")]):
|
||||
raise ValueError(
|
||||
f"{arch} should be in listed in the SDPA documentation but is not. Please update the documentation."
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user