Fix (non-slow) tests on GPU (torch) (#3024)
* Fix tests on GPU (torch) * Fix bart slow tests Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
@@ -68,7 +68,7 @@ class ModelTesterMixin:
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs_dict)
|
||||
out_2 = outputs[0].numpy()
|
||||
out_2 = outputs[0].cpu().numpy()
|
||||
out_2[np.isnan(out_2)] = 0
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
@@ -472,6 +472,7 @@ class ModelTesterMixin:
|
||||
for model_class in self.all_model_classes:
|
||||
config = copy.deepcopy(original_config)
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
|
||||
model_vocab_size = config.vocab_size
|
||||
# Retrieve the embeddings and clone theme
|
||||
|
||||
Reference in New Issue
Block a user