MBartForConditionalGeneration (#6441)
* add MBartForConditionalGeneration * style * rebase and fixes * add mbart test in TEST_FILES_WITH_NO_COMMON_TESTS * fix docs * don't ignore mbart * doc * fix mbart fairseq link * put mbart before bart * apply doc suggestions
This commit is contained in:
@@ -11,8 +11,8 @@ if is_torch_available():
|
||||
import torch
|
||||
from transformers import (
|
||||
AutoModelForSeq2SeqLM,
|
||||
BartConfig,
|
||||
BartForConditionalGeneration,
|
||||
MBartConfig,
|
||||
MBartForConditionalGeneration,
|
||||
BatchEncoding,
|
||||
AutoTokenizer,
|
||||
)
|
||||
@@ -92,7 +92,7 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
|
||||
mbart_models = ["facebook/mbart-large-en-ro"]
|
||||
expected = {"scale_embedding": True, "output_past": True}
|
||||
for name in mbart_models:
|
||||
config = BartConfig.from_pretrained(name)
|
||||
config = MBartConfig.from_pretrained(name)
|
||||
self.assertTrue(config.is_valid_mbart())
|
||||
for k, v in expected.items():
|
||||
try:
|
||||
@@ -102,7 +102,7 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
|
||||
raise
|
||||
|
||||
def test_mbart_fast_forward(self):
|
||||
config = BartConfig(
|
||||
config = MBartConfig(
|
||||
vocab_size=99,
|
||||
d_model=24,
|
||||
encoder_layers=2,
|
||||
@@ -115,7 +115,7 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
|
||||
add_final_layer_norm=True,
|
||||
return_dict=True,
|
||||
)
|
||||
lm_model = BartForConditionalGeneration(config).to(torch_device)
|
||||
lm_model = MBartForConditionalGeneration(config).to(torch_device)
|
||||
context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long().to(torch_device)
|
||||
summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long().to(torch_device)
|
||||
result = lm_model(input_ids=context, decoder_input_ids=summary, labels=summary)
|
||||
|
||||
Reference in New Issue
Block a user