@@ -22,6 +22,7 @@ import pytest
|
|||||||
from transformers import AutoTokenizer, BambaConfig, is_torch_available
|
from transformers import AutoTokenizer, BambaConfig, is_torch_available
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
require_torch,
|
require_torch,
|
||||||
|
require_torch_gpu,
|
||||||
slow,
|
slow,
|
||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
@@ -487,6 +488,7 @@ class BambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
|
@require_torch_gpu
|
||||||
class BambaModelIntegrationTest(unittest.TestCase):
|
class BambaModelIntegrationTest(unittest.TestCase):
|
||||||
model = None
|
model = None
|
||||||
tokenizer = None
|
tokenizer = None
|
||||||
|
|||||||
Reference in New Issue
Block a user