MBartForConditionalGeneration (#6441)

* add MBartForConditionalGeneration

* style

* rebase and fixes

* add mbart test in TEST_FILES_WITH_NO_COMMON_TESTS

* fix docs

* don't ignore mbart

* doc

* fix mbart fairseq link

* put mbart before bart

* apply doc suggestions
This commit is contained in:
Suraj Patil
2020-08-14 12:51:16 +05:30
committed by GitHub
parent 05810cd80a
commit 680f1337c3
14 changed files with 410 additions and 283 deletions

View File

@@ -11,8 +11,8 @@ if is_torch_available():
import torch
from transformers import (
AutoModelForSeq2SeqLM,
BartConfig,
BartForConditionalGeneration,
MBartConfig,
MBartForConditionalGeneration,
BatchEncoding,
AutoTokenizer,
)
@@ -92,7 +92,7 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
mbart_models = ["facebook/mbart-large-en-ro"]
expected = {"scale_embedding": True, "output_past": True}
for name in mbart_models:
config = BartConfig.from_pretrained(name)
config = MBartConfig.from_pretrained(name)
self.assertTrue(config.is_valid_mbart())
for k, v in expected.items():
try:
@@ -102,7 +102,7 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
raise
def test_mbart_fast_forward(self):
config = BartConfig(
config = MBartConfig(
vocab_size=99,
d_model=24,
encoder_layers=2,
@@ -115,7 +115,7 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
add_final_layer_norm=True,
return_dict=True,
)
lm_model = BartForConditionalGeneration(config).to(torch_device)
lm_model = MBartForConditionalGeneration(config).to(torch_device)
context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long().to(torch_device)
summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long().to(torch_device)
result = lm_model(input_ids=context, decoder_input_ids=summary, labels=summary)