* use the enable_gqa param in torch.nn.functional.scaled_dot_product_attention Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * ci failure fix Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * add check Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix ci failure Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * refine code, extend to cuda Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * refine code Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix review comments Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * refine the PR Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> --------- Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
48 KiB
48 KiB