Multilingual BART - (#3602)

- support mbart-en-ro weights
- add MBartTokenizer
This commit is contained in:
Sam Shleifer
2020-04-10 11:25:39 -04:00
committed by GitHub
parent f98d0ef2a2
commit 7a7fdf71f8
7 changed files with 232 additions and 38 deletions

View File

@@ -34,6 +34,8 @@ if is_torch_available():
BartForConditionalGeneration,
BartForSequenceClassification,
BartConfig,
BartTokenizer,
MBartTokenizer,
)
from transformers.modeling_bart import (
BART_PRETRAINED_MODEL_ARCHIVE_MAP,
@@ -41,7 +43,6 @@ if is_torch_available():
invert_mask,
_prepare_bart_decoder_inputs,
)
from transformers.tokenization_bart import BartTokenizer
@require_torch
@@ -55,10 +56,10 @@ class ModelTester:
self.is_training = True
self.use_labels = False
self.vocab_size = 99
self.hidden_size = 32
self.num_hidden_layers = 5
self.hidden_size = 16
self.num_hidden_layers = 2
self.num_attention_heads = 4
self.intermediate_size = 37
self.intermediate_size = 4
self.hidden_act = "gelu"
self.hidden_dropout_prob = 0.1
self.attention_probs_dropout_prob = 0.1
@@ -105,7 +106,6 @@ def prepare_bart_inputs_dict(
@require_torch
class BARTModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (
(BartModel, BartForConditionalGeneration, BartForSequenceClassification) if is_torch_available() else ()
)
@@ -196,8 +196,114 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase):
@require_torch
class BartHeadTests(unittest.TestCase):
class BartTranslationTests(unittest.TestCase):
_model = None
@classmethod
def setUpClass(cls):
checkpoint_name = "mbart-large-en-ro"
cls.tokenizer = MBartTokenizer.from_pretrained(checkpoint_name)
cls.pad_token_id = 1
net_input = {
"input_ids": _long_tensor(
[
[3493, 3060, 621, 104064, 1810, 100, 142, 566, 13158, 6889, 5, 2, 250004],
[64511, 7, 765, 2837, 45188, 297, 4049, 237, 10, 122122, 5, 2, 250004],
]
),
"decoder_input_ids": _long_tensor(
[
[250020, 31952, 144, 9019, 242307, 21980, 55749, 11, 5, 2, 1, 1],
[250020, 884, 9019, 96, 9, 916, 86792, 36, 18743, 15596, 5, 2],
]
),
"generation_mode": False,
}
net_input["attention_mask"] = net_input["input_ids"].ne(cls.pad_token_id)
cls.net_input = net_input
return cls
@property
def model(self):
"""Only load the model if needed."""
if self._model is None:
model = BartForConditionalGeneration.from_pretrained("mbart-large-en-ro")
self._model = model
return self._model
@slow
def test_enro_forward(self):
model = self.model
with torch.no_grad():
logits, *other_stuff = model(**self.net_input)
expected_slice = torch.tensor([9.0078, 10.1113, 14.4787])
result_slice = logits[0][0][:3]
self.assertTrue(torch.allclose(expected_slice, result_slice, atol=TOLERANCE))
@slow
def test_enro_generate(self):
model = self.model
# example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
# inputs: dict = tokenizer.batch_encode_plus([example_english_phrase], return_tensors="pt",)
expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
inputs = {
"input_ids": torch.LongTensor(
[[8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2]] # 250004
)
}
translated_tokens = model.generate(input_ids=inputs["input_ids"].to(torch_device), num_beams=5,)
decoded = [
self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False)
for g in translated_tokens
]
self.assertEqual(expected_translation_romanian, decoded[0])
def test_mbart_enro_config(self):
mbart_models = ["mbart-large-en-ro"]
expected = {"scale_embedding": True, "output_past": True}
for name in mbart_models:
config = BartConfig.from_pretrained(name)
self.assertTrue(config.is_valid_mbart())
for k, v in expected.items():
try:
self.assertEqual(v, getattr(config, k))
except AssertionError as e:
e.args += (name, k)
raise
def test_enro_tokenizer(self):
raw = "UN Chief Says There Is No Military Solution in Syria"
ids = self.tokenizer.batch_encode_plus([raw])["input_ids"][0]
expected_result = [0, 8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2]
# TODO(SS): should be [8274, ..., 2, 250020]
self.assertListEqual(expected_result, ids)
def test_mbart_fast_forward(self):
config = BartConfig(
vocab_size=99,
d_model=24,
encoder_layers=2,
decoder_layers=2,
encoder_attention_heads=2,
decoder_attention_heads=2,
encoder_ffn_dim=32,
decoder_ffn_dim=32,
max_position_embeddings=48,
add_final_layer_norm=True,
)
lm_model = BartForConditionalGeneration(config).to(torch_device)
context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long().to(torch_device)
summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long().to(torch_device)
loss, logits, enc_features = lm_model(input_ids=context, decoder_input_ids=summary, lm_labels=summary)
expected_shape = (*summary.shape, config.vocab_size)
self.assertEqual(logits.shape, expected_shape)
@require_torch
class BartHeadTests(unittest.TestCase):
vocab_size = 99
def _get_config_and_data(self):
@@ -263,13 +369,13 @@ class BartHeadTests(unittest.TestCase):
def test_lm_uneven_forward(self):
config = BartConfig(
vocab_size=self.vocab_size,
d_model=24,
d_model=14,
encoder_layers=2,
decoder_layers=2,
encoder_attention_heads=2,
decoder_attention_heads=2,
encoder_ffn_dim=32,
decoder_ffn_dim=32,
encoder_ffn_dim=8,
decoder_ffn_dim=8,
max_position_embeddings=48,
)
lm_model = BartForConditionalGeneration(config).to(torch_device)
@@ -462,6 +568,7 @@ class BartModelIntegrationTests(unittest.TestCase):
@slow
def test_xsum_summarization_same_as_fairseq(self):
model = BartForConditionalGeneration.from_pretrained("bart-large-xsum").to(torch_device)
self.assertFalse(model.config.is_valid_mbart())
tok = BartTokenizer.from_pretrained("bart-large")
PGE_ARTICLE = """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""