Update bros checkpoint (#26277)

* fix bros integration test

* update bros checkpoint
This commit is contained in:
Jinho Park
2023-09-20 17:22:07 +09:00
committed by GitHub
parent 86ffd5ffa2
commit 37c205eb5d
4 changed files with 22 additions and 25 deletions

View File

@@ -17,9 +17,8 @@
import copy
import unittest
from transformers import BrosProcessor
from transformers.testing_utils import require_torch, slow, torch_device
from transformers.utils import cached_property, is_torch_available, is_vision_available
from transformers.utils import is_torch_available
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
@@ -412,13 +411,10 @@ def prepare_bros_batch_inputs():
@require_torch
class BrosModelIntegrationTest(unittest.TestCase):
@cached_property
def default_processor(self):
return BrosProcessor.from_pretrained("naver-clova-ocr/bros-base-uncased") if is_vision_available() else None
@slow
def test_inference_no_head(self):
model = BrosModel.from_pretrained("naver-clova-ocr/bros-base-uncased").to(torch_device)
model = BrosModel.from_pretrained("jinho8345/bros-base-uncased").to(torch_device)
input_ids, bbox, attention_mask = prepare_bros_batch_inputs()
with torch.no_grad():
@@ -434,7 +430,8 @@ class BrosModelIntegrationTest(unittest.TestCase):
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
expected_slice = torch.tensor(
[[-0.4027, 0.0756, -0.0647], [-0.0192, -0.0065, 0.1042], [-0.0671, 0.0214, 0.0960]]
[[-0.3074, 0.1363, 0.3143], [0.0925, -0.1155, 0.1050], [0.0221, 0.0003, 0.1285]]
).to(torch_device)
torch.set_printoptions(sci_mode=False)
self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4))