[GPT2] Add SDPA support (#31172)
* `gpt2` sdpa support * fix (at least) one test, style, repo consistency * fix sdpa mask in forward --> fixes generation * test * test2 * test3 * test4 * simplify shapes for attn mask creation and small comments * hub fail test * benchmarks * flash attn 2 mask should not be inverted on enc-dec setup * fix comment * apply some suggestion from code review - only save _attn_implentation once - remove unnecessary comment * change elif logic * [run-slow] gpt2 * modify `test_gpt2_sample_max_time` to follow previous assertion patterns
This commit is contained in:
@@ -832,7 +832,8 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||
start = datetime.datetime.now()
|
||||
model.generate(input_ids, do_sample=False, max_time=None, max_length=256)
|
||||
duration = datetime.datetime.now() - start
|
||||
self.assertGreater(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))
|
||||
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
|
||||
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))
|
||||
|
||||
@slow
|
||||
def test_contrastive_search_gpt2(self):
|
||||
|
||||
Reference in New Issue
Block a user