* Fix graph break in torch.compile when using FA2 with attention_mask=None and batch size > 1 * fix code format * add test; replace position_ids with query_states becasue position_ids.shape[0] is always 1 * add assert loss is not nan
229 KiB
Executable File
229 KiB
Executable File