Rename BartForMaskedLM -> BartForConditionalGeneration (#3114)
* improved documentation
This commit is contained in:
@@ -7,7 +7,7 @@ file a `Github Issue <https://github.com/huggingface/transformers/issues/new?ass
|
|||||||
Paper
|
Paper
|
||||||
~~~~~
|
~~~~~
|
||||||
The Bart model was `proposed <https://arxiv.org/abs/1910.13461>`_ by Mike Lewis, Yinhan Liu, Naman Goyal, Marjan Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Ves Stoyanov and Luke Zettlemoyer on 29 Oct, 2019.
|
The Bart model was `proposed <https://arxiv.org/abs/1910.13461>`_ by Mike Lewis, Yinhan Liu, Naman Goyal, Marjan Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Ves Stoyanov and Luke Zettlemoyer on 29 Oct, 2019.
|
||||||
According to the abstract:
|
According to the abstract,
|
||||||
|
|
||||||
- Bart uses a standard seq2seq/machine translation architecture with a bidirectional encoder (like BERT) and a left-to-right decoder (like GPT).
|
- Bart uses a standard seq2seq/machine translation architecture with a bidirectional encoder (like BERT) and a left-to-right decoder (like GPT).
|
||||||
- The pretraining task involves randomly shuffling the order of the original sentences and a novel in-filling scheme, where spans of text are replaced with a single mask token.
|
- The pretraining task involves randomly shuffling the order of the original sentences and a novel in-filling scheme, where spans of text are replaced with a single mask token.
|
||||||
@@ -18,26 +18,28 @@ The Authors' code can be found `here <https://github.com/pytorch/fairseq/tree/ma
|
|||||||
|
|
||||||
Implementation Notes
|
Implementation Notes
|
||||||
~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~
|
||||||
- Bart doesn't use :obj:`token_type_ids`, for sequence classification just use BartTokenizer.encode to get the proper splitting.
|
- Bart doesn't use :obj:`token_type_ids` for sequence classification. Use BartTokenizer.encode to get the proper splitting.
|
||||||
- Inputs to the decoder are created by BartModel.forward if they are not passed. This is different than some other model APIs.
|
- The forward pass of ``BartModel`` will create decoder inputs (using the helper function ``transformers.modeling_bart._prepare_bart_decoder_inputs``) if they are not passed. This is different than some other modeling APIs.
|
||||||
- Model predictions are intended to be identical to the original implementation. This only works, however, if the string you pass to fairseq.encode starts with a space.
|
- Model predictions are intended to be identical to the original implementation. This only works, however, if the string you pass to ``fairseq.encode`` starts with a space.
|
||||||
- Decoder inputs are created automatically by the helper function ``transformers.modeling_bart._prepare_bart_decoder_inputs``
|
- ``BartForConditionalGeneration.generate`` should be used for conditional generation tasks like summarization, see the example in that docstrings
|
||||||
BartModel
|
- Models that load the ``"bart-large-cnn"`` weights will not have a ``mask_token_id``, or be able to perform mask filling tasks.
|
||||||
- ``MaskedLM.generate`` should be used for summarization, see the example in that docstrings
|
|
||||||
|
|
||||||
|
|
||||||
BartModel
|
BartModel
|
||||||
~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~
|
||||||
|
|
||||||
.. autoclass:: transformers.BartModel
|
.. autoclass:: transformers.BartModel
|
||||||
:members: forward
|
:members: forward
|
||||||
|
|
||||||
|
.. autofunction:: transformers.modeling_bart._prepare_bart_decoder_inputs
|
||||||
|
|
||||||
BartForMaskedLM
|
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
||||||
|
|
||||||
.. autoclass:: transformers.BartForMaskedLM
|
BartForConditionalGeneration
|
||||||
:members: forward, generate
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.BartForConditionalGeneration
|
||||||
|
:members: generate, forward
|
||||||
|
|
||||||
|
|
||||||
BartForSequenceClassification
|
BartForSequenceClassification
|
||||||
@@ -52,8 +54,3 @@ BartConfig
|
|||||||
.. autoclass:: transformers.BartConfig
|
.. autoclass:: transformers.BartConfig
|
||||||
:members:
|
:members:
|
||||||
|
|
||||||
Automatic Creation of Decoder Inputs
|
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
||||||
This is enabled by default
|
|
||||||
|
|
||||||
.. autofunction:: transformers.modeling_bart._prepare_bart_decoder_inputs
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from pathlib import Path
|
|||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from transformers import BartForMaskedLM, BartTokenizer
|
from transformers import BartForConditionalGeneration, BartTokenizer
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
@@ -18,7 +18,7 @@ def chunks(lst, n):
|
|||||||
|
|
||||||
def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE):
|
def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE):
|
||||||
fout = Path(out_file).open("w")
|
fout = Path(out_file).open("w")
|
||||||
model = BartForMaskedLM.from_pretrained("bart-large-cnn", output_past=True,)
|
model = BartForConditionalGeneration.from_pretrained("bart-large-cnn", output_past=True,)
|
||||||
tokenizer = BartTokenizer.from_pretrained("bart-large")
|
tokenizer = BartTokenizer.from_pretrained("bart-large")
|
||||||
for batch in tqdm(list(chunks(lns, batch_size))):
|
for batch in tqdm(list(chunks(lns, batch_size))):
|
||||||
dct = tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True)
|
dct = tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True)
|
||||||
|
|||||||
@@ -206,7 +206,11 @@ if is_torch_available():
|
|||||||
XLMForQuestionAnsweringSimple,
|
XLMForQuestionAnsweringSimple,
|
||||||
XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
|
XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
)
|
)
|
||||||
from .modeling_bart import BartForSequenceClassification, BartModel, BartForMaskedLM
|
from .modeling_bart import (
|
||||||
|
BartForSequenceClassification,
|
||||||
|
BartModel,
|
||||||
|
BartForConditionalGeneration,
|
||||||
|
)
|
||||||
from .modeling_roberta import (
|
from .modeling_roberta import (
|
||||||
RobertaForMaskedLM,
|
RobertaForMaskedLM,
|
||||||
RobertaModel,
|
RobertaModel,
|
||||||
|
|||||||
@@ -23,7 +23,13 @@ import fairseq
|
|||||||
import torch
|
import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from transformers import BartConfig, BartForMaskedLM, BartForSequenceClassification, BartModel, BartTokenizer
|
from transformers import (
|
||||||
|
BartConfig,
|
||||||
|
BartForConditionalGeneration,
|
||||||
|
BartForSequenceClassification,
|
||||||
|
BartModel,
|
||||||
|
BartTokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
FAIRSEQ_MODELS = ["bart.large", "bart.large.mnli", "bart.large.cnn"]
|
FAIRSEQ_MODELS = ["bart.large", "bart.large.mnli", "bart.large.cnn"]
|
||||||
@@ -86,14 +92,14 @@ def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path):
|
|||||||
model.eval()
|
model.eval()
|
||||||
# Check results
|
# Check results
|
||||||
|
|
||||||
if checkpoint_path == "bart.large.cnn": # generate doesnt work yet
|
if checkpoint_path == "bart.large.cnn":
|
||||||
model = BartForMaskedLM(config, base_model=model)
|
model = BartForConditionalGeneration(config, base_model=model)
|
||||||
assert "lm_head.weight" in model.state_dict()
|
assert "lm_head.weight" in model.state_dict()
|
||||||
assert model.lm_head.out_features == config.max_position_embeddings
|
assert model.lm_head.out_features == config.max_position_embeddings
|
||||||
model.eval()
|
model.eval()
|
||||||
our_outputs = model.model.forward(tokens)[0]
|
our_outputs = model.model(tokens)[0]
|
||||||
else:
|
else:
|
||||||
our_outputs = model.forward(tokens)[0]
|
our_outputs = model(tokens)[0]
|
||||||
assert their_output.shape == our_outputs.shape
|
assert their_output.shape == our_outputs.shape
|
||||||
assert (their_output == our_outputs).all().item()
|
assert (their_output == our_outputs).all().item()
|
||||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||||
|
|||||||
@@ -45,7 +45,12 @@ from .modeling_albert import (
|
|||||||
AlbertForTokenClassification,
|
AlbertForTokenClassification,
|
||||||
AlbertModel,
|
AlbertModel,
|
||||||
)
|
)
|
||||||
from .modeling_bart import BART_PRETRAINED_MODEL_ARCHIVE_MAP, BartForMaskedLM, BartForSequenceClassification, BartModel
|
from .modeling_bart import (
|
||||||
|
BART_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
BartForConditionalGeneration,
|
||||||
|
BartForSequenceClassification,
|
||||||
|
BartModel,
|
||||||
|
)
|
||||||
from .modeling_bert import (
|
from .modeling_bert import (
|
||||||
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
BertForMaskedLM,
|
BertForMaskedLM,
|
||||||
@@ -166,7 +171,7 @@ MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
|
|||||||
(AlbertConfig, AlbertForMaskedLM),
|
(AlbertConfig, AlbertForMaskedLM),
|
||||||
(CamembertConfig, CamembertForMaskedLM),
|
(CamembertConfig, CamembertForMaskedLM),
|
||||||
(XLMRobertaConfig, XLMRobertaForMaskedLM),
|
(XLMRobertaConfig, XLMRobertaForMaskedLM),
|
||||||
(BartConfig, BartForMaskedLM),
|
(BartConfig, BartForConditionalGeneration),
|
||||||
(RobertaConfig, RobertaForMaskedLM),
|
(RobertaConfig, RobertaForMaskedLM),
|
||||||
(BertConfig, BertForPreTraining),
|
(BertConfig, BertForPreTraining),
|
||||||
(OpenAIGPTConfig, OpenAIGPTLMHeadModel),
|
(OpenAIGPTConfig, OpenAIGPTLMHeadModel),
|
||||||
@@ -186,7 +191,7 @@ MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
|
|||||||
(AlbertConfig, AlbertForMaskedLM),
|
(AlbertConfig, AlbertForMaskedLM),
|
||||||
(CamembertConfig, CamembertForMaskedLM),
|
(CamembertConfig, CamembertForMaskedLM),
|
||||||
(XLMRobertaConfig, XLMRobertaForMaskedLM),
|
(XLMRobertaConfig, XLMRobertaForMaskedLM),
|
||||||
(BartConfig, BartForMaskedLM),
|
(BartConfig, BartForConditionalGeneration),
|
||||||
(RobertaConfig, RobertaForMaskedLM),
|
(RobertaConfig, RobertaForMaskedLM),
|
||||||
(BertConfig, BertForMaskedLM),
|
(BertConfig, BertForMaskedLM),
|
||||||
(OpenAIGPTConfig, OpenAIGPTLMHeadModel),
|
(OpenAIGPTConfig, OpenAIGPTLMHeadModel),
|
||||||
|
|||||||
@@ -778,21 +778,6 @@ def _filter_out_falsey_values(tup) -> Tuple:
|
|||||||
return tuple(x for x in tup if isinstance(x, torch.Tensor) or x)
|
return tuple(x for x in tup if isinstance(x, torch.Tensor) or x)
|
||||||
|
|
||||||
|
|
||||||
RET_DOCSTRING = r"""
|
|
||||||
Return:
|
|
||||||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
|
||||||
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
|
||||||
Sequence of hidden-states at the output of the last layer of the model.
|
|
||||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
|
||||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
|
||||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
|
||||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
|
||||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
|
|
||||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
|
||||||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
|
||||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
|
||||||
heads.
|
|
||||||
"""
|
|
||||||
# Public API
|
# Public API
|
||||||
|
|
||||||
|
|
||||||
@@ -863,10 +848,9 @@ class BartModel(PretrainedBartModel):
|
|||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"The bare BART Model with a language modeling head. This is the model used for summarization.",
|
"The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING,
|
||||||
BART_START_DOCSTRING,
|
|
||||||
)
|
)
|
||||||
class BartForMaskedLM(PretrainedBartModel):
|
class BartForConditionalGeneration(PretrainedBartModel):
|
||||||
base_model_prefix = "model"
|
base_model_prefix = "model"
|
||||||
|
|
||||||
def __init__(self, config: BartConfig):
|
def __init__(self, config: BartConfig):
|
||||||
@@ -919,11 +903,18 @@ class BartForMaskedLM(PretrainedBartModel):
|
|||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
|
|
||||||
tokenizer = BartTokenizer.from_pretrained('bart-large')
|
# Mask filling only works for bart-large
|
||||||
model = BartForMaskedLM.from_pretrained('bart-large')
|
from transformers import BartTokenizer, BartForConditionalGeneration
|
||||||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
tokenizer = AutoTokenizer.from_pretrained('bart-large')
|
||||||
outputs = model(input_ids=input_ids, lm_labels=input_ids)
|
TXT = "My friends are <mask> but they eat too many carbs."
|
||||||
loss, prediction_scores = outputs[:2]
|
model = BartForConditionalGeneration.from_pretrained('bart-large')
|
||||||
|
input_ids = tokenizer.batch_encode_plus([TXT], return_tensors='pt')['input_ids']
|
||||||
|
logits = model(input_ids)[0]
|
||||||
|
masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
|
||||||
|
probs = logits[0, masked_index].softmax(dim=0)
|
||||||
|
values, predictions = probs.topk(5)
|
||||||
|
tokenizer.decode(predictions).split()
|
||||||
|
# ['good', 'great', 'all', 'really', 'very']
|
||||||
"""
|
"""
|
||||||
outputs = self.model(
|
outputs = self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
@@ -992,8 +983,7 @@ class BartForMaskedLM(PretrainedBartModel):
|
|||||||
min_len=0,
|
min_len=0,
|
||||||
no_repeat_ngram_size=0,
|
no_repeat_ngram_size=0,
|
||||||
):
|
):
|
||||||
r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
|
r""" Generates summaries using the lm-head and greedy beam search
|
||||||
and beam-search.
|
|
||||||
|
|
||||||
Adapted in part from Facebook's `XLM beam search code`_ and `Fairseq beam search code`_.
|
Adapted in part from Facebook's `XLM beam search code`_ and `Fairseq beam search code`_.
|
||||||
|
|
||||||
@@ -1031,16 +1021,16 @@ class BartForMaskedLM(PretrainedBartModel):
|
|||||||
sequence_length is <= max_length (examples can finish early)
|
sequence_length is <= max_length (examples can finish early)
|
||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
|
from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig
|
||||||
config = BartConfig(vocab_size=50264, output_past=True)
|
# see ``examples/summarization/bart/evaluate_cnn.py`` for a longer example
|
||||||
model = AutoModelWithLMHead.from_pretrained('bart-large-cnn', config=config)
|
config = BartConfig(vocab_size=50264, output_past=True) # no mask_token_id
|
||||||
tokenizer = AutoTokenizer.from_pretrained('bart-large-cnn')
|
model = BartForConditionalGeneration.from_pretrained('bart-large-cnn', config=config)
|
||||||
|
tokenizer = BartTokenizer.from_pretrained('bart-large-cnn')
|
||||||
ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
|
ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
|
||||||
inputs = tokenizer.batch_encode_plus([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')
|
inputs = tokenizer.batch_encode_plus([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')
|
||||||
# Generate Summary
|
# Generate Summary
|
||||||
generated_ids = model.generate(inputs['input_ids'], attention_mask=inputs['attention_mask'], num_beams=4, max_length=5)
|
summary_ids = model.generate(inputs['input_ids'], attention_mask=inputs['attention_mask'], num_beams=4, max_length=5)
|
||||||
print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in generated_ids])
|
print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids])
|
||||||
|
|
||||||
"""
|
"""
|
||||||
bos_token_id = self.config.bos_token_id
|
bos_token_id = self.config.bos_token_id
|
||||||
pad_token_id = self.config.pad_token_id
|
pad_token_id = self.config.pad_token_id
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ if is_torch_available():
|
|||||||
from transformers import (
|
from transformers import (
|
||||||
AutoModelForSequenceClassification,
|
AutoModelForSequenceClassification,
|
||||||
BartModel,
|
BartModel,
|
||||||
BartForMaskedLM,
|
BartForConditionalGeneration,
|
||||||
BartForSequenceClassification,
|
BartForSequenceClassification,
|
||||||
BartConfig,
|
BartConfig,
|
||||||
)
|
)
|
||||||
@@ -97,7 +97,9 @@ def prepare_bart_inputs_dict(
|
|||||||
@require_torch
|
@require_torch
|
||||||
class BARTModelTest(ModelTesterMixin, unittest.TestCase):
|
class BARTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
all_model_classes = (BartModel, BartForMaskedLM, BartForSequenceClassification) if is_torch_available() else ()
|
all_model_classes = (
|
||||||
|
(BartModel, BartForConditionalGeneration, BartForSequenceClassification) if is_torch_available() else ()
|
||||||
|
)
|
||||||
is_encoder_decoder = True
|
is_encoder_decoder = True
|
||||||
# TODO(SS): fix the below in a separate PR
|
# TODO(SS): fix the below in a separate PR
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
@@ -221,8 +223,8 @@ class BartHeadTests(unittest.TestCase):
|
|||||||
|
|
||||||
def test_lm_forward(self):
|
def test_lm_forward(self):
|
||||||
config, input_ids, batch_size = self._get_config_and_data(output_past=False)
|
config, input_ids, batch_size = self._get_config_and_data(output_past=False)
|
||||||
decoder_lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size)
|
decoder_lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size).to(torch_device)
|
||||||
lm_model = BartForMaskedLM(config)
|
lm_model = BartForConditionalGeneration(config)
|
||||||
lm_model.to(torch_device)
|
lm_model.to(torch_device)
|
||||||
loss, logits, enc_features = lm_model.forward(
|
loss, logits, enc_features = lm_model.forward(
|
||||||
input_ids=input_ids, lm_labels=decoder_lm_labels, decoder_input_ids=input_ids
|
input_ids=input_ids, lm_labels=decoder_lm_labels, decoder_input_ids=input_ids
|
||||||
@@ -243,15 +245,15 @@ class BartHeadTests(unittest.TestCase):
|
|||||||
decoder_ffn_dim=32,
|
decoder_ffn_dim=32,
|
||||||
max_position_embeddings=48,
|
max_position_embeddings=48,
|
||||||
)
|
)
|
||||||
lm_model = BartForMaskedLM(config)
|
lm_model = BartForConditionalGeneration(config).to(torch_device)
|
||||||
context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long()
|
context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long().to(torch_device)
|
||||||
summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long()
|
summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long().to(torch_device)
|
||||||
logits, enc_features = lm_model.forward(input_ids=context, decoder_input_ids=summary)
|
loss, logits, enc_features = lm_model.forward(input_ids=context, decoder_input_ids=summary, lm_labels=summary)
|
||||||
expected_shape = (*summary.shape, config.vocab_size)
|
expected_shape = (*summary.shape, config.vocab_size)
|
||||||
self.assertEqual(logits.shape, expected_shape)
|
self.assertEqual(logits.shape, expected_shape)
|
||||||
|
|
||||||
def test_generate_beam_search(self):
|
def test_generate_beam_search(self):
|
||||||
input_ids = torch.Tensor([[71, 82, 2], [68, 34, 2]]).long()
|
input_ids = torch.Tensor([[71, 82, 2], [68, 34, 2]]).long().to(torch_device)
|
||||||
config = BartConfig(
|
config = BartConfig(
|
||||||
vocab_size=self.vocab_size,
|
vocab_size=self.vocab_size,
|
||||||
d_model=24,
|
d_model=24,
|
||||||
@@ -264,7 +266,7 @@ class BartHeadTests(unittest.TestCase):
|
|||||||
max_position_embeddings=48,
|
max_position_embeddings=48,
|
||||||
output_past=True,
|
output_past=True,
|
||||||
)
|
)
|
||||||
lm_model = BartForMaskedLM(config)
|
lm_model = BartForConditionalGeneration(config).to(torch_device)
|
||||||
lm_model.eval()
|
lm_model.eval()
|
||||||
|
|
||||||
new_input_ids = lm_model.generate(
|
new_input_ids = lm_model.generate(
|
||||||
@@ -376,7 +378,7 @@ class BartModelIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_cnn_summarization_same_as_fairseq(self):
|
def test_cnn_summarization_same_as_fairseq(self):
|
||||||
hf = BartForMaskedLM.from_pretrained("bart-large-cnn", output_past=True,).to(torch_device)
|
hf = BartForConditionalGeneration.from_pretrained("bart-large-cnn", output_past=True,).to(torch_device)
|
||||||
tok = BartTokenizer.from_pretrained("bart-large")
|
tok = BartTokenizer.from_pretrained("bart-large")
|
||||||
text = " (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian"
|
text = " (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian"
|
||||||
tokens = tok.encode(text, return_tensors="pt").to(torch_device)
|
tokens = tok.encode(text, return_tensors="pt").to(torch_device)
|
||||||
|
|||||||
Reference in New Issue
Block a user