T5 & mT5 (#8552)
* add mt5 and t5v1_1 model * fix tests * correct some imports * add tf model * finish tf t5 * improve examples * fix copies * clean doc
This commit is contained in:
committed by
GitHub
parent
9e01f988dd
commit
86822a358b
@@ -248,6 +248,7 @@ conversion utilities for the following models:
|
||||
model_doc/marian
|
||||
model_doc/mbart
|
||||
model_doc/mobilebert
|
||||
model_doc/mt5
|
||||
model_doc/gpt
|
||||
model_doc/gpt2
|
||||
model_doc/pegasus
|
||||
|
||||
53
docs/source/model_doc/mt5.rst
Normal file
53
docs/source/model_doc/mt5.rst
Normal file
@@ -0,0 +1,53 @@
|
||||
MT5
|
||||
-----------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
Overview
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
The mT5 model was presented in `mT5: A massively multilingual pre-trained text-to-text transformer
|
||||
<https://arxiv.org/abs/2010.11934>`_ by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya
|
||||
Siddhant, Aditya Barua, Colin Raffel.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*The recent "Text-to-Text Transfer Transformer" (T5) leveraged a unified text-to-text format and scale to attain
|
||||
state-of-the-art results on a wide variety of English-language NLP tasks. In this paper, we introduce mT5, a
|
||||
multilingual variant of T5 that was pre-trained on a new Common Crawl-based dataset covering 101 languages. We describe
|
||||
the design and modified training of mT5 and demonstrate its state-of-the-art performance on many multilingual
|
||||
benchmarks. All of the code and model checkpoints*
|
||||
|
||||
The original code can be found `here <https://github.com/google-research/multilingual-t5>`__.
|
||||
|
||||
MT5Config
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.MT5Config
|
||||
:members:
|
||||
|
||||
|
||||
MT5Model
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.MT5Model
|
||||
:members:
|
||||
|
||||
|
||||
MT5ForConditionalGeneration
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.MT5ForConditionalGeneration
|
||||
:members:
|
||||
|
||||
|
||||
TFMT5Model
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFMT5Model
|
||||
:members:
|
||||
|
||||
|
||||
TFMT5ForConditionalGeneration
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFMT5ForConditionalGeneration
|
||||
:members:
|
||||
@@ -498,6 +498,7 @@ if is_torch_available():
|
||||
MobileBertPreTrainedModel,
|
||||
load_tf_weights_in_mobilebert,
|
||||
)
|
||||
from .models.mt5 import MT5Config, MT5ForConditionalGeneration, MT5Model
|
||||
from .models.openai import (
|
||||
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
OpenAIGPTDoubleHeadsModel,
|
||||
@@ -791,6 +792,7 @@ if is_tf_available():
|
||||
TFMobileBertModel,
|
||||
TFMobileBertPreTrainedModel,
|
||||
)
|
||||
from .models.mt5 import TFMT5ForConditionalGeneration, TFMT5Model
|
||||
from .models.openai import (
|
||||
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFOpenAIGPTDoubleHeadsModel,
|
||||
|
||||
@@ -40,6 +40,7 @@ from ..lxmert.configuration_lxmert import LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
from ..marian.configuration_marian import MarianConfig
|
||||
from ..mbart.configuration_mbart import MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, MBartConfig
|
||||
from ..mobilebert.configuration_mobilebert import MobileBertConfig
|
||||
from ..mt5.configuration_mt5 import MT5Config
|
||||
from ..openai.configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
|
||||
from ..pegasus.configuration_pegasus import PegasusConfig
|
||||
from ..prophetnet.configuration_prophetnet import PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ProphetNetConfig
|
||||
@@ -101,6 +102,7 @@ CONFIG_MAPPING = OrderedDict(
|
||||
[
|
||||
# Add configs here
|
||||
("retribert", RetriBertConfig),
|
||||
("mt5", MT5Config),
|
||||
("t5", T5Config),
|
||||
("mobilebert", MobileBertConfig),
|
||||
("distilbert", DistilBertConfig),
|
||||
@@ -178,6 +180,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("rag", "RAG"),
|
||||
("xlm-prophetnet", "XLMProphetNet"),
|
||||
("prophetnet", "ProphetNet"),
|
||||
("mt5", "mT5"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -120,6 +120,7 @@ from ..mobilebert.modeling_mobilebert import (
|
||||
MobileBertForTokenClassification,
|
||||
MobileBertModel,
|
||||
)
|
||||
from ..mt5.modeling_mt5 import MT5ForConditionalGeneration, MT5Model
|
||||
from ..openai.modeling_openai import OpenAIGPTForSequenceClassification, OpenAIGPTLMHeadModel, OpenAIGPTModel
|
||||
from ..pegasus.modeling_pegasus import PegasusForConditionalGeneration
|
||||
from ..prophetnet.modeling_prophetnet import ProphetNetForCausalLM, ProphetNetForConditionalGeneration, ProphetNetModel
|
||||
@@ -209,6 +210,7 @@ from .configuration_auto import (
|
||||
MarianConfig,
|
||||
MBartConfig,
|
||||
MobileBertConfig,
|
||||
MT5Config,
|
||||
OpenAIGPTConfig,
|
||||
PegasusConfig,
|
||||
ProphetNetConfig,
|
||||
@@ -235,6 +237,7 @@ MODEL_MAPPING = OrderedDict(
|
||||
[
|
||||
# Base model mapping
|
||||
(RetriBertConfig, RetriBertModel),
|
||||
(MT5Config, MT5Model),
|
||||
(T5Config, T5Model),
|
||||
(DistilBertConfig, DistilBertModel),
|
||||
(AlbertConfig, AlbertModel),
|
||||
@@ -376,6 +379,7 @@ MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
|
||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
|
||||
[
|
||||
# Model for Seq2Seq Causal LM mapping
|
||||
(MT5Config, MT5ForConditionalGeneration),
|
||||
(T5Config, T5ForConditionalGeneration),
|
||||
(PegasusConfig, PegasusForConditionalGeneration),
|
||||
(MarianConfig, MarianMTModel),
|
||||
|
||||
@@ -106,6 +106,7 @@ from ..mobilebert.modeling_tf_mobilebert import (
|
||||
TFMobileBertForTokenClassification,
|
||||
TFMobileBertModel,
|
||||
)
|
||||
from ..mt5.modeling_tf_mt5 import TFMT5ForConditionalGeneration, TFMT5Model
|
||||
from ..openai.modeling_tf_openai import TFOpenAIGPTLMHeadModel, TFOpenAIGPTModel
|
||||
from ..pegasus.modeling_tf_pegasus import TFPegasusForConditionalGeneration
|
||||
from ..roberta.modeling_tf_roberta import (
|
||||
@@ -161,6 +162,7 @@ from .configuration_auto import (
|
||||
MarianConfig,
|
||||
MBartConfig,
|
||||
MobileBertConfig,
|
||||
MT5Config,
|
||||
OpenAIGPTConfig,
|
||||
PegasusConfig,
|
||||
RobertaConfig,
|
||||
@@ -182,6 +184,7 @@ TF_MODEL_MAPPING = OrderedDict(
|
||||
[
|
||||
# Base model mapping
|
||||
(LxmertConfig, TFLxmertModel),
|
||||
(MT5Config, TFMT5Model),
|
||||
(T5Config, TFT5Model),
|
||||
(DistilBertConfig, TFDistilBertModel),
|
||||
(AlbertConfig, TFAlbertModel),
|
||||
@@ -294,6 +297,7 @@ TF_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
|
||||
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
|
||||
[
|
||||
# Model for Seq2Seq Causal LM mapping
|
||||
(MT5Config, TFMT5ForConditionalGeneration),
|
||||
(T5Config, TFT5ForConditionalGeneration),
|
||||
(MarianConfig, TFMarianMTModel),
|
||||
(MBartConfig, TFMBartForConditionalGeneration),
|
||||
|
||||
13
src/transformers/models/mt5/__init__.py
Normal file
13
src/transformers/models/mt5/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# flake8: noqa
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
from ...file_utils import is_tf_available, is_torch_available
|
||||
from .configuration_mt5 import MT5Config
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .modeling_mt5 import MT5ForConditionalGeneration, MT5Model
|
||||
|
||||
if is_tf_available():
|
||||
from .modeling_tf_mt5 import TFMT5ForConditionalGeneration, TFMT5Model
|
||||
122
src/transformers/models/mt5/configuration_mt5.py
Normal file
122
src/transformers/models/mt5/configuration_mt5.py
Normal file
@@ -0,0 +1,122 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020, The T5 Authors and HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" mT5 model configuration """
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class MT5Config(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a :class:`~transformers.MT5Model` or a
|
||||
:class:`~transformers.TFMT5Model`. It is used to instantiate a mT5 model according to the specified arguments,
|
||||
defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration
|
||||
to that of the mT5 `google/mt5-small <https://huggingface.co/google/mt5-small>`__ architecture.
|
||||
|
||||
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
|
||||
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
|
||||
|
||||
Arguments:
|
||||
vocab_size (:obj:`int`, `optional`, defaults to 32128):
|
||||
Vocabulary size of the T5 model. Defines the number of different tokens that can be represented by the
|
||||
:obj:`inputs_ids` passed when calling :class:`~transformers.T5Model` or :class:`~transformers.TFT5Model`.
|
||||
d_model (:obj:`int`, `optional`, defaults to 512):
|
||||
Size of the encoder layers and the pooler layer.
|
||||
d_kv (:obj:`int`, `optional`, defaults to 64):
|
||||
Size of the key, query, value projections per attention head. :obj:`d_kv` has to be equal to :obj:`d_model
|
||||
// num_heads`.
|
||||
d_ff (:obj:`int`, `optional`, defaults to 1024):
|
||||
Size of the intermediate feed forward layer in each :obj:`T5Block`.
|
||||
num_layers (:obj:`int`, `optional`, defaults to 8):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_decoder_layers (:obj:`int`, `optional`):
|
||||
Number of hidden layers in the Transformer decoder. Will use the same value as :obj:`num_layers` if not
|
||||
set.
|
||||
num_heads (:obj:`int`, `optional`, defaults to 6):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
relative_attention_num_buckets (:obj:`int`, `optional`, defaults to 32):
|
||||
The number of buckets to use for each attention layer.
|
||||
dropout_rate (:obj:`float`, `optional`, defaults to 0.1):
|
||||
The ratio for all dropout layers.
|
||||
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-6):
|
||||
The epsilon used by the layer normalization layers.
|
||||
initializer_factor (:obj:`float`, `optional`, defaults to 1):
|
||||
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
|
||||
testing).
|
||||
feed_forward_proj (:obj:`string`, `optional`, defaults to :obj:`"gated-gelu"`):
|
||||
Type of feed forward layer to be used. Should be one of :obj:`"relu"` or :obj:`"gated-gelu"`.
|
||||
"""
|
||||
model_type = "mt5"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=250112,
|
||||
d_model=512,
|
||||
d_kv=64,
|
||||
d_ff=1024,
|
||||
num_layers=8,
|
||||
num_decoder_layers=None,
|
||||
num_heads=6,
|
||||
relative_attention_num_buckets=32,
|
||||
dropout_rate=0.1,
|
||||
layer_norm_epsilon=1e-6,
|
||||
initializer_factor=1.0,
|
||||
feed_forward_proj="gated-gelu",
|
||||
is_encoder_decoder=True,
|
||||
tokenizer_class="T5Tokenizer",
|
||||
tie_word_embeddings=False,
|
||||
pad_token_id=0,
|
||||
eos_token_id=1,
|
||||
decoder_start_token_id=0,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
is_encoder_decoder=is_encoder_decoder,
|
||||
tokenizer_class=tokenizer_class,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
decoder_start_token_id=decoder_start_token_id,
|
||||
**kwargs,
|
||||
)
|
||||
self.vocab_size = vocab_size
|
||||
self.d_model = d_model
|
||||
self.d_kv = d_kv
|
||||
self.d_ff = d_ff
|
||||
self.num_layers = num_layers
|
||||
self.num_decoder_layers = (
|
||||
num_decoder_layers if num_decoder_layers is not None else self.num_layers
|
||||
) # default = symmetry
|
||||
self.num_heads = num_heads
|
||||
self.relative_attention_num_buckets = relative_attention_num_buckets
|
||||
self.dropout_rate = dropout_rate
|
||||
self.layer_norm_epsilon = layer_norm_epsilon
|
||||
self.initializer_factor = initializer_factor
|
||||
self.feed_forward_proj = feed_forward_proj
|
||||
|
||||
@property
|
||||
def hidden_size(self):
|
||||
return self.d_model
|
||||
|
||||
@property
|
||||
def num_attention_heads(self):
|
||||
return self.num_heads
|
||||
|
||||
@property
|
||||
def num_hidden_layers(self):
|
||||
return self.num_layers
|
||||
83
src/transformers/models/mt5/modeling_mt5.py
Normal file
83
src/transformers/models/mt5/modeling_mt5.py
Normal file
@@ -0,0 +1,83 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch mT5 model. """
|
||||
|
||||
from ...utils import logging
|
||||
from ..t5.modeling_t5 import T5ForConditionalGeneration, T5Model
|
||||
from .configuration_mt5 import MT5Config
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "T5Config"
|
||||
_TOKENIZER_FOR_DOC = "T5Tokenizer"
|
||||
|
||||
|
||||
class MT5Model(T5Model):
|
||||
r"""
|
||||
This class overrides :class:`~transformers.T5Model`. Please check the superclass for the appropriate documentation
|
||||
alongside usage examples.
|
||||
|
||||
Examples::
|
||||
>>> from transformers import MT5Model, T5Tokenizer
|
||||
>>> model = MT5Model.from_pretrained("google/mt5-small")
|
||||
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
|
||||
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
|
||||
>>> summary = "Weiter Verhandlung in Syrien."
|
||||
>>> batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], tgt_texts=[summary], return_tensors="pt")
|
||||
>>> outputs = model(input_ids=batch.input_ids, decoder_input_ids=batch.labels)
|
||||
>>> hidden_states = outputs.last_hidden_state
|
||||
"""
|
||||
model_type = "mt5"
|
||||
config_class = MT5Config
|
||||
authorized_missing_keys = [
|
||||
r"encoder\.embed_tokens\.weight",
|
||||
r"decoder\.embed_tokens\.weight",
|
||||
r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
|
||||
]
|
||||
keys_to_never_save = [
|
||||
r"encoder\.embed_tokens\.weight",
|
||||
r"decoder\.embed_tokens\.weight",
|
||||
]
|
||||
|
||||
|
||||
class MT5ForConditionalGeneration(T5ForConditionalGeneration):
|
||||
r"""
|
||||
This class overrides :class:`~transformers.T5ForConditionalGeneration`. Please check the superclass for the
|
||||
appropriate documentation alongside usage examples.
|
||||
|
||||
Examples::
|
||||
>>> from transformers import MT5ForConditionalGeneration, T5Tokenizer
|
||||
>>> model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small")
|
||||
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
|
||||
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
|
||||
>>> summary = "Weiter Verhandlung in Syrien."
|
||||
>>> batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], tgt_texts=[summary], return_tensors="pt")
|
||||
>>> outputs = model(**batch)
|
||||
>>> loss = outputs.loss
|
||||
"""
|
||||
|
||||
model_type = "mt5"
|
||||
config_class = MT5Config
|
||||
authorized_missing_keys = [
|
||||
r"encoder\.embed_tokens\.weight",
|
||||
r"decoder\.embed_tokens\.weight",
|
||||
r"lm_head\.weight",
|
||||
r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
|
||||
]
|
||||
keys_to_never_save = [
|
||||
r"encoder\.embed_tokens\.weight",
|
||||
r"decoder\.embed_tokens\.weight",
|
||||
]
|
||||
66
src/transformers/models/mt5/modeling_tf_mt5.py
Normal file
66
src/transformers/models/mt5/modeling_tf_mt5.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Tensorflow mT5 model. """
|
||||
|
||||
from ...utils import logging
|
||||
from ..t5.modeling_tf_t5 import TFT5ForConditionalGeneration, TFT5Model
|
||||
from .configuration_mt5 import MT5Config
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "T5Config"
|
||||
_TOKENIZER_FOR_DOC = "T5Tokenizer"
|
||||
|
||||
|
||||
class TFMT5Model(TFT5Model):
|
||||
r"""
|
||||
This class overrides :class:`~transformers.TFT5Model`. Please check the superclass for the appropriate
|
||||
documentation alongside usage examples.
|
||||
|
||||
Examples::
|
||||
>>> from transformers import TFMT5Model, T5Tokenizer
|
||||
>>> model = TFMT5Model.from_pretrained("google/mt5-small")
|
||||
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
|
||||
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
|
||||
>>> summary = "Weiter Verhandlung in Syrien."
|
||||
>>> batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], tgt_texts=[summary], return_tensors="tf")
|
||||
>>> batch["decoder_input_ids"] = batch["labels"]
|
||||
>>> del batch["labels"]
|
||||
>>> outputs = model(batch)
|
||||
>>> hidden_states = outputs.last_hidden_state
|
||||
"""
|
||||
model_type = "mt5"
|
||||
config_class = MT5Config
|
||||
|
||||
|
||||
class TFMT5ForConditionalGeneration(TFT5ForConditionalGeneration):
|
||||
r"""
|
||||
This class overrides :class:`~transformers.TFT5ForConditionalGeneration`. Please check the superclass for the
|
||||
appropriate documentation alongside usage examples.
|
||||
|
||||
Examples::
|
||||
>>> from transformers import TFMT5ForConditionalGeneration, T5Tokenizer
|
||||
>>> model = TFMT5ForConditionalGeneration.from_pretrained("google/mt5-small")
|
||||
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
|
||||
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
|
||||
>>> summary = "Weiter Verhandlung in Syrien."
|
||||
>>> batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], tgt_texts=[summary], return_tensors="tf")
|
||||
>>> outputs = model(batch)
|
||||
>>> loss = outputs.loss
|
||||
"""
|
||||
|
||||
model_type = "mt5"
|
||||
config_class = MT5Config
|
||||
@@ -1,5 +1,5 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2010, The T5 Authors and HuggingFace Inc.
|
||||
# Copyright 2020, The T5 Authors and HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -43,9 +43,6 @@ class T5Config(PretrainedConfig):
|
||||
vocab_size (:obj:`int`, `optional`, defaults to 32128):
|
||||
Vocabulary size of the T5 model. Defines the number of different tokens that can be represented by the
|
||||
:obj:`inputs_ids` passed when calling :class:`~transformers.T5Model` or :class:`~transformers.TFT5Model`.
|
||||
n_positions (:obj:`int`, `optional`, defaults to 512):
|
||||
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
||||
just in case (e.g., 512 or 1024 or 2048).
|
||||
d_model (:obj:`int`, `optional`, defaults to 512):
|
||||
Size of the encoder layers and the pooler layer.
|
||||
d_kv (:obj:`int`, `optional`, defaults to 64):
|
||||
@@ -69,6 +66,9 @@ class T5Config(PretrainedConfig):
|
||||
initializer_factor (:obj:`float`, `optional`, defaults to 1):
|
||||
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
|
||||
testing).
|
||||
feed_forward_proj (:obj:`string`, `optional`, defaults to :obj:`"relu"`):
|
||||
Type of feed forward layer to be used. Should be one of :obj:`"relu"` or :obj:`"gated-gelu"`. T5v1.1 uses
|
||||
the :obj:`"gated-gelu"` feed forward projection. Original T5 uses :obj:`"relu"`.
|
||||
"""
|
||||
model_type = "t5"
|
||||
|
||||
@@ -85,6 +85,7 @@ class T5Config(PretrainedConfig):
|
||||
dropout_rate=0.1,
|
||||
layer_norm_epsilon=1e-6,
|
||||
initializer_factor=1.0,
|
||||
feed_forward_proj="relu",
|
||||
is_encoder_decoder=True,
|
||||
pad_token_id=0,
|
||||
eos_token_id=1,
|
||||
@@ -109,6 +110,7 @@ class T5Config(PretrainedConfig):
|
||||
self.dropout_rate = dropout_rate
|
||||
self.layer_norm_epsilon = layer_norm_epsilon
|
||||
self.initializer_factor = initializer_factor
|
||||
self.feed_forward_proj = feed_forward_proj
|
||||
|
||||
@property
|
||||
def hidden_size(self):
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
|
||||
import argparse
|
||||
|
||||
from transformers import T5Config, T5Model, load_tf_weights_in_t5
|
||||
from transformers import T5Config, T5ForConditionalGeneration, load_tf_weights_in_t5
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du
|
||||
# Initialise PyTorch model
|
||||
config = T5Config.from_json_file(config_file)
|
||||
print("Building PyTorch model from configuration: {}".format(str(config)))
|
||||
model = T5Model(config)
|
||||
model = T5ForConditionalGeneration(config)
|
||||
|
||||
# Load weights from tf checkpoint
|
||||
load_tf_weights_in_t5(model, config, tf_checkpoint_path)
|
||||
|
||||
@@ -25,6 +25,7 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...file_utils import (
|
||||
DUMMY_INPUTS,
|
||||
DUMMY_MASK,
|
||||
@@ -140,6 +141,9 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
|
||||
continue
|
||||
elif scope_names[0] == "logits":
|
||||
pointer = getattr(pointer, "lm_head")
|
||||
elif scope_names[0] == "wi" and len(scope_names) > 1 and scope_names[1].isdigit():
|
||||
pointer = getattr(pointer, f"wi_{scope_names[1]}")
|
||||
continue
|
||||
else:
|
||||
try:
|
||||
pointer = getattr(pointer, scope_names[0])
|
||||
@@ -211,10 +215,36 @@ class T5DenseReluDense(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class T5DenseGatedGeluDense(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
|
||||
self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
|
||||
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
|
||||
self.dropout = nn.Dropout(config.dropout_rate)
|
||||
self.gelu_act = ACT2FN["gelu_new"]
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_gelu = self.gelu_act(self.wi_0(hidden_states))
|
||||
hidden_linear = self.wi_1(hidden_states)
|
||||
hidden_states = hidden_gelu * hidden_linear
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.wo(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class T5LayerFF(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
if config.feed_forward_proj == "relu":
|
||||
self.DenseReluDense = T5DenseReluDense(config)
|
||||
elif config.feed_forward_proj == "gated-gelu":
|
||||
self.DenseReluDense = T5DenseGatedGeluDense(config)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{self.config.feed_forward_proj} is not supported. Choose between `relu` and `gated-gelu`"
|
||||
)
|
||||
|
||||
self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.dropout = nn.Dropout(config.dropout_rate)
|
||||
|
||||
@@ -641,6 +671,16 @@ class T5PreTrainedModel(PreTrainedModel):
|
||||
module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
|
||||
if hasattr(module.wo, "bias") and module.wo.bias is not None:
|
||||
module.wo.bias.data.zero_()
|
||||
elif isinstance(module, T5DenseGatedGeluDense):
|
||||
module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
|
||||
if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
|
||||
module.wi_0.bias.data.zero_()
|
||||
module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
|
||||
if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
|
||||
module.wi_1.bias.data.zero_()
|
||||
module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
|
||||
if hasattr(module.wo, "bias") and module.wo.bias is not None:
|
||||
module.wo.bias.data.zero_()
|
||||
elif isinstance(module, T5Attention):
|
||||
# Mesh TensorFlow attention initialization to avoid scaling before softmax
|
||||
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
|
||||
@@ -1099,8 +1139,6 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
r"encoder\.embed_tokens\.weight",
|
||||
r"decoder\.embed_tokens\.weight",
|
||||
r"lm_head\.weight",
|
||||
r"encoder\.embed_tokens\.weight",
|
||||
r"decoder\.embed_tokens\.weight",
|
||||
r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
|
||||
]
|
||||
|
||||
@@ -1262,9 +1300,12 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
)
|
||||
|
||||
sequence_output = decoder_outputs[0]
|
||||
|
||||
if self.config.tie_word_embeddings:
|
||||
# Rescale output before projecting on vocab
|
||||
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
|
||||
sequence_output = sequence_output * (self.model_dim ** -0.5)
|
||||
|
||||
lm_logits = self.lm_head(sequence_output)
|
||||
|
||||
loss = None
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 T5 Authors and The HuggingFace Inc. team.
|
||||
# Copyright 2020 T5 Authors and The HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -26,6 +26,7 @@ import tensorflow as tf
|
||||
|
||||
from transformers.modeling_tf_utils import TFWrappedEmbeddings
|
||||
|
||||
from ...activations_tf import get_tf_activation
|
||||
from ...file_utils import (
|
||||
DUMMY_INPUTS,
|
||||
DUMMY_MASK,
|
||||
@@ -103,10 +104,35 @@ class TFT5DenseReluDense(tf.keras.layers.Layer):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class TFT5GatedGeluDense(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.wi_0 = tf.keras.layers.Dense(config.d_ff, use_bias=False, name="wi_0")
|
||||
self.wi_1 = tf.keras.layers.Dense(config.d_ff, use_bias=False, name="wi_1")
|
||||
self.wo = tf.keras.layers.Dense(config.d_model, use_bias=False, name="wo")
|
||||
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
|
||||
self.act = get_tf_activation("gelu_new")
|
||||
|
||||
def call(self, hidden_states, training=False):
|
||||
hidden_gelu = self.act(self.wi_0(hidden_states))
|
||||
hidden_linear = self.wi_1(hidden_states)
|
||||
hidden_states = hidden_gelu * hidden_linear
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
hidden_states = self.wo(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class TFT5LayerFF(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if config.feed_forward_proj == "relu":
|
||||
self.DenseReluDense = TFT5DenseReluDense(config, name="DenseReluDense")
|
||||
elif config.feed_forward_proj == "gated-gelu":
|
||||
self.DenseReluDense = TFT5GatedGeluDense(config, name="DenseReluDense")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{self.config.feed_forward_proj} is not supported. Choose between `relu` and `gated-gelu`"
|
||||
)
|
||||
self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm")
|
||||
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
|
||||
|
||||
@@ -547,9 +573,6 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_embed_tokens(self, embed_tokens):
|
||||
self.embed_tokens = embed_tokens
|
||||
|
||||
@@ -970,9 +993,6 @@ class TFT5Model(TFT5PreTrainedModel):
|
||||
def get_input_embeddings(self):
|
||||
return self.shared
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.shared
|
||||
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
self.shared.weight = new_embeddings
|
||||
self.shared.vocab_size = self.shared.weight.shape[0]
|
||||
@@ -1165,11 +1185,17 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
||||
decoder_config.is_decoder = True
|
||||
self.decoder = TFT5MainLayer(decoder_config, embed_tokens, name="decoder")
|
||||
|
||||
if not config.tie_word_embeddings:
|
||||
self.lm_head = tf.keras.layers.Dense(config.vocab_size, use_bias=False, name="lm_head")
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.shared
|
||||
|
||||
def get_output_embeddings(self):
|
||||
if self.config.tie_word_embeddings:
|
||||
return self.shared
|
||||
else:
|
||||
return self.lm_head
|
||||
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
self.shared.weight = new_embeddings
|
||||
@@ -1331,9 +1357,14 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
||||
training=training,
|
||||
)
|
||||
|
||||
sequence_output = decoder_outputs[0] * (self.model_dim ** -0.5)
|
||||
embed_tokens = self.get_output_embeddings()
|
||||
logits = embed_tokens(sequence_output, mode="linear")
|
||||
sequence_output = decoder_outputs[0]
|
||||
|
||||
# T5v1.1 does not tie output word embeddings and thus does not require downscaling
|
||||
if self.config.tie_word_embeddings:
|
||||
sequence_output = sequence_output * (self.model_dim ** -0.5)
|
||||
logits = self.get_output_embeddings()(sequence_output, mode="linear")
|
||||
else:
|
||||
logits = self.get_output_embeddings()(sequence_output)
|
||||
|
||||
loss = None if labels is None else self.compute_loss(labels, logits)
|
||||
|
||||
|
||||
@@ -1361,6 +1361,29 @@ def load_tf_weights_in_mobilebert(*args, **kwargs):
|
||||
requires_pytorch(load_tf_weights_in_mobilebert)
|
||||
|
||||
|
||||
class MT5Config:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class MT5ForConditionalGeneration:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class MT5Model:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
|
||||
@@ -970,6 +970,24 @@ class TFMobileBertPreTrainedModel:
|
||||
requires_tf(self)
|
||||
|
||||
|
||||
class TFMT5ForConditionalGeneration:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
|
||||
|
||||
class TFMT5Model:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
|
||||
|
||||
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
|
||||
39
tests/test_modeling_mt5.py
Normal file
39
tests/test_modeling_mt5.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class MT5IntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_small_integration_test(self):
|
||||
"""
|
||||
For comparision run:
|
||||
>>> import t5 # pip install t5==0.7.1
|
||||
>>> from t5.data.sentencepiece_vocabulary import SentencePieceVocabulary
|
||||
|
||||
>>> path_to_mtf_small_mt5_checkpoint = '<fill_in>'
|
||||
>>> path_to_mtf_small_mt5_spm_model_path = '<fill_in>'
|
||||
>>> t5_model = t5.models.MtfModel(model_dir=path_to_mtf_small_mt5_checkpoint, batch_size=1, tpu=None)
|
||||
>>> vocab = SentencePieceVocabulary(path_to_mtf_small_mt5_spm_model_path)
|
||||
>>> score = t5_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab)
|
||||
"""
|
||||
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("google/mt5-small", return_dict=True).to(torch_device)
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
|
||||
|
||||
input_ids = tokenizer("Hello there", return_tensors="pt").input_ids
|
||||
labels = tokenizer("Hi I am", return_tensors="pt").input_ids
|
||||
|
||||
loss = model(input_ids.to(torch_device), labels=labels.to(torch_device)).loss
|
||||
mtf_score = -(labels.shape[-1] * loss.item())
|
||||
|
||||
EXPECTED_SCORE = -84.9127
|
||||
self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4)
|
||||
@@ -490,6 +490,14 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_model_v1_1(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
# check that gated gelu feed forward and different word embeddings work
|
||||
config = config_and_inputs[0]
|
||||
config.tie_word_embeddings = False
|
||||
config.feed_forward_proj = "gated-gelu"
|
||||
self.model_tester.create_and_check_model(config, *config_and_inputs[1:])
|
||||
|
||||
def test_with_lm_head(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_with_lm_head(*config_and_inputs)
|
||||
@@ -569,7 +577,7 @@ class T5ModelIntegrationTests(unittest.TestCase):
|
||||
>>> score = t5_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab)
|
||||
"""
|
||||
|
||||
model = T5ForConditionalGeneration.from_pretrained("t5-small", return_dict=True).to(torch_device)
|
||||
model = T5ForConditionalGeneration.from_pretrained("t5-small").to(torch_device)
|
||||
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||
|
||||
input_ids = tokenizer("Hello there", return_tensors="pt").input_ids
|
||||
@@ -581,6 +589,32 @@ class T5ModelIntegrationTests(unittest.TestCase):
|
||||
EXPECTED_SCORE = -19.0845
|
||||
self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4)
|
||||
|
||||
@slow
|
||||
def test_small_v1_1_integration_test(self):
|
||||
"""
|
||||
For comparision run:
|
||||
>>> import t5 # pip install t5==0.7.1
|
||||
>>> from t5.data.sentencepiece_vocabulary import SentencePieceVocabulary
|
||||
|
||||
>>> path_to_mtf_small_t5_v1_1_checkpoint = '<fill_in>'
|
||||
>>> path_to_mtf_small_spm_model_path = '<fill_in>'
|
||||
>>> t5_model = t5.models.MtfModel(model_dir=path_to_mtf_small_t5_v1_1_checkpoint, batch_size=1, tpu=None)
|
||||
>>> vocab = SentencePieceVocabulary(path_to_mtf_small_spm_model_path, extra_ids=100)
|
||||
>>> score = t5_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab)
|
||||
"""
|
||||
|
||||
model = T5ForConditionalGeneration.from_pretrained("google/t5-v1_1-small").to(torch_device)
|
||||
tokenizer = T5Tokenizer.from_pretrained("google/t5-v1_1-small")
|
||||
|
||||
input_ids = tokenizer("Hello there", return_tensors="pt").input_ids
|
||||
labels = tokenizer("Hi I am", return_tensors="pt").input_ids
|
||||
|
||||
loss = model(input_ids.to(torch_device), labels=labels.to(torch_device)).loss
|
||||
mtf_score = -(labels.shape[-1] * loss.item())
|
||||
|
||||
EXPECTED_SCORE = -59.0293
|
||||
self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4)
|
||||
|
||||
@slow
|
||||
def test_summarization(self):
|
||||
model = self.model
|
||||
|
||||
56
tests/test_modeling_tf_mt5.py
Normal file
56
tests/test_modeling_tf_mt5.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import is_tf_available
|
||||
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM
|
||||
|
||||
|
||||
@require_tf
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class TFMT5ModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_small_integration_test(self):
|
||||
"""
|
||||
For comparision run:
|
||||
>>> import t5 # pip install t5==0.7.1
|
||||
>>> from t5.data.sentencepiece_vocabulary import SentencePieceVocabulary
|
||||
|
||||
>>> path_to_mtf_small_mt5_checkpoint = '<fill_in>'
|
||||
>>> path_to_mtf_small_mt5_spm_model_path = '<fill_in>'
|
||||
>>> t5_model = t5.models.MtfModel(model_dir=path_to_mtf_small_mt5_checkpoint, batch_size=1, tpu=None)
|
||||
>>> vocab = SentencePieceVocabulary(path_to_mtf_small_mt5_spm_model_path, extra_ids=100)
|
||||
>>> score = t5_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab)
|
||||
"""
|
||||
|
||||
model = TFAutoModelForSeq2SeqLM.from_pretrained("google/mt5-small")
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
|
||||
|
||||
input_ids = tokenizer("Hello there", return_tensors="tf").input_ids
|
||||
labels = tokenizer("Hi I am", return_tensors="tf").input_ids
|
||||
|
||||
loss = model(input_ids, labels=labels).loss
|
||||
mtf_score = -tf.math.reduce_sum(loss).numpy()
|
||||
|
||||
EXPECTED_SCORE = -84.9127
|
||||
self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4)
|
||||
@@ -258,6 +258,13 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_t5_model(*config_and_inputs)
|
||||
|
||||
def test_t5_model_v1_1(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
config = config_and_inputs[0]
|
||||
config.tie_word_embeddings = False
|
||||
config.feed_forward_proj = "gated-gelu"
|
||||
self.model_tester.create_and_check_t5_model(config, *config_and_inputs[1:])
|
||||
|
||||
def test_with_lm_head(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_t5_with_lm_head(*config_and_inputs)
|
||||
@@ -296,6 +303,58 @@ class TFT5ModelIntegrationTests(unittest.TestCase):
|
||||
def model(self):
|
||||
return TFT5ForConditionalGeneration.from_pretrained("t5-base")
|
||||
|
||||
@slow
|
||||
def test_small_integration_test(self):
|
||||
"""
|
||||
For comparision run:
|
||||
>>> import t5 # pip install t5==0.7.1
|
||||
>>> from t5.data.sentencepiece_vocabulary import SentencePieceVocabulary
|
||||
|
||||
>>> path_to_mtf_small_t5_checkpoint = '<fill_in>'
|
||||
>>> path_to_mtf_small_spm_model_path = '<fill_in>'
|
||||
>>> t5_model = t5.models.MtfModel(model_dir=path_to_mtf_small_t5_checkpoint, batch_size=1, tpu=None)
|
||||
>>> vocab = SentencePieceVocabulary(path_to_mtf_small_spm_model_path, extra_ids=100)
|
||||
>>> score = t5_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab)
|
||||
"""
|
||||
|
||||
model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
|
||||
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||
|
||||
input_ids = tokenizer("Hello there", return_tensors="tf").input_ids
|
||||
labels = tokenizer("Hi I am", return_tensors="tf").input_ids
|
||||
|
||||
loss = model(input_ids, labels=labels).loss
|
||||
mtf_score = -tf.math.reduce_sum(loss).numpy()
|
||||
|
||||
EXPECTED_SCORE = -19.0845
|
||||
self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4)
|
||||
|
||||
@slow
|
||||
def test_small_v1_1_integration_test(self):
|
||||
"""
|
||||
For comparision run:
|
||||
>>> import t5 # pip install t5==0.7.1
|
||||
>>> from t5.data.sentencepiece_vocabulary import SentencePieceVocabulary
|
||||
|
||||
>>> path_to_mtf_small_t5_v1.1_checkpoint = '<fill_in>'
|
||||
>>> path_to_mtf_small_spm_model_path = '<fill_in>'
|
||||
>>> t5_model = t5.models.MtfModel(model_dir=path_to_mtf_small_t5_v1.1_checkpoint, batch_size=1, tpu=None)
|
||||
>>> vocab = SentencePieceVocabulary(path_to_mtf_small_spm_model_path, extra_ids=100)
|
||||
>>> score = t5_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab)
|
||||
"""
|
||||
|
||||
model = TFT5ForConditionalGeneration.from_pretrained("google/t5-v1_1-small")
|
||||
tokenizer = T5Tokenizer.from_pretrained("google/t5-v1_1-small")
|
||||
|
||||
input_ids = tokenizer("Hello there", return_tensors="tf").input_ids
|
||||
labels = tokenizer("Hi I am", return_tensors="tf").input_ids
|
||||
|
||||
loss = model(input_ids, labels=labels).loss
|
||||
mtf_score = -tf.math.reduce_sum(loss).numpy()
|
||||
|
||||
EXPECTED_SCORE = -59.0293
|
||||
self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4)
|
||||
|
||||
@slow
|
||||
def test_summarization(self):
|
||||
model = self.model
|
||||
|
||||
@@ -46,6 +46,7 @@ TEST_FILES_WITH_NO_COMMON_TESTS = [
|
||||
"test_modeling_flax_bert.py",
|
||||
"test_modeling_flax_roberta.py",
|
||||
"test_modeling_mbart.py",
|
||||
"test_modeling_mt5.py",
|
||||
"test_modeling_pegasus.py",
|
||||
"test_modeling_tf_camembert.py",
|
||||
"test_modeling_tf_xlm_roberta.py",
|
||||
|
||||
Reference in New Issue
Block a user