Garrett Goon
503541d7ef
add FlashAttentionKwargs and seq_idx to flat collator (#36456)
* add flash attn kwargs to flattening collator
* add return_seq_idx option
* doc string edits
* cleaner max len updates
* various fixes
* temp testing code
* return int32 seq_idx and FlashAttnKwargs
* DataCollatorIntegrationTest impl
* fix batch dims and dtypes
* fill out remaining collator tests
* test name change and fmt
* rm unused var
* fmt
* minor change
* fmt
* add missing pos_ids check
* consistent {np,pt,tf} tests
* split pt tests into 3, like np/tf tests
* mv comment, rename fa test
* remove batch dim comment
* simply wrapping
* compute cu_seq_len/max_length once
* fmt
* remove tf code
* rm warning
* move separator_id back to 2nd pos
* use cleaner lists in tests
* ret -> batch
* fmt
* attr ordering
* use py ints for max_length_{k,q}
2025-04-16 15:45:03 +02:00
..
2022-02-23 15:46:28 -05:00
2025-04-16 15:45:03 +02:00
2025-02-20 17:38:52 +01:00
2025-02-12 15:41:31 +01:00
2025-04-08 14:12:08 +02:00
2025-04-10 20:54:21 +02:00
2025-04-08 14:12:08 +02:00
2025-04-08 14:12:08 +02:00
2025-04-08 14:12:08 +02:00
2025-04-10 20:54:21 +02:00