correct opt (#17301)
This commit is contained in:
committed by
GitHub
parent
349f1c85d3
commit
1f13ba818e
@@ -22,7 +22,7 @@ import unittest
|
|||||||
import timeout_decorator # noqa
|
import timeout_decorator # noqa
|
||||||
|
|
||||||
from transformers import OPTConfig, is_torch_available
|
from transformers import OPTConfig, is_torch_available
|
||||||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
from transformers.testing_utils import require_torch, slow, torch_device
|
||||||
|
|
||||||
from ...generation.test_generation_utils import GenerationTesterMixin
|
from ...generation.test_generation_utils import GenerationTesterMixin
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
@@ -266,25 +266,21 @@ def _long_tensor(tok_lst):
|
|||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_sentencepiece
|
|
||||||
@require_tokenizers
|
|
||||||
class OPTModelIntegrationTests(unittest.TestCase):
|
class OPTModelIntegrationTests(unittest.TestCase):
|
||||||
@slow
|
@slow
|
||||||
def test_inference_no_head(self):
|
def test_inference_no_head(self):
|
||||||
model = OPTModel.from_pretrained("facebook/opt-350m").to(torch_device)
|
model = OPTModel.from_pretrained("facebook/opt-350m").to(torch_device)
|
||||||
input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
|
input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
|
||||||
attention_mask = input_ids.ne(model.config.pad_token_id)
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output = model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
|
output = model(input_ids=input_ids).last_hidden_state
|
||||||
expected_shape = torch.Size((1, 11, 512))
|
expected_shape = torch.Size((1, 11, 512))
|
||||||
self.assertEqual(output.shape, expected_shape)
|
self.assertEqual(output.shape, expected_shape)
|
||||||
expected_slice = torch.tensor(
|
expected_slice = torch.tensor(
|
||||||
[[-0.2873, -1.9218, -0.3033], [-1.2710, -0.1338, -0.1902], [0.4095, 0.1214, -1.3121]], device=torch_device
|
[[-0.2867, -1.9256, -0.3062], [-1.2711, -0.1337, -0.1897], [0.4109, 0.1187, -1.3142]], device=torch_device
|
||||||
)
|
)
|
||||||
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-3))
|
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-3))
|
||||||
|
|
||||||
|
|
||||||
@require_tokenizers
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@slow
|
@slow
|
||||||
class OPTEmbeddingsTest(unittest.TestCase):
|
class OPTEmbeddingsTest(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user