MBartForConditionalGeneration (#6441)
* add MBartForConditionalGeneration * style * rebase and fixes * add mbart test in TEST_FILES_WITH_NO_COMMON_TESTS * fix docs * don't ignore mbart * doc * fix mbart fairseq link * put mbart before bart * apply doc suggestions
This commit is contained in:
@@ -22,7 +22,7 @@ import logging
|
||||
# Configurations
|
||||
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
|
||||
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, AutoConfig
|
||||
from .configuration_bart import BartConfig, MBartConfig
|
||||
from .configuration_bart import BartConfig
|
||||
from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
|
||||
from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
|
||||
from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
|
||||
@@ -34,6 +34,7 @@ from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, Flau
|
||||
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
|
||||
from .configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig
|
||||
from .configuration_marian import MarianConfig
|
||||
from .configuration_mbart import MBartConfig
|
||||
from .configuration_mmbt import MMBTConfig
|
||||
from .configuration_mobilebert import MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileBertConfig
|
||||
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
|
||||
@@ -131,7 +132,7 @@ from .pipelines import (
|
||||
# Tokenizers
|
||||
from .tokenization_albert import AlbertTokenizer
|
||||
from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
|
||||
from .tokenization_bart import BartTokenizer, BartTokenizerFast, MBartTokenizer
|
||||
from .tokenization_bart import BartTokenizer, BartTokenizerFast
|
||||
from .tokenization_bert import BasicTokenizer, BertTokenizer, BertTokenizerFast, WordpieceTokenizer
|
||||
from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenizer, MecabTokenizer
|
||||
from .tokenization_camembert import CamembertTokenizer
|
||||
@@ -149,6 +150,7 @@ from .tokenization_electra import ElectraTokenizer, ElectraTokenizerFast
|
||||
from .tokenization_flaubert import FlaubertTokenizer
|
||||
from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
|
||||
from .tokenization_longformer import LongformerTokenizer, LongformerTokenizerFast
|
||||
from .tokenization_mbart import MBartTokenizer
|
||||
from .tokenization_mobilebert import MobileBertTokenizer, MobileBertTokenizerFast
|
||||
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
|
||||
from .tokenization_pegasus import PegasusTokenizer
|
||||
@@ -298,6 +300,7 @@ if is_torch_available():
|
||||
BartForQuestionAnswering,
|
||||
BART_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
)
|
||||
from .modeling_mbart import MBartForConditionalGeneration
|
||||
from .modeling_marian import MarianMTModel
|
||||
from .tokenization_marian import MarianTokenizer
|
||||
from .modeling_roberta import (
|
||||
|
||||
@@ -19,7 +19,7 @@ import logging
|
||||
from collections import OrderedDict
|
||||
|
||||
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
|
||||
from .configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig, MBartConfig
|
||||
from .configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig
|
||||
from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
|
||||
from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
|
||||
from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
|
||||
@@ -30,6 +30,7 @@ from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, Flau
|
||||
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
|
||||
from .configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig
|
||||
from .configuration_marian import MarianConfig
|
||||
from .configuration_mbart import MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, MBartConfig
|
||||
from .configuration_mobilebert import MobileBertConfig
|
||||
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
|
||||
from .configuration_pegasus import PegasusConfig
|
||||
@@ -52,6 +53,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
|
||||
for pretrained_map in [
|
||||
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
BART_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
MBART_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
|
||||
@@ -32,6 +32,7 @@ BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"facebook/mbart-large-en-ro": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json",
|
||||
"yjernite/bart_eli5": "https://s3.amazonaws.com/models.huggingface.co/bert/yjernite/bart_eli5/config.json",
|
||||
}
|
||||
|
||||
BART_CONFIG_ARGS_DOC = r"""
|
||||
Args:
|
||||
vocab_size (:obj:`int`, optional, defaults to 50265):
|
||||
@@ -209,8 +210,3 @@ class BartConfig(PretrainedConfig):
|
||||
if self.normalize_before or self.add_final_layer_norm or self.scale_embedding:
|
||||
logger.info("This configuration is a mixture of MBART and BART settings")
|
||||
return False
|
||||
|
||||
|
||||
class MBartConfig(BartConfig):
|
||||
model_type = "mbart"
|
||||
"""See real config values at https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json."""
|
||||
|
||||
32
src/transformers/configuration_mbart.py
Normal file
32
src/transformers/configuration_mbart.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The Fairseq Authors and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" MBART configuration """
|
||||
|
||||
import logging
|
||||
|
||||
from .configuration_bart import BartConfig
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MBART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"facebook/mbart-large-en-ro": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json",
|
||||
"facebook/mbart-large-cc25": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-cc25/config.json",
|
||||
}
|
||||
|
||||
|
||||
class MBartConfig(BartConfig):
|
||||
model_type = "mbart"
|
||||
"""See real config values at https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json."""
|
||||
@@ -32,6 +32,7 @@ from .configuration_auto import (
|
||||
FlaubertConfig,
|
||||
GPT2Config,
|
||||
LongformerConfig,
|
||||
MBartConfig,
|
||||
MobileBertConfig,
|
||||
OpenAIGPTConfig,
|
||||
PegasusConfig,
|
||||
@@ -116,6 +117,7 @@ from .modeling_longformer import (
|
||||
LongformerModel,
|
||||
)
|
||||
from .modeling_marian import MarianMTModel
|
||||
from .modeling_mbart import MBartForConditionalGeneration
|
||||
from .modeling_mobilebert import (
|
||||
MobileBertForMaskedLM,
|
||||
MobileBertForMultipleChoice,
|
||||
@@ -289,6 +291,7 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
|
||||
(T5Config, T5ForConditionalGeneration),
|
||||
(PegasusConfig, PegasusForConditionalGeneration),
|
||||
(MarianConfig, MarianMTModel),
|
||||
(MBartConfig, MBartForConditionalGeneration),
|
||||
(BartConfig, BartForConditionalGeneration),
|
||||
(EncoderDecoderConfig, EncoderDecoderModel),
|
||||
]
|
||||
|
||||
38
src/transformers/modeling_mbart.py
Normal file
38
src/transformers/modeling_mbart.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from .configuration_mbart import MBartConfig
|
||||
from .file_utils import add_start_docstrings
|
||||
from .modeling_bart import BartForConditionalGeneration
|
||||
|
||||
|
||||
_CONFIG_FOR_DOC = "MBartConfig"
|
||||
_TOKENIZER_FOR_DOC = "MBartTokenizer"
|
||||
|
||||
MBART_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"facebook/mbart-large-cc25",
|
||||
"facebook/mbart-large-en-ro",
|
||||
# See all multilingual BART models at https://huggingface.co/models?filter=mbart
|
||||
]
|
||||
|
||||
MBART_START_DOCSTRING = r"""
|
||||
|
||||
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__ sub-class.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
|
||||
usage and behavior.
|
||||
|
||||
Parameters:
|
||||
config (:class:`~transformers.MBartConfig`): Model configuration class with all the parameters of the
|
||||
model. Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration.
|
||||
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The BART Model with a language modeling head. Can be used for machine translation.", MBART_START_DOCSTRING
|
||||
)
|
||||
class MBartForConditionalGeneration(BartForConditionalGeneration):
|
||||
"""
|
||||
This class overrides :class:`~transformers.BartForConditionalGeneration`. Please check the
|
||||
superclass for the appropriate documentation alongside usage examples.
|
||||
"""
|
||||
|
||||
config_class = MBartConfig
|
||||
@@ -46,7 +46,7 @@ from .configuration_auto import (
|
||||
)
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .tokenization_albert import AlbertTokenizer
|
||||
from .tokenization_bart import BartTokenizer, BartTokenizerFast, MBartTokenizer
|
||||
from .tokenization_bart import BartTokenizer, BartTokenizerFast
|
||||
from .tokenization_bert import BertTokenizer, BertTokenizerFast
|
||||
from .tokenization_bert_japanese import BertJapaneseTokenizer
|
||||
from .tokenization_camembert import CamembertTokenizer
|
||||
@@ -57,6 +57,7 @@ from .tokenization_flaubert import FlaubertTokenizer
|
||||
from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
|
||||
from .tokenization_longformer import LongformerTokenizer, LongformerTokenizerFast
|
||||
from .tokenization_marian import MarianTokenizer
|
||||
from .tokenization_mbart import MBartTokenizer
|
||||
from .tokenization_mobilebert import MobileBertTokenizer, MobileBertTokenizerFast
|
||||
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
|
||||
from .tokenization_pegasus import PegasusTokenizer
|
||||
|
||||
@@ -14,13 +14,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from .file_utils import add_start_docstrings_to_callable
|
||||
from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
|
||||
from .tokenization_utils import BatchEncoding
|
||||
from .tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING
|
||||
from .tokenization_xlm_roberta import XLMRobertaTokenizer
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -55,258 +50,3 @@ class BartTokenizerFast(RobertaTokenizerFast):
|
||||
"vocab_file": {m: vocab_url for m in _all_bart_models},
|
||||
"merges_file": {m: merges_url for m in _all_bart_models},
|
||||
}
|
||||
|
||||
|
||||
_all_mbart_models = ["facebook/mbart-large-en-ro", "facebook/mbart-large-cc25"]
|
||||
SPM_URL = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/sentence.bpe.model"
|
||||
|
||||
FAIRSEQ_LANGUAGE_CODES = [
|
||||
"ar_AR",
|
||||
"cs_CZ",
|
||||
"de_DE",
|
||||
"en_XX",
|
||||
"es_XX",
|
||||
"et_EE",
|
||||
"fi_FI",
|
||||
"fr_XX",
|
||||
"gu_IN",
|
||||
"hi_IN",
|
||||
"it_IT",
|
||||
"ja_XX",
|
||||
"kk_KZ",
|
||||
"ko_KR",
|
||||
"lt_LT",
|
||||
"lv_LV",
|
||||
"my_MM",
|
||||
"ne_NP",
|
||||
"nl_XX",
|
||||
"ro_RO",
|
||||
"ru_RU",
|
||||
"si_LK",
|
||||
"tr_TR",
|
||||
"vi_VN",
|
||||
"zh_CN",
|
||||
]
|
||||
|
||||
|
||||
class MBartTokenizer(XLMRobertaTokenizer):
|
||||
"""
|
||||
This inherits from XLMRobertaTokenizer. ``prepare_seq2seq_batch`` should be used to encode inputs.
|
||||
Other tokenizer methods like ``encode`` do not work properly.
|
||||
The tokenization method is ``<tokens> <eos> <language code>`` for source language documents, and
|
||||
``<language code> <tokens> <eos>``` for target language documents.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import MBartTokenizer
|
||||
>>> tokenizer = MBartTokenizer.from_pretrained('facebook/mbart-large-en-ro')
|
||||
>>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
|
||||
>>> expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
|
||||
>>> batch: dict = tokenizer.prepare_seq2seq_batch(
|
||||
... example_english_phrase, src_lang="en_XX", tgt_lang="ro_RO", tgt_texts=expected_translation_romanian
|
||||
... )
|
||||
|
||||
"""
|
||||
|
||||
vocab_files_names = {"vocab_file": "sentencepiece.bpe.model"}
|
||||
max_model_input_sizes = {m: 1024 for m in _all_mbart_models}
|
||||
pretrained_vocab_files_map = {"vocab_file": {m: SPM_URL for m in _all_mbart_models}}
|
||||
|
||||
prefix_tokens: List[int] = []
|
||||
suffix_tokens: List[int] = []
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.sp_model_size = len(self.sp_model)
|
||||
self.lang_code_to_id = {
|
||||
code: self.sp_model_size + i + self.fairseq_offset for i, code in enumerate(FAIRSEQ_LANGUAGE_CODES)
|
||||
}
|
||||
self.id_to_lang_code = {v: k for k, v in self.lang_code_to_id.items()}
|
||||
self.cur_lang_code = self.lang_code_to_id["en_XX"]
|
||||
self.fairseq_tokens_to_ids["<mask>"] = len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset
|
||||
|
||||
self.fairseq_tokens_to_ids.update(self.lang_code_to_id)
|
||||
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
|
||||
self._additional_special_tokens = list(self.lang_code_to_id.keys())
|
||||
self.set_src_lang_special_tokens(kwargs.get("src_lang", "en_XX"))
|
||||
|
||||
def build_inputs_with_special_tokens(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
|
||||
by concatenating and adding special tokens. The special tokens depend on calling set_lang.
|
||||
An MBART sequence has the following format, where ``X`` represents the sequence:
|
||||
- ``input_ids`` (for encoder) ``X [eos, src_lang_code]``
|
||||
- ``decoder_input_ids``: (for decoder) ``[tgt_lang_code] X [eos]``
|
||||
BOS is never used.
|
||||
Pairs of sequences are not the expected use case, but they will be handled without a separator.
|
||||
|
||||
Args:
|
||||
token_ids_0 (:obj:`List[int]`):
|
||||
List of IDs to which the special tokens will be added
|
||||
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
|
||||
Returns:
|
||||
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
||||
"""
|
||||
if token_ids_1 is None:
|
||||
return self.prefix_tokens + token_ids_0 + self.suffix_tokens
|
||||
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
|
||||
|
||||
def get_special_tokens_mask(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
||||
) -> List[int]:
|
||||
"""
|
||||
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
||||
special tokens using the tokenizer ``prepare_for_model`` methods.
|
||||
|
||||
Args:
|
||||
token_ids_0 (:obj:`List[int]`):
|
||||
List of ids.
|
||||
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Set to True if the token list is already formatted with special tokens for the model
|
||||
|
||||
Returns:
|
||||
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
||||
"""
|
||||
|
||||
if already_has_special_tokens:
|
||||
if token_ids_1 is not None:
|
||||
raise ValueError(
|
||||
"You should not supply a second sequence if the provided sequence of "
|
||||
"ids is already formated with special tokens for the model."
|
||||
)
|
||||
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
|
||||
prefix_ones = [1] * len(self.prefix_tokens)
|
||||
suffix_ones = [1] * len(self.suffix_tokens)
|
||||
if token_ids_1 is None:
|
||||
return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones
|
||||
return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
|
||||
|
||||
@add_start_docstrings_to_callable(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
src_lang: str = "en_XX",
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
tgt_lang: str = "ro_RO",
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
truncation: bool = True,
|
||||
padding: str = "longest",
|
||||
return_tensors: str = "pt",
|
||||
**kwargs,
|
||||
) -> BatchEncoding:
|
||||
"""Prepare a batch that can be passed directly to an instance of MBartModel.
|
||||
|
||||
Arguments:
|
||||
src_texts: (:obj:`list`):
|
||||
list of documents to summarize or source language texts
|
||||
src_lang: (:obj:`str`, `optional`, default='en_XX'):
|
||||
default en_XX (english), the language we are translating from
|
||||
tgt_texts: (:obj:`list`, `optional`):
|
||||
list of tgt language texts or summaries.
|
||||
tgt_lang: (:obj:`str`, `optional`, default='ro_RO'):
|
||||
default ro_RO (romanian), the language we are translating to
|
||||
max_length (:obj:`int`, `optional`):
|
||||
Controls the maximum length for encoder inputs (documents to summarize or source language texts)
|
||||
If left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum
|
||||
length is required by one of the truncation/padding parameters. If the model has no specific maximum
|
||||
input length (like XLNet) truncation/padding to a maximum length will be deactivated.
|
||||
max_target_length (:obj:`int`, `optional`):
|
||||
Controls the maximum length of decoder inputs (target language texts or summaries)
|
||||
If left unset or set to :obj:`None`, this will use the max_length value.
|
||||
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
|
||||
Activates and controls padding. Accepts the following values:
|
||||
|
||||
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
|
||||
single sequence if provided).
|
||||
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
||||
maximum acceptable input length for the model if that argument is not provided.
|
||||
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
||||
different lengths).
|
||||
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`, defaults to "pt"):
|
||||
If set, will return tensors instead of list of python integers. Acceptable values are:
|
||||
|
||||
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
|
||||
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
|
||||
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
|
||||
truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`):
|
||||
Activates and controls truncation. Accepts the following values:
|
||||
|
||||
* :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument
|
||||
:obj:`max_length` or to the maximum acceptable input length for the model if that argument is not
|
||||
provided. This will truncate token by token, removing a token from the longest sequence in the pair
|
||||
if a pair of sequences (or a batch of pairs) is provided.
|
||||
* :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to
|
||||
the maximum acceptable input length for the model if that argument is not provided. This will only
|
||||
truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
|
||||
* :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or
|
||||
to the maximum acceptable input length for the model if that argument is not provided. This will only
|
||||
truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
|
||||
* :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with
|
||||
sequence lengths greater than the model maximum admissible input size).
|
||||
|
||||
Return:
|
||||
:class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields:
|
||||
|
||||
- **input_ids** -- List of token ids to be fed to the encoder.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
|
||||
- **decoder_input_ids** -- List of token ids to be fed to the decoder.
|
||||
- **decoder_attention_mask** -- List of indices specifying which tokens should be attended to by the decoder.
|
||||
This does not include causal mask, which is built by the model.
|
||||
|
||||
The full set of keys ``[input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]``,
|
||||
will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys.
|
||||
|
||||
"""
|
||||
if max_length is None:
|
||||
max_length = self.max_len
|
||||
self.set_src_lang_special_tokens(src_lang)
|
||||
model_inputs: BatchEncoding = self(
|
||||
src_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
max_length=max_length,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)
|
||||
if tgt_texts is None:
|
||||
return model_inputs
|
||||
# Process tgt_texts
|
||||
if max_target_length is None:
|
||||
max_target_length = max_length
|
||||
self.set_tgt_lang_special_tokens(tgt_lang)
|
||||
decoder_inputs: BatchEncoding = self(
|
||||
tgt_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
max_length=max_target_length,
|
||||
truncation=True,
|
||||
**kwargs,
|
||||
)
|
||||
for k, v in decoder_inputs.items():
|
||||
model_inputs[f"decoder_{k}"] = v
|
||||
|
||||
self.set_src_lang_special_tokens(src_lang) # sets to src_lang
|
||||
return model_inputs
|
||||
|
||||
def set_src_lang_special_tokens(self, src_lang) -> None:
|
||||
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, cur_lang_code]."""
|
||||
self.cur_lang_code = self.lang_code_to_id[src_lang]
|
||||
self.prefix_tokens = []
|
||||
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
|
||||
|
||||
def set_tgt_lang_special_tokens(self, lang: str) -> None:
|
||||
"""Reset the special tokens to the target language setting. Prefix [tgt_lang_code], suffix =[eos]."""
|
||||
self.cur_lang_code = self.lang_code_to_id[lang]
|
||||
self.prefix_tokens = [self.cur_lang_code]
|
||||
self.suffix_tokens = [self.eos_token_id]
|
||||
|
||||
279
src/transformers/tokenization_mbart.py
Normal file
279
src/transformers/tokenization_mbart.py
Normal file
@@ -0,0 +1,279 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from .file_utils import add_start_docstrings_to_callable
|
||||
from .tokenization_utils import BatchEncoding
|
||||
from .tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING
|
||||
from .tokenization_xlm_roberta import XLMRobertaTokenizer
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_all_mbart_models = ["facebook/mbart-large-en-ro", "facebook/mbart-large-cc25"]
|
||||
SPM_URL = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/sentence.bpe.model"
|
||||
|
||||
FAIRSEQ_LANGUAGE_CODES = [
|
||||
"ar_AR",
|
||||
"cs_CZ",
|
||||
"de_DE",
|
||||
"en_XX",
|
||||
"es_XX",
|
||||
"et_EE",
|
||||
"fi_FI",
|
||||
"fr_XX",
|
||||
"gu_IN",
|
||||
"hi_IN",
|
||||
"it_IT",
|
||||
"ja_XX",
|
||||
"kk_KZ",
|
||||
"ko_KR",
|
||||
"lt_LT",
|
||||
"lv_LV",
|
||||
"my_MM",
|
||||
"ne_NP",
|
||||
"nl_XX",
|
||||
"ro_RO",
|
||||
"ru_RU",
|
||||
"si_LK",
|
||||
"tr_TR",
|
||||
"vi_VN",
|
||||
"zh_CN",
|
||||
]
|
||||
|
||||
|
||||
class MBartTokenizer(XLMRobertaTokenizer):
|
||||
"""
|
||||
This inherits from XLMRobertaTokenizer. ``prepare_seq2seq_batch`` should be used to encode inputs.
|
||||
Other tokenizer methods like ``encode`` do not work properly.
|
||||
The tokenization method is ``<tokens> <eos> <language code>`` for source language documents, and
|
||||
``<language code> <tokens> <eos>``` for target language documents.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import MBartTokenizer
|
||||
>>> tokenizer = MBartTokenizer.from_pretrained('facebook/mbart-large-en-ro')
|
||||
>>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
|
||||
>>> expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
|
||||
>>> batch: dict = tokenizer.prepare_seq2seq_batch(
|
||||
... example_english_phrase, src_lang="en_XX", tgt_lang="ro_RO", tgt_texts=expected_translation_romanian
|
||||
... )
|
||||
|
||||
"""
|
||||
|
||||
vocab_files_names = {"vocab_file": "sentencepiece.bpe.model"}
|
||||
max_model_input_sizes = {m: 1024 for m in _all_mbart_models}
|
||||
pretrained_vocab_files_map = {"vocab_file": {m: SPM_URL for m in _all_mbart_models}}
|
||||
|
||||
prefix_tokens: List[int] = []
|
||||
suffix_tokens: List[int] = []
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.sp_model_size = len(self.sp_model)
|
||||
self.lang_code_to_id = {
|
||||
code: self.sp_model_size + i + self.fairseq_offset for i, code in enumerate(FAIRSEQ_LANGUAGE_CODES)
|
||||
}
|
||||
self.id_to_lang_code = {v: k for k, v in self.lang_code_to_id.items()}
|
||||
self.cur_lang_code = self.lang_code_to_id["en_XX"]
|
||||
self.fairseq_tokens_to_ids["<mask>"] = len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset
|
||||
|
||||
self.fairseq_tokens_to_ids.update(self.lang_code_to_id)
|
||||
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
|
||||
self._additional_special_tokens = list(self.lang_code_to_id.keys())
|
||||
self.set_src_lang_special_tokens(kwargs.get("src_lang", "en_XX"))
|
||||
|
||||
def build_inputs_with_special_tokens(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
|
||||
by concatenating and adding special tokens. The special tokens depend on calling set_lang.
|
||||
An MBART sequence has the following format, where ``X`` represents the sequence:
|
||||
- ``input_ids`` (for encoder) ``X [eos, src_lang_code]``
|
||||
- ``decoder_input_ids``: (for decoder) ``[tgt_lang_code] X [eos]``
|
||||
BOS is never used.
|
||||
Pairs of sequences are not the expected use case, but they will be handled without a separator.
|
||||
|
||||
Args:
|
||||
token_ids_0 (:obj:`List[int]`):
|
||||
List of IDs to which the special tokens will be added
|
||||
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
|
||||
Returns:
|
||||
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
||||
"""
|
||||
if token_ids_1 is None:
|
||||
return self.prefix_tokens + token_ids_0 + self.suffix_tokens
|
||||
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
|
||||
|
||||
def get_special_tokens_mask(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
||||
) -> List[int]:
|
||||
"""
|
||||
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
||||
special tokens using the tokenizer ``prepare_for_model`` methods.
|
||||
|
||||
Args:
|
||||
token_ids_0 (:obj:`List[int]`):
|
||||
List of ids.
|
||||
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Set to True if the token list is already formatted with special tokens for the model
|
||||
|
||||
Returns:
|
||||
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
||||
"""
|
||||
|
||||
if already_has_special_tokens:
|
||||
if token_ids_1 is not None:
|
||||
raise ValueError(
|
||||
"You should not supply a second sequence if the provided sequence of "
|
||||
"ids is already formated with special tokens for the model."
|
||||
)
|
||||
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
|
||||
prefix_ones = [1] * len(self.prefix_tokens)
|
||||
suffix_ones = [1] * len(self.suffix_tokens)
|
||||
if token_ids_1 is None:
|
||||
return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones
|
||||
return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
|
||||
|
||||
@add_start_docstrings_to_callable(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
src_lang: str = "en_XX",
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
tgt_lang: str = "ro_RO",
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
truncation: bool = True,
|
||||
padding: str = "longest",
|
||||
return_tensors: str = "pt",
|
||||
**kwargs,
|
||||
) -> BatchEncoding:
|
||||
"""Prepare a batch that can be passed directly to an instance of MBartModel.
|
||||
|
||||
Arguments:
|
||||
src_texts: (:obj:`list`):
|
||||
list of documents to summarize or source language texts
|
||||
src_lang: (:obj:`str`, `optional`, default='en_XX'):
|
||||
default en_XX (english), the language we are translating from
|
||||
tgt_texts: (:obj:`list`, `optional`):
|
||||
list of tgt language texts or summaries.
|
||||
tgt_lang: (:obj:`str`, `optional`, default='ro_RO'):
|
||||
default ro_RO (romanian), the language we are translating to
|
||||
max_length (:obj:`int`, `optional`):
|
||||
Controls the maximum length for encoder inputs (documents to summarize or source language texts)
|
||||
If left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum
|
||||
length is required by one of the truncation/padding parameters. If the model has no specific maximum
|
||||
input length (like XLNet) truncation/padding to a maximum length will be deactivated.
|
||||
max_target_length (:obj:`int`, `optional`):
|
||||
Controls the maximum length of decoder inputs (target language texts or summaries)
|
||||
If left unset or set to :obj:`None`, this will use the max_length value.
|
||||
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
|
||||
Activates and controls padding. Accepts the following values:
|
||||
|
||||
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
|
||||
single sequence if provided).
|
||||
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
||||
maximum acceptable input length for the model if that argument is not provided.
|
||||
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
||||
different lengths).
|
||||
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`, defaults to "pt"):
|
||||
If set, will return tensors instead of list of python integers. Acceptable values are:
|
||||
|
||||
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
|
||||
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
|
||||
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
|
||||
truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`):
|
||||
Activates and controls truncation. Accepts the following values:
|
||||
|
||||
* :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument
|
||||
:obj:`max_length` or to the maximum acceptable input length for the model if that argument is not
|
||||
provided. This will truncate token by token, removing a token from the longest sequence in the pair
|
||||
if a pair of sequences (or a batch of pairs) is provided.
|
||||
* :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to
|
||||
the maximum acceptable input length for the model if that argument is not provided. This will only
|
||||
truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
|
||||
* :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or
|
||||
to the maximum acceptable input length for the model if that argument is not provided. This will only
|
||||
truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
|
||||
* :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with
|
||||
sequence lengths greater than the model maximum admissible input size).
|
||||
|
||||
Return:
|
||||
:class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields:
|
||||
|
||||
- **input_ids** -- List of token ids to be fed to the encoder.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
|
||||
- **decoder_input_ids** -- List of token ids to be fed to the decoder.
|
||||
- **decoder_attention_mask** -- List of indices specifying which tokens should be attended to by the decoder.
|
||||
This does not include causal mask, which is built by the model.
|
||||
|
||||
The full set of keys ``[input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]``,
|
||||
will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys.
|
||||
|
||||
"""
|
||||
if max_length is None:
|
||||
max_length = self.max_len
|
||||
self.set_src_lang_special_tokens(src_lang)
|
||||
model_inputs: BatchEncoding = self(
|
||||
src_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
max_length=max_length,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)
|
||||
if tgt_texts is None:
|
||||
return model_inputs
|
||||
# Process tgt_texts
|
||||
if max_target_length is None:
|
||||
max_target_length = max_length
|
||||
self.set_tgt_lang_special_tokens(tgt_lang)
|
||||
decoder_inputs: BatchEncoding = self(
|
||||
tgt_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
max_length=max_target_length,
|
||||
truncation=True,
|
||||
**kwargs,
|
||||
)
|
||||
for k, v in decoder_inputs.items():
|
||||
model_inputs[f"decoder_{k}"] = v
|
||||
|
||||
self.set_src_lang_special_tokens(src_lang) # sets to src_lang
|
||||
return model_inputs
|
||||
|
||||
def set_src_lang_special_tokens(self, src_lang) -> None:
|
||||
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, cur_lang_code]."""
|
||||
self.cur_lang_code = self.lang_code_to_id[src_lang]
|
||||
self.prefix_tokens = []
|
||||
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
|
||||
|
||||
def set_tgt_lang_special_tokens(self, lang: str) -> None:
|
||||
"""Reset the special tokens to the target language setting. Prefix [tgt_lang_code], suffix =[eos]."""
|
||||
self.cur_lang_code = self.lang_code_to_id[lang]
|
||||
self.prefix_tokens = [self.cur_lang_code]
|
||||
self.suffix_tokens = [self.eos_token_id]
|
||||
Reference in New Issue
Block a user