Blenderbot (#7418)
Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -4,7 +4,7 @@ from transformers import is_torch_available
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_modeling_bart import TOLERANCE, _assert_tensors_equal, _long_tensor
|
||||
from .test_modeling_bart import TOLERANCE, _long_tensor, assert_tensors_close
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -79,7 +79,17 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
|
||||
|
||||
expected_slice = torch.tensor([9.0078, 10.1113, 14.4787], device=logits.device, dtype=logits.dtype)
|
||||
result_slice = logits[0, 0, :3]
|
||||
_assert_tensors_equal(expected_slice, result_slice, atol=TOLERANCE)
|
||||
assert_tensors_close(expected_slice, result_slice, atol=TOLERANCE)
|
||||
|
||||
@slow
|
||||
def test_enro_generate_one(self):
|
||||
batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(
|
||||
["UN Chief Says There Is No Military Solution in Syria"]
|
||||
).to(torch_device)
|
||||
translated_tokens = self.model.generate(**batch)
|
||||
decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
|
||||
self.assertEqual(self.tgt_text[0], decoded[0])
|
||||
# self.assertEqual(self.tgt_text[1], decoded[1])
|
||||
|
||||
@slow
|
||||
def test_enro_generate(self):
|
||||
|
||||
Reference in New Issue
Block a user