MBartTokenizer:add language codes (#3776)

This commit is contained in:
Sam Shleifer
2020-06-11 13:02:33 -04:00
committed by GitHub
parent 20451195f0
commit 08b59d10e5
2 changed files with 161 additions and 38 deletions

View File

@@ -19,6 +19,7 @@ import unittest
import timeout_decorator # noqa
from transformers import is_torch_available
from transformers.file_utils import cached_property
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor
@@ -37,6 +38,7 @@ if is_torch_available():
BartConfig,
BartTokenizer,
MBartTokenizer,
BatchEncoding,
)
from transformers.modeling_bart import (
BART_PRETRAINED_MODEL_ARCHIVE_LIST,
@@ -197,15 +199,37 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase):
tiny(**inputs_dict)
EN_CODE = 250004
@require_torch
class BartTranslationTests(unittest.TestCase):
_model = None
class MBartIntegrationTests(unittest.TestCase):
src_text = [
" UN Chief Says There Is No Military Solution in Syria",
" I ate lunch twice yesterday",
]
tgt_text = ["Şeful ONU declară că nu există o soluţie militară în Siria", "to be padded"]
expected_src_tokens = [8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2, EN_CODE]
@classmethod
def setUpClass(cls):
checkpoint_name = "facebook/mbart-large-en-ro"
cls.tokenizer = MBartTokenizer.from_pretrained(checkpoint_name)
cls.pad_token_id = 1
return cls
@cached_property
def model(self):
"""Only load the model if needed."""
model = BartForConditionalGeneration.from_pretrained("facebook/mbart-large-en-ro").to(torch_device)
if "cuda" in torch_device:
model = model.half()
return model
@slow
def test_enro_forward(self):
model = self.model
net_input = {
"input_ids": _long_tensor(
[
@@ -221,24 +245,9 @@ class BartTranslationTests(unittest.TestCase):
),
"generation_mode": False,
}
net_input["attention_mask"] = net_input["input_ids"].ne(cls.pad_token_id)
cls.net_input = net_input
return cls
@property
def model(self):
"""Only load the model if needed."""
if self._model is None:
model = BartForConditionalGeneration.from_pretrained("facebook/mbart-large-en-ro")
self._model = model.to(torch_device)
return self._model
@slow
def test_enro_forward(self):
model = self.model
net_input["attention_mask"] = net_input["input_ids"].ne(self.pad_token_id)
with torch.no_grad():
logits, *other_stuff = model(**self.net_input)
logits, *other_stuff = model(**net_input)
expected_slice = torch.tensor([9.0078, 10.1113, 14.4787], device=torch_device)
result_slice = logits[0][0][:3]
@@ -246,19 +255,10 @@ class BartTranslationTests(unittest.TestCase):
@slow
def test_enro_generate(self):
model = self.model
# example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
# inputs: dict = tokenizer.batch_encode_plus([example_english_phrase], return_tensors="pt",)
expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
inputs = {
"input_ids": torch.LongTensor(
[[8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2]] # 250004
)
}
translated_tokens = model.generate(input_ids=inputs["input_ids"].to(torch_device), num_beams=5,)
inputs: dict = self.tokenizer.prepare_translation_batch([self.src_text[0]]).to(torch_device)
translated_tokens = self.model.generate(input_ids=inputs["input_ids"].to(torch_device))
decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
self.assertEqual(expected_translation_romanian, decoded[0])
self.assertEqual(self.tgt_text[0], decoded[0])
def test_mbart_enro_config(self):
mbart_models = ["facebook/mbart-large-en-ro"]
@@ -273,13 +273,6 @@ class BartTranslationTests(unittest.TestCase):
e.args += (name, k)
raise
def test_enro_tokenizer(self):
raw = "UN Chief Says There Is No Military Solution in Syria"
ids = self.tokenizer.batch_encode_plus([raw])["input_ids"][0]
expected_result = [0, 8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2]
# TODO(SS): should be [8274, ..., 2, 250020]
self.assertListEqual(expected_result, ids)
def test_mbart_fast_forward(self):
config = BartConfig(
vocab_size=99,
@@ -301,6 +294,36 @@ class BartTranslationTests(unittest.TestCase):
self.assertEqual(logits.shape, expected_shape)
@require_torch
class MBartTokenizerTests(MBartIntegrationTests):
def test_enro_tokenizer_prepare_translation_batch(self):
batch = self.tokenizer.prepare_translation_batch(
self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens),
)
self.assertIsInstance(batch, BatchEncoding)
self.assertEqual((2, 14), batch.input_ids.shape)
self.assertEqual((2, 14), batch.attention_mask.shape)
result = batch.input_ids.tolist()[0]
self.assertListEqual(self.expected_src_tokens, result)
self.assertEqual(2, batch.decoder_input_ids[0, -2]) # EOS
def test_enro_tokenizer_batch_encode_plus(self):
ids = self.tokenizer.batch_encode_plus(self.src_text).input_ids[0]
self.assertListEqual(self.expected_src_tokens, ids)
def test_enro_tokenizer_truncation(self):
src_text = ["this is gunna be a long sentence " * 20]
assert isinstance(src_text[0], str)
desired_max_length = 10
ids = self.tokenizer.prepare_translation_batch(
src_text, return_tensors=None, max_length=desired_max_length
).input_ids[0]
self.assertEqual(ids[-2], 2)
self.assertEqual(ids[-1], EN_CODE)
self.assertEqual(len(ids), desired_max_length)
@require_torch
class BartHeadTests(unittest.TestCase):
vocab_size = 99