Fix (non-slow) tests on GPU (torch) (#3024)
* Fix tests on GPU (torch) * Fix bart slow tests Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
@@ -20,7 +20,7 @@ from transformers import is_torch_available
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
from .utils import CACHE_DIR, require_torch, slow
|
||||
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -125,6 +125,7 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
decoder_lm_labels,
|
||||
):
|
||||
model = T5Model(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
decoder_output, encoder_output = model(
|
||||
encoder_input_ids=encoder_input_ids,
|
||||
@@ -157,6 +158,7 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
decoder_lm_labels,
|
||||
):
|
||||
model = T5WithLMHeadModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
outputs = model(
|
||||
encoder_input_ids=encoder_input_ids,
|
||||
|
||||
Reference in New Issue
Block a user