MBartTokenizer:add language codes (#3776)
This commit is contained in:
@@ -19,6 +19,7 @@ import unittest
|
||||
import timeout_decorator # noqa
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.file_utils import cached_property
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
@@ -37,6 +38,7 @@ if is_torch_available():
|
||||
BartConfig,
|
||||
BartTokenizer,
|
||||
MBartTokenizer,
|
||||
BatchEncoding,
|
||||
)
|
||||
from transformers.modeling_bart import (
|
||||
BART_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
@@ -197,15 +199,37 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
tiny(**inputs_dict)
|
||||
|
||||
|
||||
EN_CODE = 250004
|
||||
|
||||
|
||||
@require_torch
|
||||
class BartTranslationTests(unittest.TestCase):
|
||||
_model = None
|
||||
class MBartIntegrationTests(unittest.TestCase):
|
||||
src_text = [
|
||||
" UN Chief Says There Is No Military Solution in Syria",
|
||||
" I ate lunch twice yesterday",
|
||||
]
|
||||
tgt_text = ["Şeful ONU declară că nu există o soluţie militară în Siria", "to be padded"]
|
||||
expected_src_tokens = [8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2, EN_CODE]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
checkpoint_name = "facebook/mbart-large-en-ro"
|
||||
cls.tokenizer = MBartTokenizer.from_pretrained(checkpoint_name)
|
||||
cls.pad_token_id = 1
|
||||
return cls
|
||||
|
||||
@cached_property
|
||||
def model(self):
|
||||
"""Only load the model if needed."""
|
||||
|
||||
model = BartForConditionalGeneration.from_pretrained("facebook/mbart-large-en-ro").to(torch_device)
|
||||
if "cuda" in torch_device:
|
||||
model = model.half()
|
||||
return model
|
||||
|
||||
@slow
|
||||
def test_enro_forward(self):
|
||||
model = self.model
|
||||
net_input = {
|
||||
"input_ids": _long_tensor(
|
||||
[
|
||||
@@ -221,24 +245,9 @@ class BartTranslationTests(unittest.TestCase):
|
||||
),
|
||||
"generation_mode": False,
|
||||
}
|
||||
net_input["attention_mask"] = net_input["input_ids"].ne(cls.pad_token_id)
|
||||
cls.net_input = net_input
|
||||
|
||||
return cls
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
"""Only load the model if needed."""
|
||||
if self._model is None:
|
||||
model = BartForConditionalGeneration.from_pretrained("facebook/mbart-large-en-ro")
|
||||
self._model = model.to(torch_device)
|
||||
return self._model
|
||||
|
||||
@slow
|
||||
def test_enro_forward(self):
|
||||
model = self.model
|
||||
net_input["attention_mask"] = net_input["input_ids"].ne(self.pad_token_id)
|
||||
with torch.no_grad():
|
||||
logits, *other_stuff = model(**self.net_input)
|
||||
logits, *other_stuff = model(**net_input)
|
||||
|
||||
expected_slice = torch.tensor([9.0078, 10.1113, 14.4787], device=torch_device)
|
||||
result_slice = logits[0][0][:3]
|
||||
@@ -246,19 +255,10 @@ class BartTranslationTests(unittest.TestCase):
|
||||
|
||||
@slow
|
||||
def test_enro_generate(self):
|
||||
model = self.model
|
||||
# example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
|
||||
# inputs: dict = tokenizer.batch_encode_plus([example_english_phrase], return_tensors="pt",)
|
||||
expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
|
||||
|
||||
inputs = {
|
||||
"input_ids": torch.LongTensor(
|
||||
[[8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2]] # 250004
|
||||
)
|
||||
}
|
||||
translated_tokens = model.generate(input_ids=inputs["input_ids"].to(torch_device), num_beams=5,)
|
||||
inputs: dict = self.tokenizer.prepare_translation_batch([self.src_text[0]]).to(torch_device)
|
||||
translated_tokens = self.model.generate(input_ids=inputs["input_ids"].to(torch_device))
|
||||
decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
|
||||
self.assertEqual(expected_translation_romanian, decoded[0])
|
||||
self.assertEqual(self.tgt_text[0], decoded[0])
|
||||
|
||||
def test_mbart_enro_config(self):
|
||||
mbart_models = ["facebook/mbart-large-en-ro"]
|
||||
@@ -273,13 +273,6 @@ class BartTranslationTests(unittest.TestCase):
|
||||
e.args += (name, k)
|
||||
raise
|
||||
|
||||
def test_enro_tokenizer(self):
|
||||
raw = "UN Chief Says There Is No Military Solution in Syria"
|
||||
ids = self.tokenizer.batch_encode_plus([raw])["input_ids"][0]
|
||||
expected_result = [0, 8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2]
|
||||
# TODO(SS): should be [8274, ..., 2, 250020]
|
||||
self.assertListEqual(expected_result, ids)
|
||||
|
||||
def test_mbart_fast_forward(self):
|
||||
config = BartConfig(
|
||||
vocab_size=99,
|
||||
@@ -301,6 +294,36 @@ class BartTranslationTests(unittest.TestCase):
|
||||
self.assertEqual(logits.shape, expected_shape)
|
||||
|
||||
|
||||
@require_torch
|
||||
class MBartTokenizerTests(MBartIntegrationTests):
|
||||
def test_enro_tokenizer_prepare_translation_batch(self):
|
||||
batch = self.tokenizer.prepare_translation_batch(
|
||||
self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens),
|
||||
)
|
||||
self.assertIsInstance(batch, BatchEncoding)
|
||||
|
||||
self.assertEqual((2, 14), batch.input_ids.shape)
|
||||
self.assertEqual((2, 14), batch.attention_mask.shape)
|
||||
result = batch.input_ids.tolist()[0]
|
||||
self.assertListEqual(self.expected_src_tokens, result)
|
||||
self.assertEqual(2, batch.decoder_input_ids[0, -2]) # EOS
|
||||
|
||||
def test_enro_tokenizer_batch_encode_plus(self):
|
||||
ids = self.tokenizer.batch_encode_plus(self.src_text).input_ids[0]
|
||||
self.assertListEqual(self.expected_src_tokens, ids)
|
||||
|
||||
def test_enro_tokenizer_truncation(self):
|
||||
src_text = ["this is gunna be a long sentence " * 20]
|
||||
assert isinstance(src_text[0], str)
|
||||
desired_max_length = 10
|
||||
ids = self.tokenizer.prepare_translation_batch(
|
||||
src_text, return_tensors=None, max_length=desired_max_length
|
||||
).input_ids[0]
|
||||
self.assertEqual(ids[-2], 2)
|
||||
self.assertEqual(ids[-1], EN_CODE)
|
||||
self.assertEqual(len(ids), desired_max_length)
|
||||
|
||||
|
||||
@require_torch
|
||||
class BartHeadTests(unittest.TestCase):
|
||||
vocab_size = 99
|
||||
|
||||
Reference in New Issue
Block a user