Fix (non-slow) tests on GPU (torch) (#3024)

* Fix tests on GPU (torch)

* Fix bart slow tests

Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
Julien Chaumond
2020-02-26 11:59:25 -05:00
committed by GitHub
parent 9df74b8bc4
commit 9cda3620b6
4 changed files with 26 additions and 13 deletions

View File

@@ -68,7 +68,7 @@ class ModelTesterMixin:
model.eval()
with torch.no_grad():
outputs = model(**inputs_dict)
out_2 = outputs[0].numpy()
out_2 = outputs[0].cpu().numpy()
out_2[np.isnan(out_2)] = 0
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -472,6 +472,7 @@ class ModelTesterMixin:
for model_class in self.all_model_classes:
config = copy.deepcopy(original_config)
model = model_class(config)
model.to(torch_device)
model_vocab_size = config.vocab_size
# Retrieve the embeddings and clone theme