PegasusForConditionalGeneration (torch version) (#6340)
Co-authored-by: Jingqing Zhang <jingqing.zhang15@imperial.ac.uk>
This commit is contained in:
@@ -37,6 +37,7 @@ from .configuration_marian import MarianConfig
|
||||
from .configuration_mmbt import MMBTConfig
|
||||
from .configuration_mobilebert import MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileBertConfig
|
||||
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
|
||||
from .configuration_pegasus import PegasusConfig
|
||||
from .configuration_reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig
|
||||
from .configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig
|
||||
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
|
||||
@@ -150,6 +151,7 @@ from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
|
||||
from .tokenization_longformer import LongformerTokenizer, LongformerTokenizerFast
|
||||
from .tokenization_mobilebert import MobileBertTokenizer, MobileBertTokenizerFast
|
||||
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
|
||||
from .tokenization_pegasus import PegasusTokenizer
|
||||
from .tokenization_reformer import ReformerTokenizer
|
||||
from .tokenization_retribert import RetriBertTokenizer, RetriBertTokenizerFast
|
||||
from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
|
||||
@@ -287,6 +289,7 @@ if is_torch_available():
|
||||
XLMForMultipleChoice,
|
||||
XLM_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
)
|
||||
from .modeling_pegasus import PegasusForConditionalGeneration
|
||||
from .modeling_bart import (
|
||||
PretrainedBartModel,
|
||||
BartForSequenceClassification,
|
||||
|
||||
@@ -32,6 +32,7 @@ from .configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
from .configuration_marian import MarianConfig
|
||||
from .configuration_mobilebert import MobileBertConfig
|
||||
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
|
||||
from .configuration_pegasus import PegasusConfig
|
||||
from .configuration_reformer import ReformerConfig
|
||||
from .configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig
|
||||
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
|
||||
@@ -81,6 +82,7 @@ CONFIG_MAPPING = OrderedDict(
|
||||
("albert", AlbertConfig,),
|
||||
("camembert", CamembertConfig,),
|
||||
("xlm-roberta", XLMRobertaConfig,),
|
||||
("pegasus", PegasusConfig),
|
||||
("marian", MarianConfig,),
|
||||
("mbart", MBartConfig,),
|
||||
("bart", BartConfig,),
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
import logging
|
||||
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .file_utils import add_start_docstrings_to_callable
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -31,8 +32,73 @@ BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"facebook/mbart-large-en-ro": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json",
|
||||
"yjernite/bart_eli5": "https://s3.amazonaws.com/models.huggingface.co/bert/yjernite/bart_eli5/config.json",
|
||||
}
|
||||
BART_CONFIG_ARGS_DOC = r"""
|
||||
Args:
|
||||
vocab_size (:obj:`int`, optional, defaults to 50265):
|
||||
defines the different tokens that can be represented by `inputs_ids` passed to the forward method.
|
||||
d_model (:obj:`int`, optional, defaults to 1024):
|
||||
Dimensionality of the layers and the pooler layer.
|
||||
encoder_layers (:obj:`int`, optional, defaults to 12):
|
||||
Number of encoder layers, 16 for pegasus, 6 for bart-base and marian
|
||||
decoder_layers (:obj:`int`, optional, defaults to 12):
|
||||
Number of decoder layers, 16 for pegasus, 6 for bart-base and marian
|
||||
encoder_attention_heads (:obj:`int`, optional, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
decoder_attention_heads (:obj:`int`, optional, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
decoder_ffn_dim (:obj:`int`, optional, defaults to 4096):
|
||||
Dimensionality of the "intermediate" (i.e., feed-forward) layer in decoder.
|
||||
encoder_ffn_dim (:obj:`int`, optional, defaults to 4096):
|
||||
Dimensionality of the "intermediate" (i.e., feed-forward) layer in decoder.
|
||||
activation_function (:obj:`str` or :obj:`function`, optional, defaults to "gelu"):
|
||||
The non-linear activation function (function or string) in the encoder and pooler.
|
||||
If string, "gelu", "relu", "swish" and "gelu_new" are supported.
|
||||
dropout (:obj:`float`, optional, defaults to 0.1):
|
||||
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
attention_dropout (:obj:`float`, optional, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
activation_dropout (:obj:`float`, optional, defaults to 0.0):
|
||||
The dropout ratio for activations inside the fully connected layer.
|
||||
classifier_dropout (:obj:`float`, optional, defaults to 0.0):
|
||||
The dropout ratio for classifier.
|
||||
max_position_embeddings (:obj:`int`, optional, defaults to 1024):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
|
||||
init_std (:obj:`float`, optional, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
add_bias_logits (:obj:`int`, optional, defaults to False):
|
||||
True for marian only.
|
||||
normalize_before (:obj:`bool`, optional, defaults to False):
|
||||
Call layernorm before attention ops. True for pegasus, mbart. False for bart. FIXME: marian?
|
||||
normalize_embedding (:obj:`bool`, optional, defaults to True):
|
||||
Call layernorm after embeddings. Only True for Bart.
|
||||
static_position_embeddings (:obj:`bool`, optional, defaults to False):
|
||||
Don't learn positional embeddings, use sinusoidal. True for marian, pegasus.
|
||||
add_final_layer_norm (:obj:`bool`, optional, defaults to False):
|
||||
Why not add another layernorm?
|
||||
scale_embedding (:obj:`bool`, optional, defaults to False):
|
||||
Scale embeddings by diving by sqrt(d_model).
|
||||
eos_token_id (:obj:`int`, optional, defaults to 2)
|
||||
End of stream token id.
|
||||
pad_token_id (:obj:`int`, optional, defaults to 1)
|
||||
Padding token id.
|
||||
bos_token_id (:obj:`int`, optional, defaults to 0)
|
||||
Beginning of stream token id.
|
||||
encoder_layerdrop: (:obj:`float`, optional, defaults to 0.0):
|
||||
Google "layerdrop arxiv", as its not explainable in one line.
|
||||
decoder_layerdrop: (:obj:`float`, optional, defaults to 0.0):
|
||||
Google "layerdrop arxiv", as its not explainable in one line.
|
||||
extra_pos_embeddings: (:obj:`int`, optional, defaults to 2):
|
||||
How many extra learned positional embeddings to use. Should be pad_token_id+1 for bart.
|
||||
num_labels: (:obj:`int`, optional, defaults to 2):
|
||||
for SequenceClassification
|
||||
is_encoder_decoder (:obj:`int`, optional, defaults to True):
|
||||
True
|
||||
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings_to_callable(BART_CONFIG_ARGS_DOC)
|
||||
class BartConfig(PretrainedConfig):
|
||||
r"""
|
||||
Configuration class for Bart. Parameters are renamed from the fairseq implementation
|
||||
@@ -42,7 +108,7 @@ class BartConfig(PretrainedConfig):
|
||||
def __init__(
|
||||
self,
|
||||
activation_dropout=0.0,
|
||||
extra_pos_embeddings=2,
|
||||
extra_pos_embeddings=2, # FIXME(@sshleifer): delete?
|
||||
activation_function="gelu",
|
||||
vocab_size=50265,
|
||||
d_model=1024,
|
||||
@@ -81,6 +147,7 @@ class BartConfig(PretrainedConfig):
|
||||
|
||||
>>> config = BartConfig.from_pretrained('facebook/bart-large')
|
||||
>>> model = BartModel(config)
|
||||
|
||||
"""
|
||||
if "hidden_size" in common_kwargs:
|
||||
raise ValueError("hidden size is called d_model")
|
||||
@@ -146,3 +213,4 @@ class BartConfig(PretrainedConfig):
|
||||
|
||||
class MBartConfig(BartConfig):
|
||||
model_type = "mbart"
|
||||
"""See real config values at https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json."""
|
||||
|
||||
62
src/transformers/configuration_pegasus.py
Normal file
62
src/transformers/configuration_pegasus.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 Google and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PEGASUS model configuration """
|
||||
|
||||
import logging
|
||||
|
||||
from .configuration_bart import BART_CONFIG_ARGS_DOC, BartConfig
|
||||
from .file_utils import add_start_docstrings_to_callable
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULTS = dict(
|
||||
vocab_size=96103,
|
||||
max_position_embeddings=512,
|
||||
d_model=1024,
|
||||
encoder_ffn_dim=4096,
|
||||
decoder_ffn_dim=4096,
|
||||
encoder_attention_heads=16,
|
||||
decoder_attention_heads=16,
|
||||
encoder_layers=16,
|
||||
decoder_layers=16,
|
||||
dropout=0.1,
|
||||
attention_dropout=0.1,
|
||||
activation_dropout=0.1,
|
||||
pad_token_id=0,
|
||||
eos_token_id=1,
|
||||
is_encoder_decoder=True,
|
||||
normalize_before=True,
|
||||
scale_embedding=True,
|
||||
normalize_embedding=False,
|
||||
add_final_layer_norm=True,
|
||||
static_position_embeddings=True,
|
||||
num_beams=8,
|
||||
activation_function="relu",
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings_to_callable(BART_CONFIG_ARGS_DOC)
|
||||
class PegasusConfig(BartConfig):
|
||||
r"""
|
||||
:class:`~transformers.PegasusConfig` is the configuration class to store the configuration of a
|
||||
`PegasusModel`.
|
||||
"""
|
||||
model_type = "pegasus"
|
||||
# The implementation of the config object is in BartConfig
|
||||
|
||||
@property
|
||||
def default_config_parameters(self):
|
||||
return DEFAULTS
|
||||
167
src/transformers/convert_pegasus_tf_to_pytorch.py
Normal file
167
src/transformers/convert_pegasus_tf_to_pytorch.py
Normal file
@@ -0,0 +1,167 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 Google and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import PegasusConfig, PegasusForConditionalGeneration, PegasusTokenizer
|
||||
from transformers.configuration_pegasus import DEFAULTS
|
||||
|
||||
|
||||
PATTERNS = [
|
||||
# replace left string with right string to get the relevant state_dict key (identical state dict to bart)
|
||||
["memory_attention", "encoder_attn"],
|
||||
["attention", "attn"],
|
||||
["/", "."],
|
||||
[".LayerNorm.gamma", "_layer_norm.weight"],
|
||||
[".LayerNorm.beta", "_layer_norm.bias"],
|
||||
["r.layer_", "r.layers."],
|
||||
["output_proj", "out_proj"],
|
||||
["ffn.dense_1.", "fc2."],
|
||||
["ffn.dense.", "fc1."],
|
||||
["ffn_layer_norm", "final_layer_norm"],
|
||||
["kernel", "weight"],
|
||||
["encoder_layer_norm.", "encoder.layer_norm."],
|
||||
["decoder_layer_norm.", "decoder.layer_norm."],
|
||||
["embeddings.weights", "shared.weight"],
|
||||
]
|
||||
|
||||
|
||||
def rename_state_dict_key(k):
|
||||
|
||||
for pegasus_name, bart_name in PATTERNS:
|
||||
k = k.replace(pegasus_name, bart_name)
|
||||
return k
|
||||
|
||||
|
||||
# See appendix C of paper for all hyperparams
|
||||
max_gen_length = {
|
||||
# See appendix C of paper
|
||||
"xsum": 64,
|
||||
"cnn_dailymail": 128,
|
||||
"newsroom": 128,
|
||||
"wikihow": 256,
|
||||
"multi_news": 256,
|
||||
"reddit_tifu": 128,
|
||||
"big_patent": 256,
|
||||
"arxiv": 256,
|
||||
"pubmed": 256,
|
||||
"gigaword": 32,
|
||||
"aeslc": 32,
|
||||
"billsum": 256,
|
||||
"large": 256, # @sshleifer chose arbitrarily
|
||||
}
|
||||
max_model_length = {
|
||||
"xsum": 512,
|
||||
"cnn_dailymail": 1024,
|
||||
"newsroom": 512,
|
||||
"wikihow": 512,
|
||||
"multi_news": 1024,
|
||||
"reddit_tifu": 512,
|
||||
"big_patent": 1024,
|
||||
"arxiv": 1024,
|
||||
"pubmed": 1024,
|
||||
"gigaword": 128,
|
||||
"aeslc": 512,
|
||||
"billsum": 1024,
|
||||
"large": 1024,
|
||||
}
|
||||
|
||||
expected_alpha = {
|
||||
"multinews": 0.9,
|
||||
"wikihow": 0.6,
|
||||
"reddit_tifu": 0.6,
|
||||
"big_patent": 0.7,
|
||||
"gigaword": 0.6,
|
||||
"aeslc": 0.6,
|
||||
"billsum": 0.6,
|
||||
} # otherwise 0.8
|
||||
# TODO(SS): one constant
|
||||
|
||||
|
||||
def convert_pegasus_to_bart(tf_weights: dict, cfg_updates: dict) -> PegasusForConditionalGeneration:
|
||||
cfg_kwargs = DEFAULTS.copy()
|
||||
cfg_kwargs.update(cfg_updates)
|
||||
|
||||
cfg = PegasusConfig(**cfg_updates)
|
||||
bart = PegasusForConditionalGeneration(cfg)
|
||||
sd = bart.model.state_dict()
|
||||
mapping = {}
|
||||
for k, v in tf_weights.items():
|
||||
new_k = rename_state_dict_key(k)
|
||||
if new_k not in sd:
|
||||
raise ValueError(f"could not find new key {new_k} in state dict. (converted from {k})")
|
||||
|
||||
if "dense" in k or "proj" in new_k:
|
||||
v = v.T
|
||||
mapping[new_k] = torch.tensor(v, dtype=sd[new_k].dtype)
|
||||
assert v.shape == sd[new_k].shape, f"{new_k}, {k}, {v.shape}, {sd[new_k].shape}"
|
||||
# make sure embedding.padding_idx is respected
|
||||
mapping["shared.weight"][cfg.pad_token_id] = torch.zeros_like(mapping["shared.weight"][cfg.pad_token_id + 1])
|
||||
mapping["encoder.embed_tokens.weight"] = mapping["shared.weight"]
|
||||
mapping["decoder.embed_tokens.weight"] = mapping["shared.weight"]
|
||||
empty_biases = {k: torch.zeros_like(v) for k, v in sd.items() if k.endswith("bias") and k not in mapping}
|
||||
mapping.update(**empty_biases)
|
||||
missing, extra = bart.model.load_state_dict(mapping, strict=False)
|
||||
unexpected_missing = [
|
||||
k for k in missing if k not in ["encoder.embed_positions.weight", "decoder.embed_positions.weight"]
|
||||
]
|
||||
assert unexpected_missing == [], f"no matches found for the following torch keys {unexpected_missing}"
|
||||
assert extra == [], f"no matches found for the following tf keys {extra}"
|
||||
return bart
|
||||
|
||||
|
||||
def get_tf_weights_as_numpy(path="./ckpt/aeslc/model.ckpt-32000") -> Dict:
|
||||
init_vars = tf.train.list_variables(path)
|
||||
tf_weights = {}
|
||||
ignore_name = ["Adafactor", "global_step"]
|
||||
for name, shape in tqdm(init_vars, desc="converting tf checkpoint to dict"):
|
||||
skip_key = any([pat in name for pat in ignore_name])
|
||||
if skip_key:
|
||||
continue
|
||||
array = tf.train.load_variable(path, name)
|
||||
tf_weights[name] = array
|
||||
return tf_weights
|
||||
|
||||
|
||||
def convert_pegasus_ckpt_to_pytorch(ckpt_path, save_dir):
|
||||
# save tokenizer first
|
||||
dataset = Path(ckpt_path).parent.name
|
||||
desired_max_model_length = max_model_length[dataset]
|
||||
tok = PegasusTokenizer.from_pretrained("sshleifer/pegasus", model_max_length=desired_max_model_length)
|
||||
assert tok.model_max_length == desired_max_model_length
|
||||
tok.save_pretrained(save_dir)
|
||||
|
||||
# convert model
|
||||
tf_weights = get_tf_weights_as_numpy(ckpt_path)
|
||||
cfg_updates = dict(max_length=max_gen_length[dataset], length_penalty=expected_alpha.get(dataset, 0.8))
|
||||
torch_model = convert_pegasus_to_bart(tf_weights, cfg_updates)
|
||||
torch_model.save_pretrained(save_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument("tf_ckpt_path", type=str, help="passed to tf.train.list_variables")
|
||||
parser.add_argument("save_dir", default=None, type=str, help="Path to the output PyTorch model.")
|
||||
args = parser.parse_args()
|
||||
if args.save_dir is None:
|
||||
args.save_dir = f"pegasus/{Path(args.tf_ckpt_path).parent.name}"
|
||||
convert_pegasus_ckpt_to_pytorch(args.tf_ckpt_path, args.save_dir)
|
||||
@@ -34,6 +34,7 @@ from .configuration_auto import (
|
||||
LongformerConfig,
|
||||
MobileBertConfig,
|
||||
OpenAIGPTConfig,
|
||||
PegasusConfig,
|
||||
ReformerConfig,
|
||||
RetriBertConfig,
|
||||
RobertaConfig,
|
||||
@@ -125,6 +126,7 @@ from .modeling_mobilebert import (
|
||||
MobileBertModel,
|
||||
)
|
||||
from .modeling_openai import OpenAIGPTLMHeadModel, OpenAIGPTModel
|
||||
from .modeling_pegasus import PegasusForConditionalGeneration
|
||||
from .modeling_reformer import (
|
||||
ReformerForMaskedLM,
|
||||
ReformerForQuestionAnswering,
|
||||
@@ -283,6 +285,7 @@ MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
|
||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
|
||||
[
|
||||
(T5Config, T5ForConditionalGeneration),
|
||||
(PegasusConfig, PegasusForConditionalGeneration),
|
||||
(MarianConfig, MarianMTModel),
|
||||
(BartConfig, BartForConditionalGeneration),
|
||||
(EncoderDecoderConfig, EncoderDecoderModel),
|
||||
|
||||
@@ -19,9 +19,7 @@ from .configuration_marian import MarianConfig
|
||||
from .modeling_bart import BartForConditionalGeneration
|
||||
|
||||
|
||||
MARIAN_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
# See all Marian models at https://huggingface.co/models?search=Helsinki-NLP
|
||||
]
|
||||
# See all Marian models at https://huggingface.co/models?search=Helsinki-NLP
|
||||
|
||||
|
||||
class MarianMTModel(BartForConditionalGeneration):
|
||||
|
||||
46
src/transformers/modeling_pegasus.py
Normal file
46
src/transformers/modeling_pegasus.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 Google and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch Pegasus model, ported from https://github.com/google-research/pegasus"""
|
||||
|
||||
|
||||
from .configuration_pegasus import PegasusConfig
|
||||
from .file_utils import add_start_docstrings
|
||||
from .modeling_bart import BART_START_DOCSTRING, BartForConditionalGeneration
|
||||
|
||||
|
||||
@add_start_docstrings("The Pegasus Model for summarization ", BART_START_DOCSTRING)
|
||||
class PegasusForConditionalGeneration(BartForConditionalGeneration):
|
||||
config_class = PegasusConfig
|
||||
r"""
|
||||
Pytorch version of google's pegasus model for summarization.
|
||||
Model API is identical to BartForConditionalGeneration.
|
||||
Available models are listed at `Model List <https://huggingface.co/models?search=pegasus>`__
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import PegasusTokenizer, PegasusForConditionalGeneration
|
||||
>>> from typing import List
|
||||
>>> PGE_ARTICLE = "PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
|
||||
>>> mname = "google/pegasus-xsum"
|
||||
|
||||
>>> model = PegasusForConditionalGeneration.from_pretrained(mname)
|
||||
>>> tok = PegasusTokenizer.from_pretrained(mname)
|
||||
>>> batch = tok.prepare_seq2seq_batch(src_texts=[PGE_ARTICLE]) # don't need tgt_text for inference
|
||||
>>> gen = model.generate(**batch) # for forward pass: model(**batch)
|
||||
>>> summary: List[str] = tok.batch_decode(gen, skip_special_tokens=True)
|
||||
>>> assert summary == "California's largest electricity provider has turned off power to tens of thousands of customers."
|
||||
|
||||
"""
|
||||
# All the code is in src/transformers/modeling_bart.py
|
||||
@@ -30,8 +30,11 @@ from .configuration_auto import (
|
||||
FlaubertConfig,
|
||||
GPT2Config,
|
||||
LongformerConfig,
|
||||
MarianConfig,
|
||||
MBartConfig,
|
||||
MobileBertConfig,
|
||||
OpenAIGPTConfig,
|
||||
PegasusConfig,
|
||||
ReformerConfig,
|
||||
RetriBertConfig,
|
||||
RobertaConfig,
|
||||
@@ -41,8 +44,6 @@ from .configuration_auto import (
|
||||
XLMRobertaConfig,
|
||||
XLNetConfig,
|
||||
)
|
||||
from .configuration_marian import MarianConfig
|
||||
from .configuration_mobilebert import MobileBertConfig
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .tokenization_albert import AlbertTokenizer
|
||||
from .tokenization_bart import BartTokenizer, MBartTokenizer
|
||||
@@ -58,6 +59,7 @@ from .tokenization_longformer import LongformerTokenizer
|
||||
from .tokenization_marian import MarianTokenizer
|
||||
from .tokenization_mobilebert import MobileBertTokenizer, MobileBertTokenizerFast
|
||||
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
|
||||
from .tokenization_pegasus import PegasusTokenizer
|
||||
from .tokenization_reformer import ReformerTokenizer
|
||||
from .tokenization_retribert import RetriBertTokenizer, RetriBertTokenizerFast
|
||||
from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
|
||||
@@ -79,6 +81,7 @@ TOKENIZER_MAPPING = OrderedDict(
|
||||
(DistilBertConfig, (DistilBertTokenizer, DistilBertTokenizerFast)),
|
||||
(AlbertConfig, (AlbertTokenizer, None)),
|
||||
(CamembertConfig, (CamembertTokenizer, None)),
|
||||
(PegasusConfig, (PegasusTokenizer, None)),
|
||||
(MBartConfig, (MBartTokenizer, None)),
|
||||
(XLMRobertaConfig, (XLMRobertaTokenizer, None)),
|
||||
(MarianConfig, (MarianTokenizer, None)),
|
||||
|
||||
193
src/transformers/tokenization_pegasus.py
Normal file
193
src/transformers/tokenization_pegasus.py
Normal file
@@ -0,0 +1,193 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 Google and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from transformers.tokenization_reformer import ReformerTokenizer
|
||||
|
||||
from .tokenization_utils_base import BatchEncoding
|
||||
|
||||
|
||||
class PegasusTokenizer(ReformerTokenizer):
|
||||
offset = 103 # entries 2-104 are only used for pretraining
|
||||
vocab_files_names = {"vocab_file": "spiece.model"}
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# Dont use reserved words added_token_encoder, added_tokens_decoder because of
|
||||
# AssertionError: Non-consecutive added token '1' found. in from_pretrained
|
||||
assert len(self.added_tokens_decoder) == 0
|
||||
self.encoder: Dict[int, str] = {0: self.pad_token, 1: self.eos_token}
|
||||
# entries 2-104 are only used for pretraining and called unk_2, ...unk_104
|
||||
self.encoder.update({i: f"unk_{i}" for i in range(2, self.offset + 2)})
|
||||
self.decoder: Dict[str, int] = {v: k for k, v in self.encoder.items()}
|
||||
|
||||
def _convert_token_to_id(self, token: str) -> int:
|
||||
""" Converts a token (str) in an id using the vocab. """
|
||||
if token in self.decoder:
|
||||
return self.decoder[token]
|
||||
elif token in self.added_tokens_decoder:
|
||||
return self.added_tokens_decoder[token]
|
||||
sp_id = self.sp_model.piece_to_id(token)
|
||||
return sp_id + self.offset
|
||||
|
||||
def _convert_id_to_token(self, index: int) -> str:
|
||||
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||
if index in self.encoder:
|
||||
return self.encoder[index]
|
||||
elif index in self.added_tokens_encoder:
|
||||
return self.added_tokens_encoder[index]
|
||||
else:
|
||||
# assert index > self.offset, f"cannot decode ids between 2 and {self.offset}. Got {index}"
|
||||
token = self.sp_model.IdToPiece(index - self.offset)
|
||||
return token
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return len(self.sp_model) + self.offset
|
||||
|
||||
def get_vocab(self) -> Dict[str, int]:
|
||||
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
||||
vocab.update(self.added_tokens_encoder)
|
||||
return vocab
|
||||
|
||||
def num_special_tokens_to_add(self, pair=False):
|
||||
"""Just EOS"""
|
||||
return 1
|
||||
|
||||
def _special_token_mask(self, seq):
|
||||
all_special_ids = set(self.all_special_ids) # call it once instead of inside list comp
|
||||
all_special_ids.remove(self.unk_token_id) # <unk> is only sometimes special
|
||||
assert all_special_ids == set([0, 1])
|
||||
return [1 if x in all_special_ids else 0 for x in seq]
|
||||
|
||||
def get_special_tokens_mask(
|
||||
self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False
|
||||
) -> List[int]:
|
||||
"""Get list where entries are [1] if a token is [eos] or [pad] else 0."""
|
||||
if already_has_special_tokens:
|
||||
return self._special_token_mask(token_ids_0)
|
||||
elif token_ids_1 is None:
|
||||
return self._special_token_mask(token_ids_0) + [1]
|
||||
else:
|
||||
return self._special_token_mask(token_ids_0 + token_ids_1) + [1]
|
||||
|
||||
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
|
||||
"""
|
||||
Build model inputs from a sequence by adding eos to the end. no bos token is added to the front.
|
||||
- single sequence: ``X </s>``
|
||||
- pair of sequences: ``A B </s>`` (not intended use)
|
||||
|
||||
Args:
|
||||
token_ids_0 (:obj:`List[int]`):
|
||||
List of IDs to which the special tokens will be added
|
||||
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
|
||||
Returns:
|
||||
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
||||
"""
|
||||
if token_ids_1 is None:
|
||||
return token_ids_0 + [self.eos_token_id]
|
||||
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||
return token_ids_0 + token_ids_1 + [self.eos_token_id]
|
||||
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
return_tensors: str = "pt",
|
||||
truncation=True,
|
||||
padding="longest",
|
||||
) -> BatchEncoding:
|
||||
"""
|
||||
Prepare model inputs for summarization or translation.
|
||||
|
||||
Arguments:
|
||||
src_texts: (:obj:`list`):
|
||||
list of documents to summarize or source language texts
|
||||
tgt_texts: (:obj:`list`, `optional`):
|
||||
list of tgt language texts or summaries.
|
||||
max_length (:obj:`int`, `optional`):
|
||||
Controls the maximum length for encoder inputs (documents to summarize or source language texts)
|
||||
If left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum
|
||||
length is required by one of the truncation/padding parameters. If the model has no specific maximum
|
||||
input length (like XLNet) truncation/padding to a maximum length will be deactivated.
|
||||
max_target_length (:obj:`int`, `optional`):
|
||||
Controls the maximum length of decoder inputs (target language texts or summaries)
|
||||
If left unset or set to :obj:`None`, this will use the max_length value.
|
||||
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
|
||||
Activates and controls padding. Accepts the following values:
|
||||
|
||||
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
|
||||
single sequence if provided).
|
||||
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
||||
maximum acceptable input length for the model if that argument is not provided.
|
||||
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
||||
different lengths).
|
||||
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`, defaults to "pt"):
|
||||
If set, will return tensors instead of list of python integers. Acceptable values are:
|
||||
|
||||
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
|
||||
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
|
||||
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
|
||||
truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`):
|
||||
Activates and controls truncation. Accepts the following values:
|
||||
|
||||
* :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument
|
||||
:obj:`max_length` or to the maximum acceptable input length for the model if that argument is not
|
||||
provided. This will truncate token by token, removing a token from the longest sequence in the pair
|
||||
if a pair of sequences (or a batch of pairs) is provided.
|
||||
* :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to
|
||||
the maximum acceptable input length for the model if that argument is not provided. This will only
|
||||
truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
|
||||
* :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or
|
||||
to the maximum acceptable input length for the model if that argument is not provided. This will only
|
||||
truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
|
||||
* :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with
|
||||
sequence lengths greater than the model maximum admissible input size).
|
||||
|
||||
Return:
|
||||
:class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields:
|
||||
|
||||
- **input_ids** -- List of token ids to be fed to the encoder.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
|
||||
- **decoder_input_ids** -- List of token ids to be fed to the decoder.
|
||||
- **decoder_attention_mask** -- List of indices specifying which tokens should be attended to by the decoder.
|
||||
This does not include causal mask, which is built by the model.
|
||||
|
||||
The full set of keys ``[input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]``,
|
||||
will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys.
|
||||
|
||||
"""
|
||||
if "" in src_texts:
|
||||
raise ValueError(f"found empty string in src_texts: {src_texts}")
|
||||
tokenizer_kwargs = dict(
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
max_length=max_length,
|
||||
truncation=truncation,
|
||||
padding=padding,
|
||||
)
|
||||
model_inputs: BatchEncoding = self(src_texts, **tokenizer_kwargs)
|
||||
if tgt_texts is None:
|
||||
return model_inputs
|
||||
if max_target_length is not None:
|
||||
tokenizer_kwargs["max_length"] = max_target_length
|
||||
decoder_inputs: BatchEncoding = self(tgt_texts, **tokenizer_kwargs)
|
||||
for k, v in decoder_inputs.items():
|
||||
model_inputs[f"decoder_{k}"] = v
|
||||
return model_inputs
|
||||
Reference in New Issue
Block a user