Update bros checkpoint (#26277)
* fix bros integration test * update bros checkpoint
This commit is contained in:
@@ -17,9 +17,8 @@
|
||||
import copy
|
||||
import unittest
|
||||
|
||||
from transformers import BrosProcessor
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
from transformers.utils import cached_property, is_torch_available, is_vision_available
|
||||
from transformers.utils import is_torch_available
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
||||
@@ -412,13 +411,10 @@ def prepare_bros_batch_inputs():
|
||||
|
||||
@require_torch
|
||||
class BrosModelIntegrationTest(unittest.TestCase):
|
||||
@cached_property
|
||||
def default_processor(self):
|
||||
return BrosProcessor.from_pretrained("naver-clova-ocr/bros-base-uncased") if is_vision_available() else None
|
||||
|
||||
@slow
|
||||
def test_inference_no_head(self):
|
||||
model = BrosModel.from_pretrained("naver-clova-ocr/bros-base-uncased").to(torch_device)
|
||||
model = BrosModel.from_pretrained("jinho8345/bros-base-uncased").to(torch_device)
|
||||
|
||||
input_ids, bbox, attention_mask = prepare_bros_batch_inputs()
|
||||
|
||||
with torch.no_grad():
|
||||
@@ -434,7 +430,8 @@ class BrosModelIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor(
|
||||
[[-0.4027, 0.0756, -0.0647], [-0.0192, -0.0065, 0.1042], [-0.0671, 0.0214, 0.0960]]
|
||||
[[-0.3074, 0.1363, 0.3143], [0.0925, -0.1155, 0.1050], [0.0221, 0.0003, 0.1285]]
|
||||
).to(torch_device)
|
||||
torch.set_printoptions(sci_mode=False)
|
||||
|
||||
self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4))
|
||||
|
||||
Reference in New Issue
Block a user