Fix (non-slow) tests on GPU (torch) (#3024)
* Fix tests on GPU (torch) * Fix bart slow tests Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
@@ -86,7 +86,7 @@ def _prepare_bart_decoder_inputs(
|
|||||||
causal_lm_mask = None
|
causal_lm_mask = None
|
||||||
new_shape = (bsz, tgt_len, tgt_len)
|
new_shape = (bsz, tgt_len, tgt_len)
|
||||||
# make it broadcastable so can just be added to the attention coefficients
|
# make it broadcastable so can just be added to the attention coefficients
|
||||||
decoder_attn_mask = _combine_masks(decoder_padding_mask, causal_lm_mask, new_shape)
|
decoder_attn_mask = _combine_masks(decoder_padding_mask, causal_lm_mask, new_shape).to(device=input_ids.device)
|
||||||
assert decoder_attn_mask is None or decoder_attn_mask.shape == (bsz, 1, tgt_len, tgt_len)
|
assert decoder_attn_mask is None or decoder_attn_mask.shape == (bsz, 1, tgt_len, tgt_len)
|
||||||
return decoder_input_ids, decoder_attn_mask
|
return decoder_input_ids, decoder_attn_mask
|
||||||
|
|
||||||
|
|||||||
@@ -172,7 +172,7 @@ class BartHeadTests(unittest.TestCase):
|
|||||||
vocab_size = 99
|
vocab_size = 99
|
||||||
|
|
||||||
def test_lm_forward(self):
|
def test_lm_forward(self):
|
||||||
input_ids = torch.Tensor(
|
input_ids = torch.tensor(
|
||||||
[
|
[
|
||||||
[71, 82, 18, 33, 46, 91, 2],
|
[71, 82, 18, 33, 46, 91, 2],
|
||||||
[68, 34, 26, 58, 30, 82, 2],
|
[68, 34, 26, 58, 30, 82, 2],
|
||||||
@@ -187,8 +187,10 @@ class BartHeadTests(unittest.TestCase):
|
|||||||
[21, 5, 62, 28, 14, 76, 2],
|
[21, 5, 62, 28, 14, 76, 2],
|
||||||
[45, 98, 37, 86, 59, 48, 2],
|
[45, 98, 37, 86, 59, 48, 2],
|
||||||
[70, 70, 50, 9, 28, 0, 2],
|
[70, 70, 50, 9, 28, 0, 2],
|
||||||
]
|
],
|
||||||
).long()
|
dtype=torch.long,
|
||||||
|
device=torch_device,
|
||||||
|
)
|
||||||
batch_size = input_ids.shape[0]
|
batch_size = input_ids.shape[0]
|
||||||
decoder_lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size)
|
decoder_lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size)
|
||||||
|
|
||||||
@@ -204,12 +206,14 @@ class BartHeadTests(unittest.TestCase):
|
|||||||
max_position_embeddings=48,
|
max_position_embeddings=48,
|
||||||
)
|
)
|
||||||
model = BartForSequenceClassification(config)
|
model = BartForSequenceClassification(config)
|
||||||
|
model.to(torch_device)
|
||||||
outputs = model.forward(input_ids=input_ids, decoder_input_ids=input_ids)
|
outputs = model.forward(input_ids=input_ids, decoder_input_ids=input_ids)
|
||||||
logits = outputs[0]
|
logits = outputs[0]
|
||||||
expected_shape = torch.Size((batch_size, config.num_labels))
|
expected_shape = torch.Size((batch_size, config.num_labels))
|
||||||
self.assertEqual(logits.shape, expected_shape)
|
self.assertEqual(logits.shape, expected_shape)
|
||||||
|
|
||||||
lm_model = BartForMaskedLM(config)
|
lm_model = BartForMaskedLM(config)
|
||||||
|
lm_model.to(torch_device)
|
||||||
loss, logits, enc_features = lm_model.forward(
|
loss, logits, enc_features = lm_model.forward(
|
||||||
input_ids=input_ids, lm_labels=decoder_lm_labels, decoder_input_ids=input_ids
|
input_ids=input_ids, lm_labels=decoder_lm_labels, decoder_input_ids=input_ids
|
||||||
)
|
)
|
||||||
@@ -292,6 +296,10 @@ def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
|
|||||||
raise AssertionError(msg)
|
raise AssertionError(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def _long_tensor(tok_lst):
|
||||||
|
return torch.tensor(tok_lst, dtype=torch.long, device=torch_device,)
|
||||||
|
|
||||||
|
|
||||||
TOLERANCE = 1e-4
|
TOLERANCE = 1e-4
|
||||||
|
|
||||||
|
|
||||||
@@ -299,15 +307,15 @@ TOLERANCE = 1e-4
|
|||||||
class BartModelIntegrationTest(unittest.TestCase):
|
class BartModelIntegrationTest(unittest.TestCase):
|
||||||
@slow
|
@slow
|
||||||
def test_inference_no_head(self):
|
def test_inference_no_head(self):
|
||||||
model = BartModel.from_pretrained("bart-large")
|
model = BartModel.from_pretrained("bart-large").to(torch_device)
|
||||||
input_ids = torch.Tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]]).long()
|
input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
|
||||||
inputs_dict = prepare_bart_inputs_dict(model.config, input_ids)
|
inputs_dict = prepare_bart_inputs_dict(model.config, input_ids)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output = model.forward(**inputs_dict)[0]
|
output = model.forward(**inputs_dict)[0]
|
||||||
expected_shape = torch.Size((1, 11, 1024))
|
expected_shape = torch.Size((1, 11, 1024))
|
||||||
self.assertEqual(output.shape, expected_shape)
|
self.assertEqual(output.shape, expected_shape)
|
||||||
expected_slice = torch.Tensor(
|
expected_slice = torch.Tensor(
|
||||||
[[0.7144, 0.8143, -1.2813], [0.7144, 0.8143, -1.2813], [-0.0467, 2.5911, -2.1845]]
|
[[0.7144, 0.8143, -1.2813], [0.7144, 0.8143, -1.2813], [-0.0467, 2.5911, -2.1845]], device=torch_device
|
||||||
)
|
)
|
||||||
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=TOLERANCE))
|
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=TOLERANCE))
|
||||||
|
|
||||||
@@ -315,20 +323,22 @@ class BartModelIntegrationTest(unittest.TestCase):
|
|||||||
def test_mnli_inference(self):
|
def test_mnli_inference(self):
|
||||||
|
|
||||||
example_b = [0, 31414, 232, 328, 740, 1140, 69, 46078, 1588, 2, 1]
|
example_b = [0, 31414, 232, 328, 740, 1140, 69, 46078, 1588, 2, 1]
|
||||||
input_ids = torch.Tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2], example_b]).long()
|
input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2], example_b])
|
||||||
|
|
||||||
model = AutoModelForSequenceClassification.from_pretrained("bart-large-mnli") # eval called in from_pre
|
model = AutoModelForSequenceClassification.from_pretrained("bart-large-mnli").to(
|
||||||
|
torch_device
|
||||||
|
) # eval called in from_pre
|
||||||
inputs_dict = prepare_bart_inputs_dict(model.config, input_ids)
|
inputs_dict = prepare_bart_inputs_dict(model.config, input_ids)
|
||||||
# Test that model hasn't changed
|
# Test that model hasn't changed
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
batched_logits, features = model.forward(**inputs_dict)
|
batched_logits, features = model.forward(**inputs_dict)
|
||||||
expected_shape = torch.Size((2, 3))
|
expected_shape = torch.Size((2, 3))
|
||||||
self.assertEqual(batched_logits.shape, expected_shape)
|
self.assertEqual(batched_logits.shape, expected_shape)
|
||||||
expected_slice = torch.Tensor([[0.1907, 1.4342, -1.0289]])
|
expected_slice = torch.Tensor([[0.1907, 1.4342, -1.0289]]).to(torch_device)
|
||||||
logits_arr = batched_logits[0].detach()
|
logits_arr = batched_logits[0].detach()
|
||||||
|
|
||||||
# Test that padding does not change results
|
# Test that padding does not change results
|
||||||
input_ids_no_pad = torch.Tensor([example_b[:-1]]).long()
|
input_ids_no_pad = _long_tensor([example_b[:-1]])
|
||||||
|
|
||||||
inputs_dict = prepare_bart_inputs_dict(model.config, input_ids=input_ids_no_pad)
|
inputs_dict = prepare_bart_inputs_dict(model.config, input_ids=input_ids_no_pad)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ class ModelTesterMixin:
|
|||||||
model.eval()
|
model.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model(**inputs_dict)
|
outputs = model(**inputs_dict)
|
||||||
out_2 = outputs[0].numpy()
|
out_2 = outputs[0].cpu().numpy()
|
||||||
out_2[np.isnan(out_2)] = 0
|
out_2[np.isnan(out_2)] = 0
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
@@ -472,6 +472,7 @@ class ModelTesterMixin:
|
|||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
config = copy.deepcopy(original_config)
|
config = copy.deepcopy(original_config)
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
model_vocab_size = config.vocab_size
|
model_vocab_size = config.vocab_size
|
||||||
# Retrieve the embeddings and clone theme
|
# Retrieve the embeddings and clone theme
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from transformers import is_torch_available
|
|||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||||
from .utils import CACHE_DIR, require_torch, slow
|
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -125,6 +125,7 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
decoder_lm_labels,
|
decoder_lm_labels,
|
||||||
):
|
):
|
||||||
model = T5Model(config=config)
|
model = T5Model(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
decoder_output, encoder_output = model(
|
decoder_output, encoder_output = model(
|
||||||
encoder_input_ids=encoder_input_ids,
|
encoder_input_ids=encoder_input_ids,
|
||||||
@@ -157,6 +158,7 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
decoder_lm_labels,
|
decoder_lm_labels,
|
||||||
):
|
):
|
||||||
model = T5WithLMHeadModel(config=config)
|
model = T5WithLMHeadModel(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
outputs = model(
|
outputs = model(
|
||||||
encoder_input_ids=encoder_input_ids,
|
encoder_input_ids=encoder_input_ids,
|
||||||
|
|||||||
Reference in New Issue
Block a user