Fix (non-slow) tests on GPU (torch) (#3024)

* Fix tests on GPU (torch)

* Fix bart slow tests

Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
Julien Chaumond
2020-02-26 11:59:25 -05:00
committed by GitHub
parent 9df74b8bc4
commit 9cda3620b6
4 changed files with 26 additions and 13 deletions

View File

@@ -20,7 +20,7 @@ from transformers import is_torch_available
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import CACHE_DIR, require_torch, slow
from .utils import CACHE_DIR, require_torch, slow, torch_device
if is_torch_available():
@@ -125,6 +125,7 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
decoder_lm_labels,
):
model = T5Model(config=config)
model.to(torch_device)
model.eval()
decoder_output, encoder_output = model(
encoder_input_ids=encoder_input_ids,
@@ -157,6 +158,7 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
decoder_lm_labels,
):
model = T5WithLMHeadModel(config=config)
model.to(torch_device)
model.eval()
outputs = model(
encoder_input_ids=encoder_input_ids,