Rename BartForMaskedLM -> BartForConditionalGeneration (#3114)
* improved documentation
This commit is contained in:
@@ -29,7 +29,7 @@ if is_torch_available():
|
||||
from transformers import (
|
||||
AutoModelForSequenceClassification,
|
||||
BartModel,
|
||||
BartForMaskedLM,
|
||||
BartForConditionalGeneration,
|
||||
BartForSequenceClassification,
|
||||
BartConfig,
|
||||
)
|
||||
@@ -97,7 +97,9 @@ def prepare_bart_inputs_dict(
|
||||
@require_torch
|
||||
class BARTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (BartModel, BartForMaskedLM, BartForSequenceClassification) if is_torch_available() else ()
|
||||
all_model_classes = (
|
||||
(BartModel, BartForConditionalGeneration, BartForSequenceClassification) if is_torch_available() else ()
|
||||
)
|
||||
is_encoder_decoder = True
|
||||
# TODO(SS): fix the below in a separate PR
|
||||
test_pruning = False
|
||||
@@ -221,8 +223,8 @@ class BartHeadTests(unittest.TestCase):
|
||||
|
||||
def test_lm_forward(self):
|
||||
config, input_ids, batch_size = self._get_config_and_data(output_past=False)
|
||||
decoder_lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size)
|
||||
lm_model = BartForMaskedLM(config)
|
||||
decoder_lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size).to(torch_device)
|
||||
lm_model = BartForConditionalGeneration(config)
|
||||
lm_model.to(torch_device)
|
||||
loss, logits, enc_features = lm_model.forward(
|
||||
input_ids=input_ids, lm_labels=decoder_lm_labels, decoder_input_ids=input_ids
|
||||
@@ -243,15 +245,15 @@ class BartHeadTests(unittest.TestCase):
|
||||
decoder_ffn_dim=32,
|
||||
max_position_embeddings=48,
|
||||
)
|
||||
lm_model = BartForMaskedLM(config)
|
||||
context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long()
|
||||
summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long()
|
||||
logits, enc_features = lm_model.forward(input_ids=context, decoder_input_ids=summary)
|
||||
lm_model = BartForConditionalGeneration(config).to(torch_device)
|
||||
context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long().to(torch_device)
|
||||
summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long().to(torch_device)
|
||||
loss, logits, enc_features = lm_model.forward(input_ids=context, decoder_input_ids=summary, lm_labels=summary)
|
||||
expected_shape = (*summary.shape, config.vocab_size)
|
||||
self.assertEqual(logits.shape, expected_shape)
|
||||
|
||||
def test_generate_beam_search(self):
|
||||
input_ids = torch.Tensor([[71, 82, 2], [68, 34, 2]]).long()
|
||||
input_ids = torch.Tensor([[71, 82, 2], [68, 34, 2]]).long().to(torch_device)
|
||||
config = BartConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
d_model=24,
|
||||
@@ -264,7 +266,7 @@ class BartHeadTests(unittest.TestCase):
|
||||
max_position_embeddings=48,
|
||||
output_past=True,
|
||||
)
|
||||
lm_model = BartForMaskedLM(config)
|
||||
lm_model = BartForConditionalGeneration(config).to(torch_device)
|
||||
lm_model.eval()
|
||||
|
||||
new_input_ids = lm_model.generate(
|
||||
@@ -376,7 +378,7 @@ class BartModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
@slow
|
||||
def test_cnn_summarization_same_as_fairseq(self):
|
||||
hf = BartForMaskedLM.from_pretrained("bart-large-cnn", output_past=True,).to(torch_device)
|
||||
hf = BartForConditionalGeneration.from_pretrained("bart-large-cnn", output_past=True,).to(torch_device)
|
||||
tok = BartTokenizer.from_pretrained("bart-large")
|
||||
text = " (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian"
|
||||
tokens = tok.encode(text, return_tensors="pt").to(torch_device)
|
||||
|
||||
Reference in New Issue
Block a user