Multilingual BART - (#3602)
- support mbart-en-ro weights - add MBartTokenizer
This commit is contained in:
@@ -34,6 +34,8 @@ if is_torch_available():
|
||||
BartForConditionalGeneration,
|
||||
BartForSequenceClassification,
|
||||
BartConfig,
|
||||
BartTokenizer,
|
||||
MBartTokenizer,
|
||||
)
|
||||
from transformers.modeling_bart import (
|
||||
BART_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
@@ -41,7 +43,6 @@ if is_torch_available():
|
||||
invert_mask,
|
||||
_prepare_bart_decoder_inputs,
|
||||
)
|
||||
from transformers.tokenization_bart import BartTokenizer
|
||||
|
||||
|
||||
@require_torch
|
||||
@@ -55,10 +56,10 @@ class ModelTester:
|
||||
self.is_training = True
|
||||
self.use_labels = False
|
||||
self.vocab_size = 99
|
||||
self.hidden_size = 32
|
||||
self.num_hidden_layers = 5
|
||||
self.hidden_size = 16
|
||||
self.num_hidden_layers = 2
|
||||
self.num_attention_heads = 4
|
||||
self.intermediate_size = 37
|
||||
self.intermediate_size = 4
|
||||
self.hidden_act = "gelu"
|
||||
self.hidden_dropout_prob = 0.1
|
||||
self.attention_probs_dropout_prob = 0.1
|
||||
@@ -105,7 +106,6 @@ def prepare_bart_inputs_dict(
|
||||
|
||||
@require_torch
|
||||
class BARTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (
|
||||
(BartModel, BartForConditionalGeneration, BartForSequenceClassification) if is_torch_available() else ()
|
||||
)
|
||||
@@ -196,8 +196,114 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
|
||||
@require_torch
|
||||
class BartHeadTests(unittest.TestCase):
|
||||
class BartTranslationTests(unittest.TestCase):
|
||||
_model = None
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
checkpoint_name = "mbart-large-en-ro"
|
||||
cls.tokenizer = MBartTokenizer.from_pretrained(checkpoint_name)
|
||||
cls.pad_token_id = 1
|
||||
net_input = {
|
||||
"input_ids": _long_tensor(
|
||||
[
|
||||
[3493, 3060, 621, 104064, 1810, 100, 142, 566, 13158, 6889, 5, 2, 250004],
|
||||
[64511, 7, 765, 2837, 45188, 297, 4049, 237, 10, 122122, 5, 2, 250004],
|
||||
]
|
||||
),
|
||||
"decoder_input_ids": _long_tensor(
|
||||
[
|
||||
[250020, 31952, 144, 9019, 242307, 21980, 55749, 11, 5, 2, 1, 1],
|
||||
[250020, 884, 9019, 96, 9, 916, 86792, 36, 18743, 15596, 5, 2],
|
||||
]
|
||||
),
|
||||
"generation_mode": False,
|
||||
}
|
||||
net_input["attention_mask"] = net_input["input_ids"].ne(cls.pad_token_id)
|
||||
cls.net_input = net_input
|
||||
|
||||
return cls
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
"""Only load the model if needed."""
|
||||
if self._model is None:
|
||||
model = BartForConditionalGeneration.from_pretrained("mbart-large-en-ro")
|
||||
self._model = model
|
||||
return self._model
|
||||
|
||||
@slow
|
||||
def test_enro_forward(self):
|
||||
model = self.model
|
||||
with torch.no_grad():
|
||||
logits, *other_stuff = model(**self.net_input)
|
||||
|
||||
expected_slice = torch.tensor([9.0078, 10.1113, 14.4787])
|
||||
result_slice = logits[0][0][:3]
|
||||
self.assertTrue(torch.allclose(expected_slice, result_slice, atol=TOLERANCE))
|
||||
|
||||
@slow
|
||||
def test_enro_generate(self):
|
||||
model = self.model
|
||||
# example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
|
||||
# inputs: dict = tokenizer.batch_encode_plus([example_english_phrase], return_tensors="pt",)
|
||||
expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
|
||||
|
||||
inputs = {
|
||||
"input_ids": torch.LongTensor(
|
||||
[[8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2]] # 250004
|
||||
)
|
||||
}
|
||||
translated_tokens = model.generate(input_ids=inputs["input_ids"].to(torch_device), num_beams=5,)
|
||||
decoded = [
|
||||
self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
for g in translated_tokens
|
||||
]
|
||||
self.assertEqual(expected_translation_romanian, decoded[0])
|
||||
|
||||
def test_mbart_enro_config(self):
|
||||
mbart_models = ["mbart-large-en-ro"]
|
||||
expected = {"scale_embedding": True, "output_past": True}
|
||||
for name in mbart_models:
|
||||
config = BartConfig.from_pretrained(name)
|
||||
self.assertTrue(config.is_valid_mbart())
|
||||
for k, v in expected.items():
|
||||
try:
|
||||
self.assertEqual(v, getattr(config, k))
|
||||
except AssertionError as e:
|
||||
e.args += (name, k)
|
||||
raise
|
||||
|
||||
def test_enro_tokenizer(self):
|
||||
raw = "UN Chief Says There Is No Military Solution in Syria"
|
||||
ids = self.tokenizer.batch_encode_plus([raw])["input_ids"][0]
|
||||
expected_result = [0, 8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2]
|
||||
# TODO(SS): should be [8274, ..., 2, 250020]
|
||||
self.assertListEqual(expected_result, ids)
|
||||
|
||||
def test_mbart_fast_forward(self):
|
||||
config = BartConfig(
|
||||
vocab_size=99,
|
||||
d_model=24,
|
||||
encoder_layers=2,
|
||||
decoder_layers=2,
|
||||
encoder_attention_heads=2,
|
||||
decoder_attention_heads=2,
|
||||
encoder_ffn_dim=32,
|
||||
decoder_ffn_dim=32,
|
||||
max_position_embeddings=48,
|
||||
add_final_layer_norm=True,
|
||||
)
|
||||
lm_model = BartForConditionalGeneration(config).to(torch_device)
|
||||
context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long().to(torch_device)
|
||||
summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long().to(torch_device)
|
||||
loss, logits, enc_features = lm_model(input_ids=context, decoder_input_ids=summary, lm_labels=summary)
|
||||
expected_shape = (*summary.shape, config.vocab_size)
|
||||
self.assertEqual(logits.shape, expected_shape)
|
||||
|
||||
|
||||
@require_torch
|
||||
class BartHeadTests(unittest.TestCase):
|
||||
vocab_size = 99
|
||||
|
||||
def _get_config_and_data(self):
|
||||
@@ -263,13 +369,13 @@ class BartHeadTests(unittest.TestCase):
|
||||
def test_lm_uneven_forward(self):
|
||||
config = BartConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
d_model=24,
|
||||
d_model=14,
|
||||
encoder_layers=2,
|
||||
decoder_layers=2,
|
||||
encoder_attention_heads=2,
|
||||
decoder_attention_heads=2,
|
||||
encoder_ffn_dim=32,
|
||||
decoder_ffn_dim=32,
|
||||
encoder_ffn_dim=8,
|
||||
decoder_ffn_dim=8,
|
||||
max_position_embeddings=48,
|
||||
)
|
||||
lm_model = BartForConditionalGeneration(config).to(torch_device)
|
||||
@@ -462,6 +568,7 @@ class BartModelIntegrationTests(unittest.TestCase):
|
||||
@slow
|
||||
def test_xsum_summarization_same_as_fairseq(self):
|
||||
model = BartForConditionalGeneration.from_pretrained("bart-large-xsum").to(torch_device)
|
||||
self.assertFalse(model.config.is_valid_mbart())
|
||||
tok = BartTokenizer.from_pretrained("bart-large")
|
||||
|
||||
PGE_ARTICLE = """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""
|
||||
|
||||
Reference in New Issue
Block a user