PegasusForConditionalGeneration (torch version) (#6340)

Co-authored-by: Jingqing  Zhang <jingqing.zhang15@imperial.ac.uk>
This commit is contained in:
Sam Shleifer
2020-08-11 14:31:23 -04:00
committed by GitHub
parent f6cb0f806e
commit 66fa8ceaea
20 changed files with 860 additions and 20 deletions

View File

@@ -37,6 +37,7 @@ from .configuration_marian import MarianConfig
from .configuration_mmbt import MMBTConfig
from .configuration_mobilebert import MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileBertConfig
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
from .configuration_pegasus import PegasusConfig
from .configuration_reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig
from .configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
@@ -150,6 +151,7 @@ from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
from .tokenization_longformer import LongformerTokenizer, LongformerTokenizerFast
from .tokenization_mobilebert import MobileBertTokenizer, MobileBertTokenizerFast
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
from .tokenization_pegasus import PegasusTokenizer
from .tokenization_reformer import ReformerTokenizer
from .tokenization_retribert import RetriBertTokenizer, RetriBertTokenizerFast
from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
@@ -287,6 +289,7 @@ if is_torch_available():
XLMForMultipleChoice,
XLM_PRETRAINED_MODEL_ARCHIVE_LIST,
)
from .modeling_pegasus import PegasusForConditionalGeneration
from .modeling_bart import (
PretrainedBartModel,
BartForSequenceClassification,

View File

@@ -32,6 +32,7 @@ from .configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
from .configuration_marian import MarianConfig
from .configuration_mobilebert import MobileBertConfig
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
from .configuration_pegasus import PegasusConfig
from .configuration_reformer import ReformerConfig
from .configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
@@ -81,6 +82,7 @@ CONFIG_MAPPING = OrderedDict(
("albert", AlbertConfig,),
("camembert", CamembertConfig,),
("xlm-roberta", XLMRobertaConfig,),
("pegasus", PegasusConfig),
("marian", MarianConfig,),
("mbart", MBartConfig,),
("bart", BartConfig,),

View File

@@ -18,6 +18,7 @@
import logging
from .configuration_utils import PretrainedConfig
from .file_utils import add_start_docstrings_to_callable
logger = logging.getLogger(__name__)
@@ -31,8 +32,73 @@ BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"facebook/mbart-large-en-ro": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json",
"yjernite/bart_eli5": "https://s3.amazonaws.com/models.huggingface.co/bert/yjernite/bart_eli5/config.json",
}
BART_CONFIG_ARGS_DOC = r"""
Args:
vocab_size (:obj:`int`, optional, defaults to 50265):
defines the different tokens that can be represented by `inputs_ids` passed to the forward method.
d_model (:obj:`int`, optional, defaults to 1024):
Dimensionality of the layers and the pooler layer.
encoder_layers (:obj:`int`, optional, defaults to 12):
Number of encoder layers, 16 for pegasus, 6 for bart-base and marian
decoder_layers (:obj:`int`, optional, defaults to 12):
Number of decoder layers, 16 for pegasus, 6 for bart-base and marian
encoder_attention_heads (:obj:`int`, optional, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
decoder_attention_heads (:obj:`int`, optional, defaults to 16):
Number of attention heads for each attention layer in the Transformer decoder.
decoder_ffn_dim (:obj:`int`, optional, defaults to 4096):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in decoder.
encoder_ffn_dim (:obj:`int`, optional, defaults to 4096):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in decoder.
activation_function (:obj:`str` or :obj:`function`, optional, defaults to "gelu"):
The non-linear activation function (function or string) in the encoder and pooler.
If string, "gelu", "relu", "swish" and "gelu_new" are supported.
dropout (:obj:`float`, optional, defaults to 0.1):
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (:obj:`float`, optional, defaults to 0.0):
The dropout ratio for the attention probabilities.
activation_dropout (:obj:`float`, optional, defaults to 0.0):
The dropout ratio for activations inside the fully connected layer.
classifier_dropout (:obj:`float`, optional, defaults to 0.0):
The dropout ratio for classifier.
max_position_embeddings (:obj:`int`, optional, defaults to 1024):
The maximum sequence length that this model might ever be used with.
Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
init_std (:obj:`float`, optional, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
add_bias_logits (:obj:`int`, optional, defaults to False):
True for marian only.
normalize_before (:obj:`bool`, optional, defaults to False):
Call layernorm before attention ops. True for pegasus, mbart. False for bart. FIXME: marian?
normalize_embedding (:obj:`bool`, optional, defaults to True):
Call layernorm after embeddings. Only True for Bart.
static_position_embeddings (:obj:`bool`, optional, defaults to False):
Don't learn positional embeddings, use sinusoidal. True for marian, pegasus.
add_final_layer_norm (:obj:`bool`, optional, defaults to False):
Why not add another layernorm?
scale_embedding (:obj:`bool`, optional, defaults to False):
Scale embeddings by diving by sqrt(d_model).
eos_token_id (:obj:`int`, optional, defaults to 2)
End of stream token id.
pad_token_id (:obj:`int`, optional, defaults to 1)
Padding token id.
bos_token_id (:obj:`int`, optional, defaults to 0)
Beginning of stream token id.
encoder_layerdrop: (:obj:`float`, optional, defaults to 0.0):
Google "layerdrop arxiv", as its not explainable in one line.
decoder_layerdrop: (:obj:`float`, optional, defaults to 0.0):
Google "layerdrop arxiv", as its not explainable in one line.
extra_pos_embeddings: (:obj:`int`, optional, defaults to 2):
How many extra learned positional embeddings to use. Should be pad_token_id+1 for bart.
num_labels: (:obj:`int`, optional, defaults to 2):
for SequenceClassification
is_encoder_decoder (:obj:`int`, optional, defaults to True):
True
"""
@add_start_docstrings_to_callable(BART_CONFIG_ARGS_DOC)
class BartConfig(PretrainedConfig):
r"""
Configuration class for Bart. Parameters are renamed from the fairseq implementation
@@ -42,7 +108,7 @@ class BartConfig(PretrainedConfig):
def __init__(
self,
activation_dropout=0.0,
extra_pos_embeddings=2,
extra_pos_embeddings=2, # FIXME(@sshleifer): delete?
activation_function="gelu",
vocab_size=50265,
d_model=1024,
@@ -81,6 +147,7 @@ class BartConfig(PretrainedConfig):
>>> config = BartConfig.from_pretrained('facebook/bart-large')
>>> model = BartModel(config)
"""
if "hidden_size" in common_kwargs:
raise ValueError("hidden size is called d_model")
@@ -146,3 +213,4 @@ class BartConfig(PretrainedConfig):
class MBartConfig(BartConfig):
model_type = "mbart"
"""See real config values at https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json."""

View File

@@ -0,0 +1,62 @@
# coding=utf-8
# Copyright 2020 Google and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PEGASUS model configuration """
import logging
from .configuration_bart import BART_CONFIG_ARGS_DOC, BartConfig
from .file_utils import add_start_docstrings_to_callable
logger = logging.getLogger(__name__)
DEFAULTS = dict(
vocab_size=96103,
max_position_embeddings=512,
d_model=1024,
encoder_ffn_dim=4096,
decoder_ffn_dim=4096,
encoder_attention_heads=16,
decoder_attention_heads=16,
encoder_layers=16,
decoder_layers=16,
dropout=0.1,
attention_dropout=0.1,
activation_dropout=0.1,
pad_token_id=0,
eos_token_id=1,
is_encoder_decoder=True,
normalize_before=True,
scale_embedding=True,
normalize_embedding=False,
add_final_layer_norm=True,
static_position_embeddings=True,
num_beams=8,
activation_function="relu",
)
@add_start_docstrings_to_callable(BART_CONFIG_ARGS_DOC)
class PegasusConfig(BartConfig):
r"""
:class:`~transformers.PegasusConfig` is the configuration class to store the configuration of a
`PegasusModel`.
"""
model_type = "pegasus"
# The implementation of the config object is in BartConfig
@property
def default_config_parameters(self):
return DEFAULTS

View File

@@ -0,0 +1,167 @@
# coding=utf-8
# Copyright 2020 Google and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from pathlib import Path
from typing import Dict
import tensorflow as tf
import torch
from tqdm import tqdm
from transformers import PegasusConfig, PegasusForConditionalGeneration, PegasusTokenizer
from transformers.configuration_pegasus import DEFAULTS
PATTERNS = [
# replace left string with right string to get the relevant state_dict key (identical state dict to bart)
["memory_attention", "encoder_attn"],
["attention", "attn"],
["/", "."],
[".LayerNorm.gamma", "_layer_norm.weight"],
[".LayerNorm.beta", "_layer_norm.bias"],
["r.layer_", "r.layers."],
["output_proj", "out_proj"],
["ffn.dense_1.", "fc2."],
["ffn.dense.", "fc1."],
["ffn_layer_norm", "final_layer_norm"],
["kernel", "weight"],
["encoder_layer_norm.", "encoder.layer_norm."],
["decoder_layer_norm.", "decoder.layer_norm."],
["embeddings.weights", "shared.weight"],
]
def rename_state_dict_key(k):
for pegasus_name, bart_name in PATTERNS:
k = k.replace(pegasus_name, bart_name)
return k
# See appendix C of paper for all hyperparams
max_gen_length = {
# See appendix C of paper
"xsum": 64,
"cnn_dailymail": 128,
"newsroom": 128,
"wikihow": 256,
"multi_news": 256,
"reddit_tifu": 128,
"big_patent": 256,
"arxiv": 256,
"pubmed": 256,
"gigaword": 32,
"aeslc": 32,
"billsum": 256,
"large": 256, # @sshleifer chose arbitrarily
}
max_model_length = {
"xsum": 512,
"cnn_dailymail": 1024,
"newsroom": 512,
"wikihow": 512,
"multi_news": 1024,
"reddit_tifu": 512,
"big_patent": 1024,
"arxiv": 1024,
"pubmed": 1024,
"gigaword": 128,
"aeslc": 512,
"billsum": 1024,
"large": 1024,
}
expected_alpha = {
"multinews": 0.9,
"wikihow": 0.6,
"reddit_tifu": 0.6,
"big_patent": 0.7,
"gigaword": 0.6,
"aeslc": 0.6,
"billsum": 0.6,
} # otherwise 0.8
# TODO(SS): one constant
def convert_pegasus_to_bart(tf_weights: dict, cfg_updates: dict) -> PegasusForConditionalGeneration:
cfg_kwargs = DEFAULTS.copy()
cfg_kwargs.update(cfg_updates)
cfg = PegasusConfig(**cfg_updates)
bart = PegasusForConditionalGeneration(cfg)
sd = bart.model.state_dict()
mapping = {}
for k, v in tf_weights.items():
new_k = rename_state_dict_key(k)
if new_k not in sd:
raise ValueError(f"could not find new key {new_k} in state dict. (converted from {k})")
if "dense" in k or "proj" in new_k:
v = v.T
mapping[new_k] = torch.tensor(v, dtype=sd[new_k].dtype)
assert v.shape == sd[new_k].shape, f"{new_k}, {k}, {v.shape}, {sd[new_k].shape}"
# make sure embedding.padding_idx is respected
mapping["shared.weight"][cfg.pad_token_id] = torch.zeros_like(mapping["shared.weight"][cfg.pad_token_id + 1])
mapping["encoder.embed_tokens.weight"] = mapping["shared.weight"]
mapping["decoder.embed_tokens.weight"] = mapping["shared.weight"]
empty_biases = {k: torch.zeros_like(v) for k, v in sd.items() if k.endswith("bias") and k not in mapping}
mapping.update(**empty_biases)
missing, extra = bart.model.load_state_dict(mapping, strict=False)
unexpected_missing = [
k for k in missing if k not in ["encoder.embed_positions.weight", "decoder.embed_positions.weight"]
]
assert unexpected_missing == [], f"no matches found for the following torch keys {unexpected_missing}"
assert extra == [], f"no matches found for the following tf keys {extra}"
return bart
def get_tf_weights_as_numpy(path="./ckpt/aeslc/model.ckpt-32000") -> Dict:
init_vars = tf.train.list_variables(path)
tf_weights = {}
ignore_name = ["Adafactor", "global_step"]
for name, shape in tqdm(init_vars, desc="converting tf checkpoint to dict"):
skip_key = any([pat in name for pat in ignore_name])
if skip_key:
continue
array = tf.train.load_variable(path, name)
tf_weights[name] = array
return tf_weights
def convert_pegasus_ckpt_to_pytorch(ckpt_path, save_dir):
# save tokenizer first
dataset = Path(ckpt_path).parent.name
desired_max_model_length = max_model_length[dataset]
tok = PegasusTokenizer.from_pretrained("sshleifer/pegasus", model_max_length=desired_max_model_length)
assert tok.model_max_length == desired_max_model_length
tok.save_pretrained(save_dir)
# convert model
tf_weights = get_tf_weights_as_numpy(ckpt_path)
cfg_updates = dict(max_length=max_gen_length[dataset], length_penalty=expected_alpha.get(dataset, 0.8))
torch_model = convert_pegasus_to_bart(tf_weights, cfg_updates)
torch_model.save_pretrained(save_dir)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("tf_ckpt_path", type=str, help="passed to tf.train.list_variables")
parser.add_argument("save_dir", default=None, type=str, help="Path to the output PyTorch model.")
args = parser.parse_args()
if args.save_dir is None:
args.save_dir = f"pegasus/{Path(args.tf_ckpt_path).parent.name}"
convert_pegasus_ckpt_to_pytorch(args.tf_ckpt_path, args.save_dir)

View File

@@ -34,6 +34,7 @@ from .configuration_auto import (
LongformerConfig,
MobileBertConfig,
OpenAIGPTConfig,
PegasusConfig,
ReformerConfig,
RetriBertConfig,
RobertaConfig,
@@ -125,6 +126,7 @@ from .modeling_mobilebert import (
MobileBertModel,
)
from .modeling_openai import OpenAIGPTLMHeadModel, OpenAIGPTModel
from .modeling_pegasus import PegasusForConditionalGeneration
from .modeling_reformer import (
ReformerForMaskedLM,
ReformerForQuestionAnswering,
@@ -283,6 +285,7 @@ MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
[
(T5Config, T5ForConditionalGeneration),
(PegasusConfig, PegasusForConditionalGeneration),
(MarianConfig, MarianMTModel),
(BartConfig, BartForConditionalGeneration),
(EncoderDecoderConfig, EncoderDecoderModel),

View File

@@ -19,9 +19,7 @@ from .configuration_marian import MarianConfig
from .modeling_bart import BartForConditionalGeneration
MARIAN_PRETRAINED_MODEL_ARCHIVE_LIST = [
# See all Marian models at https://huggingface.co/models?search=Helsinki-NLP
]
# See all Marian models at https://huggingface.co/models?search=Helsinki-NLP
class MarianMTModel(BartForConditionalGeneration):

View File

@@ -0,0 +1,46 @@
# coding=utf-8
# Copyright 2020 Google and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Pegasus model, ported from https://github.com/google-research/pegasus"""
from .configuration_pegasus import PegasusConfig
from .file_utils import add_start_docstrings
from .modeling_bart import BART_START_DOCSTRING, BartForConditionalGeneration
@add_start_docstrings("The Pegasus Model for summarization ", BART_START_DOCSTRING)
class PegasusForConditionalGeneration(BartForConditionalGeneration):
config_class = PegasusConfig
r"""
Pytorch version of google's pegasus model for summarization.
Model API is identical to BartForConditionalGeneration.
Available models are listed at `Model List <https://huggingface.co/models?search=pegasus>`__
Examples::
>>> from transformers import PegasusTokenizer, PegasusForConditionalGeneration
>>> from typing import List
>>> PGE_ARTICLE = "PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
>>> mname = "google/pegasus-xsum"
>>> model = PegasusForConditionalGeneration.from_pretrained(mname)
>>> tok = PegasusTokenizer.from_pretrained(mname)
>>> batch = tok.prepare_seq2seq_batch(src_texts=[PGE_ARTICLE]) # don't need tgt_text for inference
>>> gen = model.generate(**batch) # for forward pass: model(**batch)
>>> summary: List[str] = tok.batch_decode(gen, skip_special_tokens=True)
>>> assert summary == "California's largest electricity provider has turned off power to tens of thousands of customers."
"""
# All the code is in src/transformers/modeling_bart.py

View File

@@ -30,8 +30,11 @@ from .configuration_auto import (
FlaubertConfig,
GPT2Config,
LongformerConfig,
MarianConfig,
MBartConfig,
MobileBertConfig,
OpenAIGPTConfig,
PegasusConfig,
ReformerConfig,
RetriBertConfig,
RobertaConfig,
@@ -41,8 +44,6 @@ from .configuration_auto import (
XLMRobertaConfig,
XLNetConfig,
)
from .configuration_marian import MarianConfig
from .configuration_mobilebert import MobileBertConfig
from .configuration_utils import PretrainedConfig
from .tokenization_albert import AlbertTokenizer
from .tokenization_bart import BartTokenizer, MBartTokenizer
@@ -58,6 +59,7 @@ from .tokenization_longformer import LongformerTokenizer
from .tokenization_marian import MarianTokenizer
from .tokenization_mobilebert import MobileBertTokenizer, MobileBertTokenizerFast
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
from .tokenization_pegasus import PegasusTokenizer
from .tokenization_reformer import ReformerTokenizer
from .tokenization_retribert import RetriBertTokenizer, RetriBertTokenizerFast
from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
@@ -79,6 +81,7 @@ TOKENIZER_MAPPING = OrderedDict(
(DistilBertConfig, (DistilBertTokenizer, DistilBertTokenizerFast)),
(AlbertConfig, (AlbertTokenizer, None)),
(CamembertConfig, (CamembertTokenizer, None)),
(PegasusConfig, (PegasusTokenizer, None)),
(MBartConfig, (MBartTokenizer, None)),
(XLMRobertaConfig, (XLMRobertaTokenizer, None)),
(MarianConfig, (MarianTokenizer, None)),

View File

@@ -0,0 +1,193 @@
# coding=utf-8
# Copyright 2020 Google and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Optional
from transformers.tokenization_reformer import ReformerTokenizer
from .tokenization_utils_base import BatchEncoding
class PegasusTokenizer(ReformerTokenizer):
offset = 103 # entries 2-104 are only used for pretraining
vocab_files_names = {"vocab_file": "spiece.model"}
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Dont use reserved words added_token_encoder, added_tokens_decoder because of
# AssertionError: Non-consecutive added token '1' found. in from_pretrained
assert len(self.added_tokens_decoder) == 0
self.encoder: Dict[int, str] = {0: self.pad_token, 1: self.eos_token}
# entries 2-104 are only used for pretraining and called unk_2, ...unk_104
self.encoder.update({i: f"unk_{i}" for i in range(2, self.offset + 2)})
self.decoder: Dict[str, int] = {v: k for k, v in self.encoder.items()}
def _convert_token_to_id(self, token: str) -> int:
""" Converts a token (str) in an id using the vocab. """
if token in self.decoder:
return self.decoder[token]
elif token in self.added_tokens_decoder:
return self.added_tokens_decoder[token]
sp_id = self.sp_model.piece_to_id(token)
return sp_id + self.offset
def _convert_id_to_token(self, index: int) -> str:
"""Converts an index (integer) in a token (str) using the vocab."""
if index in self.encoder:
return self.encoder[index]
elif index in self.added_tokens_encoder:
return self.added_tokens_encoder[index]
else:
# assert index > self.offset, f"cannot decode ids between 2 and {self.offset}. Got {index}"
token = self.sp_model.IdToPiece(index - self.offset)
return token
@property
def vocab_size(self) -> int:
return len(self.sp_model) + self.offset
def get_vocab(self) -> Dict[str, int]:
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
def num_special_tokens_to_add(self, pair=False):
"""Just EOS"""
return 1
def _special_token_mask(self, seq):
all_special_ids = set(self.all_special_ids) # call it once instead of inside list comp
all_special_ids.remove(self.unk_token_id) # <unk> is only sometimes special
assert all_special_ids == set([0, 1])
return [1 if x in all_special_ids else 0 for x in seq]
def get_special_tokens_mask(
self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False
) -> List[int]:
"""Get list where entries are [1] if a token is [eos] or [pad] else 0."""
if already_has_special_tokens:
return self._special_token_mask(token_ids_0)
elif token_ids_1 is None:
return self._special_token_mask(token_ids_0) + [1]
else:
return self._special_token_mask(token_ids_0 + token_ids_1) + [1]
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
"""
Build model inputs from a sequence by adding eos to the end. no bos token is added to the front.
- single sequence: ``X </s>``
- pair of sequences: ``A B </s>`` (not intended use)
Args:
token_ids_0 (:obj:`List[int]`):
List of IDs to which the special tokens will be added
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
Optional second list of IDs for sequence pairs.
Returns:
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
"""
if token_ids_1 is None:
return token_ids_0 + [self.eos_token_id]
# We don't expect to process pairs, but leave the pair logic for API consistency
return token_ids_0 + token_ids_1 + [self.eos_token_id]
def prepare_seq2seq_batch(
self,
src_texts: List[str],
tgt_texts: Optional[List[str]] = None,
max_length: Optional[int] = None,
max_target_length: Optional[int] = None,
return_tensors: str = "pt",
truncation=True,
padding="longest",
) -> BatchEncoding:
"""
Prepare model inputs for summarization or translation.
Arguments:
src_texts: (:obj:`list`):
list of documents to summarize or source language texts
tgt_texts: (:obj:`list`, `optional`):
list of tgt language texts or summaries.
max_length (:obj:`int`, `optional`):
Controls the maximum length for encoder inputs (documents to summarize or source language texts)
If left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum
length is required by one of the truncation/padding parameters. If the model has no specific maximum
input length (like XLNet) truncation/padding to a maximum length will be deactivated.
max_target_length (:obj:`int`, `optional`):
Controls the maximum length of decoder inputs (target language texts or summaries)
If left unset or set to :obj:`None`, this will use the max_length value.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
Activates and controls padding. Accepts the following values:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
single sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`, defaults to "pt"):
If set, will return tensors instead of list of python integers. Acceptable values are:
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`):
Activates and controls truncation. Accepts the following values:
* :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument
:obj:`max_length` or to the maximum acceptable input length for the model if that argument is not
provided. This will truncate token by token, removing a token from the longest sequence in the pair
if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to
the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or
to the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with
sequence lengths greater than the model maximum admissible input size).
Return:
:class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields:
- **input_ids** -- List of token ids to be fed to the encoder.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
- **decoder_input_ids** -- List of token ids to be fed to the decoder.
- **decoder_attention_mask** -- List of indices specifying which tokens should be attended to by the decoder.
This does not include causal mask, which is built by the model.
The full set of keys ``[input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]``,
will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys.
"""
if "" in src_texts:
raise ValueError(f"found empty string in src_texts: {src_texts}")
tokenizer_kwargs = dict(
add_special_tokens=True,
return_tensors=return_tensors,
max_length=max_length,
truncation=truncation,
padding=padding,
)
model_inputs: BatchEncoding = self(src_texts, **tokenizer_kwargs)
if tgt_texts is None:
return model_inputs
if max_target_length is not None:
tokenizer_kwargs["max_length"] = max_target_length
decoder_inputs: BatchEncoding = self(tgt_texts, **tokenizer_kwargs)
for k, v in decoder_inputs.items():
model_inputs[f"decoder_{k}"] = v
return model_inputs