Add PLBart (#13269)
* Init PLBART * Add missing configuration file * Add conversion script and configurationf ile * Fix style * Update modeling and conversion scripts * Fix scale embedding in config * Add comment * Fix conversion script * Add classification option to conversion script * Fix vocab size in config doc * Add tokenizer files from MBart50 * Allow no lang code in regular tokenizer * Add PLBart Tokenizer Converters * Remove mask from multi tokenizer * Remove mask from multi tokenizer * Change from MBart-50 to MBart tokenizer * Fix names and modify src/tgt behavior * Fix imports for tokenizer * Remove <mask> from multi tokenizer * Fix style * Change tokenizer_class to processor_class * Add attribute map to config class * Update modeling file to modified MBart code * Update configuration file to MBart style configuration * Fix tokenizer * Separate tokenizers * Fix error in tokenization auto * Copy MBart tests * Replace with MBart tokenization tests * Fix style * Fix language code in multi tokenizer * Fix configuration docs * Add entry for plbart_multi in transformers init * Add dummy objects and fix imports * Fix modeling tests * Add TODO in config * Fix copyright year * Fix modeling docs and test * Fix some tokenization tests and style * Add changes from review * Fix copies * Fix docs * Fix docs * Fix style * Fix year * Add changes from review * Remove extra changes * Fix base tokenizer and doc * Fix style * Fix modeling and slow tokenizer tests * Remove Multi-tokenizer Converter and Tests * Delete QA model and Multi Tokenizer dummy objects * Fix repo consistency and code quality issues * Fix example documentation * Fix style * Remove PLBartTokenizer from type checking in init * Fix consistency issue * Add changes from review * Fix style * Remove PLBartTokenizerFast * Remove FastTokenizer converter * Fix AutoTokenzier mapping * Add plbart to toctree and fix consistency issues * Add language codes tokenizer test * Fix styling and doc issues * Add fixes for failing tests * Fix copies * Fix failing modeling test * Change assert to assertTrue in modeling tests
This commit is contained in:
@@ -246,6 +246,8 @@
|
|||||||
title: Pegasus
|
title: Pegasus
|
||||||
- local: model_doc/phobert
|
- local: model_doc/phobert
|
||||||
title: PhoBERT
|
title: PhoBERT
|
||||||
|
- local: model_doc/plbart
|
||||||
|
title: PLBart
|
||||||
- local: model_doc/poolformer
|
- local: model_doc/poolformer
|
||||||
title: PoolFormer
|
title: PoolFormer
|
||||||
- local: model_doc/prophetnet
|
- local: model_doc/prophetnet
|
||||||
|
|||||||
@@ -215,6 +215,7 @@ Flax), PyTorch, and/or TensorFlow.
|
|||||||
| OpenAI GPT-2 | ✅ | ✅ | ✅ | ✅ | ✅ |
|
| OpenAI GPT-2 | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||||
| Pegasus | ✅ | ✅ | ✅ | ✅ | ✅ |
|
| Pegasus | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||||
| Perceiver | ✅ | ❌ | ✅ | ❌ | ❌ |
|
| Perceiver | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||||
|
| PLBart | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||||
| PoolFormer | ❌ | ❌ | ✅ | ❌ | ❌ |
|
| PoolFormer | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||||
| ProphetNet | ✅ | ❌ | ✅ | ❌ | ❌ |
|
| ProphetNet | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||||
| QDQBert | ❌ | ❌ | ✅ | ❌ | ❌ |
|
| QDQBert | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||||
|
|||||||
112
docs/source/model_doc/plbart.mdx
Normal file
112
docs/source/model_doc/plbart.mdx
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||||
|
the License. You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||||
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||||
|
specific language governing permissions and limitations under the License.
|
||||||
|
-->
|
||||||
|
|
||||||
|
# PLBart
|
||||||
|
|
||||||
|
**DISCLAIMER:** If you see something strange, file a [Github Issue](https://github.com/huggingface/transformers/issues/new?assignees=&labels=&template=bug-report.md&title) and assign
|
||||||
|
[@gchhablani](https://www.github.com/gchhablani).
|
||||||
|
|
||||||
|
## Overview of PLBart
|
||||||
|
|
||||||
|
The PLBART model was proposed in [Unified Pre-training for Program Understanding and Generation](https://arxiv.org/abs/2103.06333) by Wasi Uddin Ahmad, Saikat Chakraborty, Baishakhi Ray, Kai-Wei Chang.
|
||||||
|
This is a BART-like model which can be used to perform code-summarization, code-generation, and code-translation tasks. The pre-trained model `plbart-base` has been trained using multilingual denoising task
|
||||||
|
on Java, Python and English.
|
||||||
|
|
||||||
|
According to the abstract
|
||||||
|
|
||||||
|
*Code summarization and generation empower conversion between programming language (PL) and natural language (NL),
|
||||||
|
while code translation avails the migration of legacy code from one PL to another. This paper introduces PLBART,
|
||||||
|
a sequence-to-sequence model capable of performing a broad spectrum of program and language understanding and generation tasks.
|
||||||
|
PLBART is pre-trained on an extensive collection of Java and Python functions and associated NL text via denoising autoencoding.
|
||||||
|
Experiments on code summarization in the English language, code generation, and code translation in seven programming languages
|
||||||
|
show that PLBART outperforms or rivals state-of-the-art models. Moreover, experiments on discriminative tasks, e.g., program
|
||||||
|
repair, clone detection, and vulnerable code detection, demonstrate PLBART's effectiveness in program understanding.
|
||||||
|
Furthermore, analysis reveals that PLBART learns program syntax, style (e.g., identifier naming convention), logical flow
|
||||||
|
(e.g., if block inside an else block is equivalent to else if block) that are crucial to program semantics and thus excels
|
||||||
|
even with limited annotations.*
|
||||||
|
|
||||||
|
This model was contributed by [gchhablani](https://huggingface.co/gchhablani). The Authors' code can be found [here](https://github.com/wasiahmad/PLBART).
|
||||||
|
|
||||||
|
### Training of PLBart
|
||||||
|
|
||||||
|
PLBart is a multilingual encoder-decoder (sequence-to-sequence) model primarily intended for code-to-text, text-to-code, code-to-code tasks. As the
|
||||||
|
model is multilingual it expects the sequences in a different format. A special language id token is added in both the
|
||||||
|
source and target text. The source text format is `X [eos, src_lang_code]` where `X` is the source text. The
|
||||||
|
target text format is `[tgt_lang_code] X [eos]`. `bos` is never used.
|
||||||
|
|
||||||
|
However, for fine-tuning, in some cases no language token is provided in cases where a single language is used. Please refer to [the paper](https://arxiv.org/abs/2103.06333) to learn more about this.
|
||||||
|
|
||||||
|
In cases where the language code is needed, The regular [`~PLBartTokenizer.__call__`] will encode source text format, and it should be wrapped
|
||||||
|
inside the context manager [`~PLBartTokenizer.as_target_tokenizer`] to encode target text format.
|
||||||
|
|
||||||
|
- Supervised training
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import PLBartForConditionalGeneration, PLBartTokenizer
|
||||||
|
|
||||||
|
>>> tokenizer = PLBartTokenizer.from_pretrained("uclanlp/plbart-base", src_lang="en_XX", tgt_lang="python")
|
||||||
|
>>> example_python_phrase = "def maximum(a,b,c):NEW_LINE_INDENTreturn max([a,b,c])"
|
||||||
|
>>> expected_translation_english = "Returns the maximum value of a b c."
|
||||||
|
>>> inputs = tokenizer(example_python_phrase, return_tensors="pt")
|
||||||
|
>>> with tokenizer.as_target_tokenizer():
|
||||||
|
... labels = tokenizer(expected_translation_english, return_tensors="pt")
|
||||||
|
>>> inputs["labels"] = labels["input_ids"]
|
||||||
|
>>> # forward pass
|
||||||
|
>>> model(**inputs)
|
||||||
|
```
|
||||||
|
|
||||||
|
- Generation
|
||||||
|
|
||||||
|
While generating the target text set the `decoder_start_token_id` to the target language id. The following
|
||||||
|
example shows how to translate Python to English using the `uclanlp/plbart-python-en_XX` model.
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import PLBartForConditionalGeneration, PLBartTokenizer
|
||||||
|
|
||||||
|
>>> tokenizer = PLBartTokenizer.from_pretrained("uclanlp/plbart-python-en_XX", src_lang="python", tgt_lang="en_XX")
|
||||||
|
>>> example_python_phrase = "def maximum(a,b,c):NEW_LINE_INDENTreturn max([a,b,c])"
|
||||||
|
>>> inputs = tokenizer(example_python_phrase, return_tensors="pt")
|
||||||
|
>>> model = PLBartForConditionalGeneration.from_pretrained("uclanlp/plbart-python-en_XX")
|
||||||
|
>>> translated_tokens = model.generate(**inputs, decoder_start_token_id=tokenizer.lang_code_to_id["en_XX"])
|
||||||
|
>>> tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
|
||||||
|
"Returns the maximum value of a b c."
|
||||||
|
```
|
||||||
|
|
||||||
|
## PLBartConfig
|
||||||
|
|
||||||
|
[[autodoc]] PLBartConfig
|
||||||
|
|
||||||
|
## PLBartTokenizer
|
||||||
|
|
||||||
|
[[autodoc]] PLBartTokenizer
|
||||||
|
- as_target_tokenizer
|
||||||
|
- build_inputs_with_special_tokens
|
||||||
|
|
||||||
|
## PLBartModel
|
||||||
|
|
||||||
|
[[autodoc]] PLBartModel
|
||||||
|
- forward
|
||||||
|
|
||||||
|
## PLBartForConditionalGeneration
|
||||||
|
|
||||||
|
[[autodoc]] PLBartForConditionalGeneration
|
||||||
|
- forward
|
||||||
|
|
||||||
|
## PLBartForSequenceClassification
|
||||||
|
|
||||||
|
[[autodoc]] PLBartForSequenceClassification
|
||||||
|
- forward
|
||||||
|
|
||||||
|
## PLBartForCausalLM
|
||||||
|
|
||||||
|
[[autodoc]] PLBartForCausalLM
|
||||||
|
- forward
|
||||||
@@ -57,6 +57,7 @@ Ready-made configurations include the following architectures:
|
|||||||
- Marian
|
- Marian
|
||||||
- mBART
|
- mBART
|
||||||
- OpenAI GPT-2
|
- OpenAI GPT-2
|
||||||
|
- PLBart
|
||||||
- RoBERTa
|
- RoBERTa
|
||||||
- T5
|
- T5
|
||||||
- XLM-RoBERTa
|
- XLM-RoBERTa
|
||||||
|
|||||||
@@ -263,6 +263,7 @@ _import_structure = {
|
|||||||
"models.pegasus": ["PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP", "PegasusConfig", "PegasusTokenizer"],
|
"models.pegasus": ["PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP", "PegasusConfig", "PegasusTokenizer"],
|
||||||
"models.perceiver": ["PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP", "PerceiverConfig", "PerceiverTokenizer"],
|
"models.perceiver": ["PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP", "PerceiverConfig", "PerceiverTokenizer"],
|
||||||
"models.phobert": ["PhobertTokenizer"],
|
"models.phobert": ["PhobertTokenizer"],
|
||||||
|
"models.plbart": ["PLBART_PRETRAINED_CONFIG_ARCHIVE_MAP", "PLBartConfig"],
|
||||||
"models.poolformer": ["POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "PoolFormerConfig"],
|
"models.poolformer": ["POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "PoolFormerConfig"],
|
||||||
"models.prophetnet": ["PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "ProphetNetConfig", "ProphetNetTokenizer"],
|
"models.prophetnet": ["PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "ProphetNetConfig", "ProphetNetTokenizer"],
|
||||||
"models.qdqbert": ["QDQBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "QDQBertConfig"],
|
"models.qdqbert": ["QDQBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "QDQBertConfig"],
|
||||||
@@ -410,6 +411,7 @@ if is_sentencepiece_available():
|
|||||||
_import_structure["models.mluke"].append("MLukeTokenizer")
|
_import_structure["models.mluke"].append("MLukeTokenizer")
|
||||||
_import_structure["models.mt5"].append("MT5Tokenizer")
|
_import_structure["models.mt5"].append("MT5Tokenizer")
|
||||||
_import_structure["models.pegasus"].append("PegasusTokenizer")
|
_import_structure["models.pegasus"].append("PegasusTokenizer")
|
||||||
|
_import_structure["models.plbart"].append("PLBartTokenizer")
|
||||||
_import_structure["models.reformer"].append("ReformerTokenizer")
|
_import_structure["models.reformer"].append("ReformerTokenizer")
|
||||||
_import_structure["models.rembert"].append("RemBertTokenizer")
|
_import_structure["models.rembert"].append("RemBertTokenizer")
|
||||||
_import_structure["models.speech_to_text"].append("Speech2TextTokenizer")
|
_import_structure["models.speech_to_text"].append("Speech2TextTokenizer")
|
||||||
@@ -1219,6 +1221,16 @@ if is_torch_available():
|
|||||||
"PerceiverPreTrainedModel",
|
"PerceiverPreTrainedModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
_import_structure["models.plbart"].extend(
|
||||||
|
[
|
||||||
|
"PLBART_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"PLBartForCausalLM",
|
||||||
|
"PLBartForConditionalGeneration",
|
||||||
|
"PLBartForSequenceClassification",
|
||||||
|
"PLBartModel",
|
||||||
|
"PLBartPreTrainedModel",
|
||||||
|
]
|
||||||
|
)
|
||||||
_import_structure["models.poolformer"].extend(
|
_import_structure["models.poolformer"].extend(
|
||||||
[
|
[
|
||||||
"POOLFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"POOLFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
@@ -2498,6 +2510,7 @@ if TYPE_CHECKING:
|
|||||||
from .models.pegasus import PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP, PegasusConfig, PegasusTokenizer
|
from .models.pegasus import PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP, PegasusConfig, PegasusTokenizer
|
||||||
from .models.perceiver import PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP, PerceiverConfig, PerceiverTokenizer
|
from .models.perceiver import PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP, PerceiverConfig, PerceiverTokenizer
|
||||||
from .models.phobert import PhobertTokenizer
|
from .models.phobert import PhobertTokenizer
|
||||||
|
from .models.plbart import PLBART_PRETRAINED_CONFIG_ARCHIVE_MAP, PLBartConfig
|
||||||
from .models.poolformer import POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, PoolFormerConfig
|
from .models.poolformer import POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, PoolFormerConfig
|
||||||
from .models.prophetnet import PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ProphetNetConfig, ProphetNetTokenizer
|
from .models.prophetnet import PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ProphetNetConfig, ProphetNetTokenizer
|
||||||
from .models.qdqbert import QDQBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, QDQBertConfig
|
from .models.qdqbert import QDQBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, QDQBertConfig
|
||||||
@@ -2630,6 +2643,7 @@ if TYPE_CHECKING:
|
|||||||
from .models.mluke import MLukeTokenizer
|
from .models.mluke import MLukeTokenizer
|
||||||
from .models.mt5 import MT5Tokenizer
|
from .models.mt5 import MT5Tokenizer
|
||||||
from .models.pegasus import PegasusTokenizer
|
from .models.pegasus import PegasusTokenizer
|
||||||
|
from .models.plbart import PLBartTokenizer
|
||||||
from .models.reformer import ReformerTokenizer
|
from .models.reformer import ReformerTokenizer
|
||||||
from .models.rembert import RemBertTokenizer
|
from .models.rembert import RemBertTokenizer
|
||||||
from .models.speech_to_text import Speech2TextTokenizer
|
from .models.speech_to_text import Speech2TextTokenizer
|
||||||
@@ -3292,6 +3306,14 @@ if TYPE_CHECKING:
|
|||||||
PerceiverModel,
|
PerceiverModel,
|
||||||
PerceiverPreTrainedModel,
|
PerceiverPreTrainedModel,
|
||||||
)
|
)
|
||||||
|
from .models.plbart import (
|
||||||
|
PLBART_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
PLBartForCausalLM,
|
||||||
|
PLBartForConditionalGeneration,
|
||||||
|
PLBartForSequenceClassification,
|
||||||
|
PLBartModel,
|
||||||
|
PLBartPreTrainedModel,
|
||||||
|
)
|
||||||
from .models.poolformer import (
|
from .models.poolformer import (
|
||||||
POOLFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
POOLFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
PoolFormerForImageClassification,
|
PoolFormerForImageClassification,
|
||||||
|
|||||||
@@ -83,6 +83,7 @@ from . import (
|
|||||||
pegasus,
|
pegasus,
|
||||||
perceiver,
|
perceiver,
|
||||||
phobert,
|
phobert,
|
||||||
|
plbart,
|
||||||
poolformer,
|
poolformer,
|
||||||
prophetnet,
|
prophetnet,
|
||||||
qdqbert,
|
qdqbert,
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
|||||||
("perceiver", "PerceiverConfig"),
|
("perceiver", "PerceiverConfig"),
|
||||||
("gptj", "GPTJConfig"),
|
("gptj", "GPTJConfig"),
|
||||||
("layoutlmv2", "LayoutLMv2Config"),
|
("layoutlmv2", "LayoutLMv2Config"),
|
||||||
|
("plbart", "PLBartConfig"),
|
||||||
("beit", "BeitConfig"),
|
("beit", "BeitConfig"),
|
||||||
("rembert", "RemBertConfig"),
|
("rembert", "RemBertConfig"),
|
||||||
("visual_bert", "VisualBertConfig"),
|
("visual_bert", "VisualBertConfig"),
|
||||||
@@ -143,6 +144,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
|
|||||||
("perceiver", "PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
("perceiver", "PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||||
("gptj", "GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
("gptj", "GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||||
("layoutlmv2", "LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
("layoutlmv2", "LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||||
|
("plbart", "PLBART_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||||
("beit", "BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
("beit", "BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||||
("rembert", "REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
("rembert", "REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||||
("visual_bert", "VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
("visual_bert", "VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||||
@@ -228,6 +230,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
|||||||
("perceiver", "Perceiver"),
|
("perceiver", "Perceiver"),
|
||||||
("gptj", "GPT-J"),
|
("gptj", "GPT-J"),
|
||||||
("beit", "BEiT"),
|
("beit", "BEiT"),
|
||||||
|
("plbart", "PLBart"),
|
||||||
("rembert", "RemBERT"),
|
("rembert", "RemBERT"),
|
||||||
("layoutlmv2", "LayoutLMv2"),
|
("layoutlmv2", "LayoutLMv2"),
|
||||||
("visual_bert", "VisualBert"),
|
("visual_bert", "VisualBert"),
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
|||||||
("perceiver", "PerceiverModel"),
|
("perceiver", "PerceiverModel"),
|
||||||
("gptj", "GPTJModel"),
|
("gptj", "GPTJModel"),
|
||||||
("layoutlmv2", "LayoutLMv2Model"),
|
("layoutlmv2", "LayoutLMv2Model"),
|
||||||
|
("plbart", "PLBartModel"),
|
||||||
("beit", "BeitModel"),
|
("beit", "BeitModel"),
|
||||||
("rembert", "RemBertModel"),
|
("rembert", "RemBertModel"),
|
||||||
("visual_bert", "VisualBertModel"),
|
("visual_bert", "VisualBertModel"),
|
||||||
@@ -163,6 +164,7 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
|
|||||||
# Model with LM heads mapping
|
# Model with LM heads mapping
|
||||||
("yoso", "YosoForMaskedLM"),
|
("yoso", "YosoForMaskedLM"),
|
||||||
("nystromformer", "NystromformerForMaskedLM"),
|
("nystromformer", "NystromformerForMaskedLM"),
|
||||||
|
("plbart", "PLBartForConditionalGeneration"),
|
||||||
("qdqbert", "QDQBertForMaskedLM"),
|
("qdqbert", "QDQBertForMaskedLM"),
|
||||||
("fnet", "FNetForMaskedLM"),
|
("fnet", "FNetForMaskedLM"),
|
||||||
("gptj", "GPTJForCausalLM"),
|
("gptj", "GPTJForCausalLM"),
|
||||||
@@ -216,6 +218,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
|||||||
[
|
[
|
||||||
# Model for Causal LM mapping
|
# Model for Causal LM mapping
|
||||||
("xglm", "XGLMForCausalLM"),
|
("xglm", "XGLMForCausalLM"),
|
||||||
|
("plbart", "PLBartForCausalLM"),
|
||||||
("qdqbert", "QDQBertLMHeadModel"),
|
("qdqbert", "QDQBertLMHeadModel"),
|
||||||
("trocr", "TrOCRForCausalLM"),
|
("trocr", "TrOCRForCausalLM"),
|
||||||
("gptj", "GPTJForCausalLM"),
|
("gptj", "GPTJForCausalLM"),
|
||||||
@@ -361,6 +364,7 @@ MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
|
|||||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||||
[
|
[
|
||||||
# Model for Seq2Seq Causal LM mapping
|
# Model for Seq2Seq Causal LM mapping
|
||||||
|
("plbart", "PLBartForConditionalGeneration"),
|
||||||
("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"),
|
("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"),
|
||||||
("m2m_100", "M2M100ForConditionalGeneration"),
|
("m2m_100", "M2M100ForConditionalGeneration"),
|
||||||
("led", "LEDForConditionalGeneration"),
|
("led", "LEDForConditionalGeneration"),
|
||||||
@@ -391,6 +395,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|||||||
# Model for Sequence Classification mapping
|
# Model for Sequence Classification mapping
|
||||||
("yoso", "YosoForSequenceClassification"),
|
("yoso", "YosoForSequenceClassification"),
|
||||||
("nystromformer", "NystromformerForSequenceClassification"),
|
("nystromformer", "NystromformerForSequenceClassification"),
|
||||||
|
("plbart", "PLBartForSequenceClassification"),
|
||||||
("perceiver", "PerceiverForSequenceClassification"),
|
("perceiver", "PerceiverForSequenceClassification"),
|
||||||
("qdqbert", "QDQBertForSequenceClassification"),
|
("qdqbert", "QDQBertForSequenceClassification"),
|
||||||
("fnet", "FNetForSequenceClassification"),
|
("fnet", "FNetForSequenceClassification"),
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ if TYPE_CHECKING:
|
|||||||
else:
|
else:
|
||||||
TOKENIZER_MAPPING_NAMES = OrderedDict(
|
TOKENIZER_MAPPING_NAMES = OrderedDict(
|
||||||
[
|
[
|
||||||
|
("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)),
|
||||||
("fnet", ("FNetTokenizer", "FNetTokenizerFast" if is_tokenizers_available() else None)),
|
("fnet", ("FNetTokenizer", "FNetTokenizerFast" if is_tokenizers_available() else None)),
|
||||||
("retribert", ("RetriBertTokenizer", "RetriBertTokenizerFast" if is_tokenizers_available() else None)),
|
("retribert", ("RetriBertTokenizer", "RetriBertTokenizerFast" if is_tokenizers_available() else None)),
|
||||||
("roformer", ("RoFormerTokenizer", "RoFormerTokenizerFast" if is_tokenizers_available() else None)),
|
("roformer", ("RoFormerTokenizer", "RoFormerTokenizerFast" if is_tokenizers_available() else None)),
|
||||||
|
|||||||
61
src/transformers/models/plbart/__init__.py
Normal file
61
src/transformers/models/plbart/__init__.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
# flake8: noqa
|
||||||
|
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||||
|
# module, but to preserve other warnings. So, don't check this module at all.
|
||||||
|
|
||||||
|
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from ...file_utils import _LazyModule, is_sentencepiece_available, is_tokenizers_available, is_torch_available
|
||||||
|
|
||||||
|
|
||||||
|
_import_structure = {
|
||||||
|
"configuration_plbart": ["PLBART_PRETRAINED_CONFIG_ARCHIVE_MAP", "PLBartConfig"],
|
||||||
|
}
|
||||||
|
|
||||||
|
if is_sentencepiece_available():
|
||||||
|
_import_structure["tokenization_plbart"] = ["PLBartTokenizer"]
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
_import_structure["modeling_plbart"] = [
|
||||||
|
"PLBART_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"PLBartForCausalLM",
|
||||||
|
"PLBartForConditionalGeneration",
|
||||||
|
"PLBartForSequenceClassification",
|
||||||
|
"PLBartModel",
|
||||||
|
"PLBartPreTrainedModel",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .configuration_plbart import PLBART_PRETRAINED_CONFIG_ARCHIVE_MAP, PLBartConfig
|
||||||
|
|
||||||
|
if is_sentencepiece_available():
|
||||||
|
from .tokenization_plbart import PLBartTokenizer
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
from .modeling_plbart import (
|
||||||
|
PLBART_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
PLBartForCausalLM,
|
||||||
|
PLBartForConditionalGeneration,
|
||||||
|
PLBartForSequenceClassification,
|
||||||
|
PLBartModel,
|
||||||
|
PLBartPreTrainedModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
else:
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
|
||||||
192
src/transformers/models/plbart/configuration_plbart.py
Normal file
192
src/transformers/models/plbart/configuration_plbart.py
Normal file
@@ -0,0 +1,192 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022, UCLA NLP, The Facebook AI Research Team and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" PLBART model configuration"""
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Mapping
|
||||||
|
|
||||||
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...onnx import OnnxConfigWithPast
|
||||||
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
PLBART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
|
"uclanlp/plbart-base": "https://huggingface.co/uclanlp/plbart-base/resolve/main/config.json",
|
||||||
|
# See all PLBART models at https://huggingface.co/models?filter=plbart
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class PLBartConfig(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a [`PLBartModel`]. It is used to instantiate an
|
||||||
|
PLBART model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||||
|
with the defaults will yield a similar configuration to that of the PLBART
|
||||||
|
[uclanlp/plbart-base](https://huggingface.co/uclanlp/plbart-base) architecture.
|
||||||
|
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_size (`int`, *optional*, defaults to 50005):
|
||||||
|
Vocabulary size of the PLBART model. Defines the number of different tokens that can be represented by the
|
||||||
|
`inputs_ids` passed when calling [`PLBartModel`].
|
||||||
|
d_model (`int`, *optional*, defaults to 768):
|
||||||
|
Dimensionality of the layers and the pooler layer.
|
||||||
|
encoder_layers (`int`, *optional*, defaults to 6):
|
||||||
|
Number of encoder layers.
|
||||||
|
decoder_layers (`int`, *optional*, defaults to 6):
|
||||||
|
Number of decoder layers.
|
||||||
|
encoder_attention_heads (`int`, *optional*, defaults to 12):
|
||||||
|
Number of attention heads for each attention layer in the Transformer encoder.
|
||||||
|
decoder_attention_heads (`int`, *optional*, defaults to 12):
|
||||||
|
Number of attention heads for each attention layer in the Transformer decoder.
|
||||||
|
decoder_ffn_dim (`int`, *optional*, defaults to 3072):
|
||||||
|
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
||||||
|
encoder_ffn_dim (`int`, *optional*, defaults to 3072):
|
||||||
|
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
||||||
|
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
|
||||||
|
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||||
|
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
||||||
|
dropout (`float`, *optional*, defaults to 0.1):
|
||||||
|
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
||||||
|
attention_dropout (`float`, *optional*, defaults to 0.1):
|
||||||
|
The dropout ratio for the attention probabilities.
|
||||||
|
activation_dropout (`float`, *optional*, defaults to 0.0):
|
||||||
|
The dropout ratio for activations inside the fully connected layer.
|
||||||
|
classifier_dropout (`float`, *optional*, defaults to 0.0):
|
||||||
|
The dropout ratio for classifier.
|
||||||
|
max_position_embeddings (`int`, *optional*, defaults to 1024):
|
||||||
|
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
||||||
|
just in case (e.g., 512 or 1024 or 2048).
|
||||||
|
init_std (`float`, *optional*, defaults to 0.02):
|
||||||
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||||
|
encoder_layerdrop: (`float`, *optional*, defaults to 0.0):
|
||||||
|
The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
||||||
|
for more details.
|
||||||
|
decoder_layerdrop: (`float`, *optional*, defaults to 0.0):
|
||||||
|
The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
||||||
|
for more details.
|
||||||
|
scale_embedding (`bool`, *optional*, defaults to `True`):
|
||||||
|
Scale embeddings by diving by sqrt(d_model).
|
||||||
|
use_cache (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not the model should return the last key/values attentions (not used by all models)
|
||||||
|
forced_eos_token_id (`int`, *optional*, defaults to 2):
|
||||||
|
The id of the token to force as the last generated token when `max_length` is reached. Usually set to
|
||||||
|
`eos_token_id`.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import PLBartModel, PLBartConfig
|
||||||
|
|
||||||
|
>>> # Initializing a PLBART uclanlp/plbart-base style configuration
|
||||||
|
>>> configuration = PLBartConfig()
|
||||||
|
>>> # Initializing a model from the uclanlp/plbart-base style configuration
|
||||||
|
>>> model = PLBartModel(configuration)
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```"""
|
||||||
|
|
||||||
|
model_type = "plbart"
|
||||||
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=50005,
|
||||||
|
max_position_embeddings=1024,
|
||||||
|
encoder_layers=6,
|
||||||
|
encoder_ffn_dim=3072,
|
||||||
|
encoder_attention_heads=12,
|
||||||
|
decoder_layers=6,
|
||||||
|
decoder_ffn_dim=3072,
|
||||||
|
decoder_attention_heads=12,
|
||||||
|
encoder_layerdrop=0.0,
|
||||||
|
decoder_layerdrop=0.0,
|
||||||
|
use_cache=True,
|
||||||
|
is_encoder_decoder=True,
|
||||||
|
activation_function="gelu",
|
||||||
|
d_model=768,
|
||||||
|
dropout=0.1,
|
||||||
|
attention_dropout=0.1,
|
||||||
|
activation_dropout=0.0,
|
||||||
|
init_std=0.02,
|
||||||
|
classifier_dropout=0.0,
|
||||||
|
scale_embedding=True,
|
||||||
|
pad_token_id=1,
|
||||||
|
bos_token_id=0,
|
||||||
|
eos_token_id=2,
|
||||||
|
forced_eos_token_id=2,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.d_model = d_model
|
||||||
|
self.encoder_ffn_dim = encoder_ffn_dim
|
||||||
|
self.encoder_layers = encoder_layers
|
||||||
|
self.encoder_attention_heads = encoder_attention_heads
|
||||||
|
self.decoder_ffn_dim = decoder_ffn_dim
|
||||||
|
self.decoder_layers = decoder_layers
|
||||||
|
self.decoder_attention_heads = decoder_attention_heads
|
||||||
|
self.dropout = dropout
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
self.activation_dropout = activation_dropout
|
||||||
|
self.activation_function = activation_function
|
||||||
|
self.init_std = init_std
|
||||||
|
self.encoder_layerdrop = encoder_layerdrop
|
||||||
|
self.decoder_layerdrop = decoder_layerdrop
|
||||||
|
self.classifier_dropout = classifier_dropout
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.num_hidden_layers = encoder_layers
|
||||||
|
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
is_encoder_decoder=is_encoder_decoder,
|
||||||
|
forced_eos_token_id=forced_eos_token_id,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PLBartOnnxConfig(OnnxConfigWithPast):
|
||||||
|
@property
|
||||||
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
return OrderedDict(
|
||||||
|
[
|
||||||
|
("input_ids", {0: "batch", 1: "sequence"}),
|
||||||
|
("attention_mask", {0: "batch", 1: "sequence"}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
if self.use_past:
|
||||||
|
return OrderedDict(
|
||||||
|
[
|
||||||
|
("last_hidden_state", {0: "batch", 1: "sequence"}),
|
||||||
|
("past_keys", {0: "batch", 2: "sequence"}),
|
||||||
|
("encoder_last_hidden_state", {0: "batch", 1: "sequence"}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return OrderedDict(
|
||||||
|
[
|
||||||
|
("last_hidden_state", {0: "batch", 1: "sequence"}),
|
||||||
|
("encoder_last_hidden_state", {0: "batch", 1: "sequence"}),
|
||||||
|
]
|
||||||
|
)
|
||||||
@@ -0,0 +1,94 @@
|
|||||||
|
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from transformers import PLBartConfig, PLBartForConditionalGeneration, PLBartForSequenceClassification
|
||||||
|
|
||||||
|
|
||||||
|
def remove_ignore_keys_(state_dict):
|
||||||
|
ignore_keys = [
|
||||||
|
"encoder.version",
|
||||||
|
"decoder.version",
|
||||||
|
"model.encoder.version",
|
||||||
|
"model.decoder.version",
|
||||||
|
"_float_tensor",
|
||||||
|
"decoder.output_projection.weight",
|
||||||
|
]
|
||||||
|
for k in ignore_keys:
|
||||||
|
state_dict.pop(k, None)
|
||||||
|
|
||||||
|
|
||||||
|
def make_linear_from_emb(emb):
|
||||||
|
vocab_size, emb_size = emb.weight.shape
|
||||||
|
lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
|
||||||
|
lin_layer.weight.data = emb.weight.data
|
||||||
|
return lin_layer
|
||||||
|
|
||||||
|
|
||||||
|
def convert_fairseq_plbart_checkpoint_from_disk(
|
||||||
|
checkpoint_path, hf_config_path="uclanlp/plbart-base", finetuned=False, classification=False
|
||||||
|
):
|
||||||
|
state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
|
||||||
|
remove_ignore_keys_(state_dict)
|
||||||
|
vocab_size = state_dict["encoder.embed_tokens.weight"].shape[0]
|
||||||
|
|
||||||
|
plbart_config = PLBartConfig.from_pretrained(hf_config_path, vocab_size=vocab_size)
|
||||||
|
|
||||||
|
state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"]
|
||||||
|
if not classification:
|
||||||
|
model = PLBartForConditionalGeneration(plbart_config)
|
||||||
|
model.model.load_state_dict(state_dict)
|
||||||
|
if finetuned:
|
||||||
|
model.lm_head = make_linear_from_emb(model.model.shared)
|
||||||
|
|
||||||
|
else:
|
||||||
|
classification_head = {}
|
||||||
|
for key, value in state_dict.copy().items():
|
||||||
|
if key.startswith("classification_heads.sentence_classification_head"):
|
||||||
|
classification_head[key.replace("classification_heads.sentence_classification_head.", "")] = value
|
||||||
|
state_dict.pop(key)
|
||||||
|
model = PLBartForSequenceClassification(plbart_config)
|
||||||
|
model.model.load_state_dict(state_dict)
|
||||||
|
model.classification_head.load_state_dict(classification_head)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
# Required parameters
|
||||||
|
parser.add_argument("fairseq_path", type=str, help="model.pt on local filesystem.")
|
||||||
|
parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--hf_config",
|
||||||
|
default="uclanlp/plbart-base",
|
||||||
|
type=str,
|
||||||
|
help="Which huggingface architecture to use: plbart-base",
|
||||||
|
)
|
||||||
|
parser.add_argument("--finetuned", action="store_true", help="whether the model is a fine-tuned checkpoint")
|
||||||
|
parser.add_argument(
|
||||||
|
"--classification", action="store_true", help="whether the model is a classification checkpoint"
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
model = convert_fairseq_plbart_checkpoint_from_disk(
|
||||||
|
args.fairseq_path,
|
||||||
|
hf_config_path=args.hf_config,
|
||||||
|
finetuned=args.finetuned,
|
||||||
|
classification=args.classification,
|
||||||
|
)
|
||||||
|
model.save_pretrained(args.pytorch_dump_folder_path)
|
||||||
1717
src/transformers/models/plbart/modeling_plbart.py
Executable file
1717
src/transformers/models/plbart/modeling_plbart.py
Executable file
File diff suppressed because it is too large
Load Diff
448
src/transformers/models/plbart/tokenization_plbart.py
Normal file
448
src/transformers/models/plbart/tokenization_plbart.py
Normal file
@@ -0,0 +1,448 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022, UCLA NLP, The Facebook AI Research Team Authors and The HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from shutil import copyfile
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import sentencepiece as spm
|
||||||
|
|
||||||
|
from ...tokenization_utils import AddedToken, BatchEncoding, PreTrainedTokenizer
|
||||||
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
SPIECE_UNDERLINE = "▁"
|
||||||
|
|
||||||
|
VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model", "tokenizer_file": "tokenizer.json"}
|
||||||
|
|
||||||
|
PRETRAINED_VOCAB_FILES_MAP = {
|
||||||
|
"vocab_file": {
|
||||||
|
"uclanlp/plbart-base": "https://huggingface.co/uclanlp/plbart-base/resolve/main/sentencepiece.bpe.model",
|
||||||
|
"uclanlp/plbart-c-cpp-defect-detection": "https://huggingface.co/uclanlp/plbart-c-cpp-defect-detection/resolve/main/sentencepiece.bpe.model",
|
||||||
|
"uclanlp/plbart-cs-java": "https://huggingface.co/uclanlp/plbart-cs-java/resolve/main/sentencepiece.bpe.model",
|
||||||
|
"uclanlp/plbart-en_XX-java": "https://huggingface.co/uclanlp/plbart-en_XX-java/resolve/main/sentencepiece.bpe.model",
|
||||||
|
"uclanlp/plbart-go-en_XX": "https://huggingface.co/uclanlp/plbart-go-en_XX/resolve/main/sentencepiece.bpe.model",
|
||||||
|
"uclanlp/plbart-java-clone-detection": "https://huggingface.co/uclanlp/plbart-java-clone-detection/resolve/main/sentencepiece.bpe.model",
|
||||||
|
"uclanlp/plbart-java-cs": "https://huggingface.co/uclanlp/plbart-java-cs/resolve/main/sentencepiece.bpe.model",
|
||||||
|
"uclanlp/plbart-java-en_XX": "https://huggingface.co/uclanlp/plbart-java-en_XX/resolve/main/sentencepiece.bpe.model",
|
||||||
|
"uclanlp/plbart-javascript-en_XX": "https://huggingface.co/uclanlp/plbart-javascript-en_XX/resolve/main/sentencepiece.bpe.model",
|
||||||
|
"uclanlp/plbart-php-en_XX": "https://huggingface.co/uclanlp/plbart-php-en_XX/resolve/main/sentencepiece.bpe.model",
|
||||||
|
"uclanlp/plbart-python-en_XX": "https://huggingface.co/uclanlp/plbart-python-en_XX/resolve/main/sentencepiece.bpe.model",
|
||||||
|
"uclanlp/plbart-refine-java-medium": "https://huggingface.co/uclanlp/plbart-refine-java-medium/resolve/main/sentencepiece.bpe.model",
|
||||||
|
"uclanlp/plbart-refine-java-small": "https://huggingface.co/uclanlp/plbart-refine-java-small/resolve/main/sentencepiece.bpe.model",
|
||||||
|
"uclanlp/plbart-ruby-en_XX": "https://huggingface.co/uclanlp/plbart-ruby-en_XX/resolve/main/sentencepiece.bpe.model",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||||
|
"uclanlp/plbart-base": 1024,
|
||||||
|
"uclanlp/plbart-c-cpp-defect-detection": 1024,
|
||||||
|
"uclanlp/plbart-cs-java": 1024,
|
||||||
|
"uclanlp/plbart-en_XX-java": 1024,
|
||||||
|
"uclanlp/plbart-go-en_XX": 1024,
|
||||||
|
"uclanlp/plbart-java-clone-detection": 1024,
|
||||||
|
"uclanlp/plbart-java-cs": 1024,
|
||||||
|
"uclanlp/plbart-java-en_XX": 1024,
|
||||||
|
"uclanlp/plbart-javascript-en_XX": 1024,
|
||||||
|
"uclanlp/plbart-php-en_XX": 1024,
|
||||||
|
"uclanlp/plbart-python-en_XX": 1024,
|
||||||
|
"uclanlp/plbart-refine-java-medium": 1024,
|
||||||
|
"uclanlp/plbart-refine-java-small": 1024,
|
||||||
|
"uclanlp/plbart-ruby-en_XX": 1024,
|
||||||
|
}
|
||||||
|
|
||||||
|
FAIRSEQ_LANGUAGE_CODES = {
|
||||||
|
"base": ["java", "python", "en_XX"],
|
||||||
|
"multi": ["java", "python", "en_XX", "javascript", "php", "ruby", "go"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class PLBartTokenizer(PreTrainedTokenizer):
|
||||||
|
"""
|
||||||
|
Construct an PLBART tokenizer.
|
||||||
|
|
||||||
|
Adapted from [`RobertaTokenizer`] and [`XLNetTokenizer`]. Based on
|
||||||
|
[SentencePiece](https://github.com/google/sentencepiece).
|
||||||
|
|
||||||
|
The tokenization method is `<tokens> <eos> <language code>` for source language documents, and ``<language code>
|
||||||
|
<tokens> <eos>``` for target language documents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_file (`str`):
|
||||||
|
Path to the vocabulary file.
|
||||||
|
src_lang (`str`, *optional*):
|
||||||
|
A string representing the source language.
|
||||||
|
tgt_lang (`str`, *optional*):
|
||||||
|
A string representing the target language.
|
||||||
|
bos_token (`str`, *optional*, defaults to `"<s>"`):
|
||||||
|
The start of sequence token.
|
||||||
|
eos_token (`str`, *optional*, defaults to `"</s>"`):
|
||||||
|
The end of sequence token.
|
||||||
|
sep_token (`str`, *optional*, defaults to `"</s>"`):
|
||||||
|
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
|
||||||
|
sequence classification or for a text and a question for question answering. It is also used as the last
|
||||||
|
token of a sequence built with special tokens.
|
||||||
|
cls_token (`str`, *optional*, defaults to `"<s>"`):
|
||||||
|
The cls token, which is a special token used as the first token for all tasks.
|
||||||
|
unk_token (`str`, *optional*, defaults to `"<unk>"`):
|
||||||
|
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
||||||
|
token instead.
|
||||||
|
pad_token (`str`, *optional*, defaults to `"<pad>"`):
|
||||||
|
The token used for padding, for example when batching sequences of different lengths.
|
||||||
|
mask_token(`str`, *optional*, defaults to `"<mask>"`):
|
||||||
|
The token used for masking values. This is the token used when training this model with masking tasks. This
|
||||||
|
is only used in the `"base"` tokenizer type. For `"multi"` tokenizer, masking is never done for the
|
||||||
|
downstream tasks.
|
||||||
|
language_codes (`str`, *optional*, defaults to `"base"`):
|
||||||
|
What language codes to use. Should be one of `"base"` or `"multi"`.
|
||||||
|
sp_model_kwargs (`dict`, *optional*):
|
||||||
|
Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
|
||||||
|
SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
|
||||||
|
to set:
|
||||||
|
- `enable_sampling`: Enable subword regularization.
|
||||||
|
- `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
|
||||||
|
- `nbest_size = {0,1}`: No sampling is performed.
|
||||||
|
- `nbest_size > 1`: samples from the nbest_size results.
|
||||||
|
- `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
|
||||||
|
using forward-filtering-and-backward-sampling algorithm.
|
||||||
|
- `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
|
||||||
|
BPE-dropout.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import PLBartTokenizer
|
||||||
|
|
||||||
|
>>> tokenizer = PLBartTokenizer.from_pretrained("uclanlp/plbart-python-en_XX", src_lang="python", tgt_lang="en_XX")
|
||||||
|
>>> example_python_phrase = "def maximum(a,b,c):NEW_LINE_INDENTreturn max([a,b,c])"
|
||||||
|
>>> expected_translation_english = "Returns the maximum value of a b c."
|
||||||
|
>>> inputs = tokenizer(example_python_phrase, return_tensors="pt")
|
||||||
|
>>> with tokenizer.as_target_tokenizer():
|
||||||
|
... labels = tokenizer(expected_translation_english, return_tensors="pt")
|
||||||
|
>>> inputs["labels"] = labels["input_ids"]
|
||||||
|
```"""
|
||||||
|
|
||||||
|
vocab_files_names = VOCAB_FILES_NAMES
|
||||||
|
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||||
|
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||||
|
model_input_names = ["input_ids", "attention_mask"]
|
||||||
|
|
||||||
|
prefix_tokens: List[int] = []
|
||||||
|
suffix_tokens: List[int] = []
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_file,
|
||||||
|
bos_token="<s>",
|
||||||
|
eos_token="</s>",
|
||||||
|
sep_token="</s>",
|
||||||
|
cls_token="<s>",
|
||||||
|
unk_token="<unk>",
|
||||||
|
pad_token="<pad>",
|
||||||
|
mask_token="<mask>",
|
||||||
|
language_codes="base",
|
||||||
|
tokenizer_file=None,
|
||||||
|
src_lang=None,
|
||||||
|
tgt_lang=None,
|
||||||
|
sp_model_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
additional_special_tokens=None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
# Mask token behave like a normal word, i.e. include the space before it
|
||||||
|
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
|
||||||
|
|
||||||
|
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
bos_token=bos_token,
|
||||||
|
eos_token=eos_token,
|
||||||
|
unk_token=unk_token,
|
||||||
|
sep_token=sep_token,
|
||||||
|
cls_token=cls_token,
|
||||||
|
pad_token=pad_token,
|
||||||
|
mask_token=mask_token,
|
||||||
|
language_codes=language_codes,
|
||||||
|
tokenizer_file=tokenizer_file,
|
||||||
|
src_lang=src_lang,
|
||||||
|
tgt_lang=tgt_lang,
|
||||||
|
additional_special_tokens=additional_special_tokens,
|
||||||
|
sp_model_kwargs=self.sp_model_kwargs,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
||||||
|
self.sp_model.Load(str(vocab_file))
|
||||||
|
self.vocab_file = vocab_file
|
||||||
|
self.language_codes = language_codes
|
||||||
|
|
||||||
|
fairseq_language_codes = FAIRSEQ_LANGUAGE_CODES[self.language_codes]
|
||||||
|
|
||||||
|
# Original fairseq vocab and spm vocab must be "aligned":
|
||||||
|
# Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9
|
||||||
|
# -------- | ------- | ------- | ------ | ------- | --- | --- | --- | ----- | ----- | ----
|
||||||
|
# fairseq | '<s>' | '<pad>' | '</s>' | '<unk>' | ',' | '.' | '▁' | 's' | '▁de' | '-'
|
||||||
|
# spm | '<unk>' | '<s>' | '</s>' | ',' | '.' | '▁' | 's' | '▁de' | '-' | '▁a'
|
||||||
|
|
||||||
|
# Mimic fairseq token-to-id alignment for the first 4 token
|
||||||
|
self.fairseq_tokens_to_ids = {"<s>": 0, "<pad>": 1, "</s>": 2, "<unk>": 3}
|
||||||
|
|
||||||
|
# The first "real" token "," has position 4 in the original fairseq vocab and position 3 in the spm vocab
|
||||||
|
self.fairseq_offset = 1
|
||||||
|
|
||||||
|
self.sp_model_size = len(self.sp_model)
|
||||||
|
self.lang_code_to_id = {
|
||||||
|
code: self.sp_model_size + i + self.fairseq_offset for i, code in enumerate(fairseq_language_codes)
|
||||||
|
}
|
||||||
|
self.id_to_lang_code = {v: k for k, v in self.lang_code_to_id.items()}
|
||||||
|
|
||||||
|
if self.language_codes == "base":
|
||||||
|
self.fairseq_tokens_to_ids["<mask>"] = len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset
|
||||||
|
|
||||||
|
self.fairseq_tokens_to_ids.update(self.lang_code_to_id)
|
||||||
|
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
|
||||||
|
self._additional_special_tokens = list(self.lang_code_to_id.keys())
|
||||||
|
|
||||||
|
if additional_special_tokens is not None:
|
||||||
|
# Only add those special tokens if they are not already there.
|
||||||
|
self._additional_special_tokens.extend(
|
||||||
|
[t for t in additional_special_tokens if t not in self._additional_special_tokens]
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.language_codes == "base":
|
||||||
|
self._src_lang = src_lang
|
||||||
|
self.cur_lang_code_id = (
|
||||||
|
self.lang_code_to_id[self._src_lang] if self._src_lang is not None else self._src_lang
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._src_lang = src_lang if src_lang is not None else "en_XX"
|
||||||
|
self.cur_lang_code_id = self.lang_code_to_id[self._src_lang]
|
||||||
|
|
||||||
|
self.tgt_lang = tgt_lang
|
||||||
|
self.set_src_lang_special_tokens(self._src_lang)
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
state = self.__dict__.copy()
|
||||||
|
state["sp_model"] = None
|
||||||
|
state["sp_model_proto"] = self.sp_model.serialized_model_proto()
|
||||||
|
return state
|
||||||
|
|
||||||
|
def __setstate__(self, d):
|
||||||
|
self.__dict__ = d
|
||||||
|
|
||||||
|
# for backward compatibility
|
||||||
|
if not hasattr(self, "sp_model_kwargs"):
|
||||||
|
self.sp_model_kwargs = {}
|
||||||
|
|
||||||
|
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
||||||
|
self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_size(self):
|
||||||
|
if self.language_codes == "base":
|
||||||
|
return (
|
||||||
|
len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset + 1
|
||||||
|
) # Plus 1 for the mask token
|
||||||
|
else:
|
||||||
|
return len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset
|
||||||
|
|
||||||
|
@property
|
||||||
|
def src_lang(self) -> str:
|
||||||
|
return self._src_lang
|
||||||
|
|
||||||
|
@src_lang.setter
|
||||||
|
def src_lang(self, new_src_lang: str) -> None:
|
||||||
|
self._src_lang = new_src_lang
|
||||||
|
self.set_src_lang_special_tokens(self._src_lang)
|
||||||
|
|
||||||
|
def get_special_tokens_mask(
|
||||||
|
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
||||||
|
) -> List[int]:
|
||||||
|
"""
|
||||||
|
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
||||||
|
special tokens using the tokenizer `prepare_for_model` method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_ids_0 (`List[int]`):
|
||||||
|
List of IDs.
|
||||||
|
token_ids_1 (`List[int]`, *optional*):
|
||||||
|
Optional second list of IDs for sequence pairs.
|
||||||
|
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not the token list is already formatted with special tokens for the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if already_has_special_tokens:
|
||||||
|
return super().get_special_tokens_mask(
|
||||||
|
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
||||||
|
)
|
||||||
|
|
||||||
|
prefix_ones = [1] * len(self.prefix_tokens)
|
||||||
|
suffix_ones = [1] * len(self.suffix_tokens)
|
||||||
|
if token_ids_1 is None:
|
||||||
|
return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones
|
||||||
|
return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
|
||||||
|
|
||||||
|
def build_inputs_with_special_tokens(
|
||||||
|
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||||
|
) -> List[int]:
|
||||||
|
"""
|
||||||
|
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
||||||
|
adding special tokens. An PLBART sequence has the following format, where `X` represents the sequence:
|
||||||
|
|
||||||
|
- `input_ids` (for encoder) `X [eos, src_lang_code]`
|
||||||
|
- `decoder_input_ids`: (for decoder) `X [eos, tgt_lang_code]`
|
||||||
|
|
||||||
|
BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a
|
||||||
|
separator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_ids_0 (`List[int]`):
|
||||||
|
List of IDs to which the special tokens will be added.
|
||||||
|
token_ids_1 (`List[int]`, *optional*):
|
||||||
|
Optional second list of IDs for sequence pairs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
||||||
|
"""
|
||||||
|
if token_ids_1 is None:
|
||||||
|
return self.prefix_tokens + token_ids_0 + self.suffix_tokens
|
||||||
|
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||||
|
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
|
||||||
|
|
||||||
|
def create_token_type_ids_from_sequences(
|
||||||
|
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||||
|
) -> List[int]:
|
||||||
|
"""
|
||||||
|
Create a mask from the two sequences passed to be used in a sequence-pair classification task. PLBart does not
|
||||||
|
make use of token type ids, therefore a list of zeros is returned.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_ids_0 (`List[int]`):
|
||||||
|
List of IDs.
|
||||||
|
token_ids_1 (`List[int]`, *optional*):
|
||||||
|
Optional second list of IDs for sequence pairs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`List[int]`: List of zeros.
|
||||||
|
"""
|
||||||
|
|
||||||
|
sep = [self.sep_token_id]
|
||||||
|
cls = [self.cls_token_id]
|
||||||
|
|
||||||
|
if token_ids_1 is None:
|
||||||
|
return len(cls + token_ids_0 + sep) * [0]
|
||||||
|
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
|
||||||
|
|
||||||
|
def _build_translation_inputs(
|
||||||
|
self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs
|
||||||
|
):
|
||||||
|
"""Used by translation pipeline, to prepare inputs for the generate function"""
|
||||||
|
if src_lang is None or tgt_lang is None:
|
||||||
|
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
|
||||||
|
self.src_lang = src_lang
|
||||||
|
inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs)
|
||||||
|
tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
|
||||||
|
inputs["forced_bos_token_id"] = tgt_lang_id
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def get_vocab(self):
|
||||||
|
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
||||||
|
vocab.update(self.added_tokens_encoder)
|
||||||
|
return vocab
|
||||||
|
|
||||||
|
def _tokenize(self, text: str) -> List[str]:
|
||||||
|
return self.sp_model.encode(text, out_type=str)
|
||||||
|
|
||||||
|
def _convert_token_to_id(self, token):
|
||||||
|
"""Converts a token (str) in an id using the vocab."""
|
||||||
|
if token in self.fairseq_tokens_to_ids:
|
||||||
|
return self.fairseq_tokens_to_ids[token]
|
||||||
|
spm_id = self.sp_model.PieceToId(token)
|
||||||
|
|
||||||
|
# Need to return unknown token if the SP model returned 0
|
||||||
|
return spm_id + self.fairseq_offset if spm_id else self.unk_token_id
|
||||||
|
|
||||||
|
def _convert_id_to_token(self, index):
|
||||||
|
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||||
|
if index in self.fairseq_ids_to_tokens:
|
||||||
|
return self.fairseq_ids_to_tokens[index]
|
||||||
|
return self.sp_model.IdToPiece(index - self.fairseq_offset)
|
||||||
|
|
||||||
|
def convert_tokens_to_string(self, tokens):
|
||||||
|
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
|
||||||
|
out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
|
||||||
|
return out_string
|
||||||
|
|
||||||
|
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||||
|
if not os.path.isdir(save_directory):
|
||||||
|
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
||||||
|
return
|
||||||
|
out_vocab_file = os.path.join(
|
||||||
|
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
||||||
|
)
|
||||||
|
|
||||||
|
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
|
||||||
|
copyfile(self.vocab_file, out_vocab_file)
|
||||||
|
elif not os.path.isfile(self.vocab_file):
|
||||||
|
with open(out_vocab_file, "wb") as fi:
|
||||||
|
content_spiece_model = self.sp_model.serialized_model_proto()
|
||||||
|
fi.write(content_spiece_model)
|
||||||
|
|
||||||
|
return (out_vocab_file,)
|
||||||
|
|
||||||
|
def prepare_seq2seq_batch(
|
||||||
|
self,
|
||||||
|
src_texts: List[str],
|
||||||
|
src_lang: str = "en_XX",
|
||||||
|
tgt_texts: Optional[List[str]] = None,
|
||||||
|
tgt_lang: str = "python",
|
||||||
|
**kwargs,
|
||||||
|
) -> BatchEncoding:
|
||||||
|
self.src_lang = src_lang
|
||||||
|
self.tgt_lang = tgt_lang
|
||||||
|
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def as_target_tokenizer(self):
|
||||||
|
"""
|
||||||
|
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
|
||||||
|
sequence-to-sequence models that need a slightly different processing for the labels.
|
||||||
|
"""
|
||||||
|
self.set_tgt_lang_special_tokens(self.tgt_lang)
|
||||||
|
yield
|
||||||
|
self.set_src_lang_special_tokens(self.src_lang)
|
||||||
|
|
||||||
|
def set_src_lang_special_tokens(self, src_lang) -> None:
|
||||||
|
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
|
||||||
|
self.cur_lang_code = self.lang_code_to_id[src_lang] if src_lang is not None else None
|
||||||
|
self.prefix_tokens = []
|
||||||
|
if self.cur_lang_code is not None:
|
||||||
|
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
|
||||||
|
else:
|
||||||
|
self.suffix_tokens = [self.eos_token_id]
|
||||||
|
|
||||||
|
def set_tgt_lang_special_tokens(self, lang: str) -> None:
|
||||||
|
"""Reset the special tokens to the target language setting. No prefix and suffix=[eos, tgt_lang_code]."""
|
||||||
|
self.cur_lang_code = self.lang_code_to_id[lang] if lang is not None else None
|
||||||
|
self.prefix_tokens = []
|
||||||
|
if self.cur_lang_code is not None:
|
||||||
|
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
|
||||||
|
else:
|
||||||
|
self.suffix_tokens = [self.eos_token_id]
|
||||||
@@ -2786,6 +2786,44 @@ class PerceiverPreTrainedModel(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
PLBART_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
|
class PLBartForCausalLM(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class PLBartForConditionalGeneration(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class PLBartForSequenceClassification(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class PLBartModel(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class PLBartPreTrainedModel(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
POOLFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
POOLFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -108,6 +108,13 @@ class PegasusTokenizer(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["sentencepiece"])
|
requires_backends(self, ["sentencepiece"])
|
||||||
|
|
||||||
|
|
||||||
|
class PLBartTokenizer(metaclass=DummyObject):
|
||||||
|
_backends = ["sentencepiece"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["sentencepiece"])
|
||||||
|
|
||||||
|
|
||||||
class ReformerTokenizer(metaclass=DummyObject):
|
class ReformerTokenizer(metaclass=DummyObject):
|
||||||
_backends = ["sentencepiece"]
|
_backends = ["sentencepiece"]
|
||||||
|
|
||||||
|
|||||||
632
tests/test_modeling_plbart.py
Normal file
632
tests/test_modeling_plbart.py
Normal file
@@ -0,0 +1,632 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022, The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" Testing suite for the PyTorch PLBART model. """
|
||||||
|
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from transformers import PLBartConfig, is_torch_available
|
||||||
|
from transformers.file_utils import cached_property
|
||||||
|
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
||||||
|
|
||||||
|
from .test_configuration_common import ConfigTester
|
||||||
|
from .test_generation_utils import GenerationTesterMixin
|
||||||
|
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
AutoTokenizer,
|
||||||
|
PLBartForCausalLM,
|
||||||
|
PLBartForConditionalGeneration,
|
||||||
|
PLBartForSequenceClassification,
|
||||||
|
PLBartModel,
|
||||||
|
)
|
||||||
|
from transformers.models.plbart.modeling_plbart import PLBartDecoder, PLBartEncoder
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_plbart_inputs_dict(
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
decoder_input_ids,
|
||||||
|
attention_mask=None,
|
||||||
|
decoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
|
):
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = input_ids.ne(config.pad_token_id)
|
||||||
|
if decoder_attention_mask is None:
|
||||||
|
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
|
||||||
|
if head_mask is None:
|
||||||
|
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
|
||||||
|
if decoder_head_mask is None:
|
||||||
|
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
|
||||||
|
if cross_attn_head_mask is None:
|
||||||
|
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"decoder_input_ids": decoder_input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"decoder_attention_mask": attention_mask,
|
||||||
|
"head_mask": head_mask,
|
||||||
|
"decoder_head_mask": decoder_head_mask,
|
||||||
|
"cross_attn_head_mask": cross_attn_head_mask,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class PLBartModelTester:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent,
|
||||||
|
batch_size=13,
|
||||||
|
seq_length=7,
|
||||||
|
is_training=True,
|
||||||
|
use_labels=False,
|
||||||
|
vocab_size=99,
|
||||||
|
hidden_size=16,
|
||||||
|
num_hidden_layers=2,
|
||||||
|
num_attention_heads=4,
|
||||||
|
intermediate_size=4,
|
||||||
|
hidden_act="gelu",
|
||||||
|
hidden_dropout_prob=0.1,
|
||||||
|
attention_probs_dropout_prob=0.1,
|
||||||
|
max_position_embeddings=100,
|
||||||
|
eos_token_id=2,
|
||||||
|
pad_token_id=1,
|
||||||
|
bos_token_id=0,
|
||||||
|
):
|
||||||
|
self.parent = parent
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.seq_length = seq_length
|
||||||
|
self.is_training = is_training
|
||||||
|
self.use_labels = use_labels
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.hidden_dropout_prob = hidden_dropout_prob
|
||||||
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.eos_token_id = eos_token_id
|
||||||
|
self.pad_token_id = pad_token_id
|
||||||
|
self.bos_token_id = bos_token_id
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
|
input_ids = input_ids.clamp(3)
|
||||||
|
input_ids[:, -1] = self.eos_token_id # Eos Token
|
||||||
|
|
||||||
|
decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
|
|
||||||
|
config = self.get_config()
|
||||||
|
inputs_dict = prepare_plbart_inputs_dict(config, input_ids, decoder_input_ids)
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return PLBartConfig(
|
||||||
|
vocab_size=self.vocab_size,
|
||||||
|
d_model=self.hidden_size,
|
||||||
|
encoder_layers=self.num_hidden_layers,
|
||||||
|
decoder_layers=self.num_hidden_layers,
|
||||||
|
encoder_attention_heads=self.num_attention_heads,
|
||||||
|
decoder_attention_heads=self.num_attention_heads,
|
||||||
|
encoder_ffn_dim=self.intermediate_size,
|
||||||
|
decoder_ffn_dim=self.intermediate_size,
|
||||||
|
dropout=self.hidden_dropout_prob,
|
||||||
|
attention_dropout=self.attention_probs_dropout_prob,
|
||||||
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
|
eos_token_id=self.eos_token_id,
|
||||||
|
bos_token_id=self.bos_token_id,
|
||||||
|
pad_token_id=self.pad_token_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
config, inputs_dict = self.prepare_config_and_inputs()
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):
|
||||||
|
model = PLBartModel(config=config).get_decoder().to(torch_device).eval()
|
||||||
|
input_ids = inputs_dict["input_ids"]
|
||||||
|
attention_mask = inputs_dict["attention_mask"]
|
||||||
|
head_mask = inputs_dict["head_mask"]
|
||||||
|
|
||||||
|
# first forward pass
|
||||||
|
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||||
|
|
||||||
|
output, past_key_values = outputs.to_tuple()
|
||||||
|
|
||||||
|
# create hypothetical multiple next token and extent to next_input_ids
|
||||||
|
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||||||
|
next_attn_mask = ids_tensor((self.batch_size, 3), 2)
|
||||||
|
|
||||||
|
# append to next input_ids and
|
||||||
|
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||||
|
next_attention_mask = torch.cat([attention_mask, next_attn_mask], dim=-1)
|
||||||
|
|
||||||
|
output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"]
|
||||||
|
output_with_past_key_values = model(
|
||||||
|
next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values
|
||||||
|
)
|
||||||
|
output_from_past = output_with_past_key_values["last_hidden_state"]
|
||||||
|
|
||||||
|
# select random slice
|
||||||
|
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||||
|
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
|
||||||
|
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
|
||||||
|
|
||||||
|
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
||||||
|
|
||||||
|
# test that outputs are equal for slice
|
||||||
|
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||||
|
|
||||||
|
def check_encoder_decoder_model_standalone(self, config, inputs_dict):
|
||||||
|
model = PLBartModel(config=config).to(torch_device).eval()
|
||||||
|
outputs = model(**inputs_dict)
|
||||||
|
|
||||||
|
encoder_last_hidden_state = outputs.encoder_last_hidden_state
|
||||||
|
last_hidden_state = outputs.last_hidden_state
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
encoder = model.get_encoder()
|
||||||
|
encoder.save_pretrained(tmpdirname)
|
||||||
|
encoder = PLBartEncoder.from_pretrained(tmpdirname).to(torch_device)
|
||||||
|
|
||||||
|
encoder_last_hidden_state_2 = encoder(inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"])[
|
||||||
|
0
|
||||||
|
]
|
||||||
|
|
||||||
|
self.parent.assertTrue((encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max().item() < 1e-3)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
decoder = model.get_decoder()
|
||||||
|
decoder.save_pretrained(tmpdirname)
|
||||||
|
decoder = PLBartDecoder.from_pretrained(tmpdirname).to(torch_device)
|
||||||
|
|
||||||
|
last_hidden_state_2 = decoder(
|
||||||
|
input_ids=inputs_dict["decoder_input_ids"],
|
||||||
|
attention_mask=inputs_dict["decoder_attention_mask"],
|
||||||
|
encoder_hidden_states=encoder_last_hidden_state,
|
||||||
|
encoder_attention_mask=inputs_dict["attention_mask"],
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
self.parent.assertTrue((last_hidden_state_2 - last_hidden_state).abs().max().item() < 1e-3)
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class PLBartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||||
|
all_model_classes = (
|
||||||
|
(PLBartModel, PLBartForConditionalGeneration, PLBartForSequenceClassification) if is_torch_available() else ()
|
||||||
|
)
|
||||||
|
all_generative_model_classes = (PLBartForConditionalGeneration,) if is_torch_available() else ()
|
||||||
|
is_encoder_decoder = True
|
||||||
|
test_pruning = False
|
||||||
|
test_missing_keys = False
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.model_tester = PLBartModelTester(self)
|
||||||
|
self.config_tester = ConfigTester(self, config_class=PLBartConfig)
|
||||||
|
|
||||||
|
def test_config(self):
|
||||||
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
|
def test_save_load_strict(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
model.save_pretrained(tmpdirname)
|
||||||
|
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
|
||||||
|
self.assertEqual(info["missing_keys"], [])
|
||||||
|
|
||||||
|
def test_decoder_model_past_with_large_inputs(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_encoder_decoder_model_standalone(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs)
|
||||||
|
|
||||||
|
# PLBartForSequenceClassification does not support inputs_embeds
|
||||||
|
def test_inputs_embeds(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in (PLBartModel, PLBartForConditionalGeneration):
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
|
||||||
|
if not self.is_encoder_decoder:
|
||||||
|
input_ids = inputs["input_ids"]
|
||||||
|
del inputs["input_ids"]
|
||||||
|
else:
|
||||||
|
encoder_input_ids = inputs["input_ids"]
|
||||||
|
decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
|
||||||
|
del inputs["input_ids"]
|
||||||
|
inputs.pop("decoder_input_ids", None)
|
||||||
|
|
||||||
|
wte = model.get_input_embeddings()
|
||||||
|
if not self.is_encoder_decoder:
|
||||||
|
inputs["inputs_embeds"] = wte(input_ids)
|
||||||
|
else:
|
||||||
|
inputs["inputs_embeds"] = wte(encoder_input_ids)
|
||||||
|
inputs["decoder_inputs_embeds"] = wte(decoder_input_ids)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
model(**inputs)[0]
|
||||||
|
|
||||||
|
def test_generate_fp16(self):
|
||||||
|
config, input_dict = self.model_tester.prepare_config_and_inputs()
|
||||||
|
input_ids = input_dict["input_ids"]
|
||||||
|
attention_mask = input_ids.ne(1).to(torch_device)
|
||||||
|
model = PLBartForConditionalGeneration(config).eval().to(torch_device)
|
||||||
|
if torch_device == "cuda":
|
||||||
|
model.half()
|
||||||
|
model.generate(input_ids, attention_mask=attention_mask)
|
||||||
|
model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
|
||||||
|
|
||||||
|
|
||||||
|
def assert_tensors_close(a, b, atol=1e-12, prefix=""):
|
||||||
|
"""If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""
|
||||||
|
if a is None and b is None:
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
if torch.allclose(a, b, atol=atol):
|
||||||
|
return True
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
pct_different = (torch.gt((a - b).abs(), atol)).float().mean().item()
|
||||||
|
if a.numel() > 100:
|
||||||
|
msg = f"tensor values are {pct_different:.1%} percent different."
|
||||||
|
else:
|
||||||
|
msg = f"{a} != {b}"
|
||||||
|
if prefix:
|
||||||
|
msg = prefix + ": " + msg
|
||||||
|
raise AssertionError(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def _long_tensor(tok_lst):
|
||||||
|
return torch.tensor(tok_lst, dtype=torch.long, device=torch_device)
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@require_sentencepiece
|
||||||
|
@require_tokenizers
|
||||||
|
class AbstractSeq2SeqIntegrationTest(unittest.TestCase):
|
||||||
|
maxDiff = 1000 # longer string compare tracebacks
|
||||||
|
checkpoint_name = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.tokenizer = AutoTokenizer.from_pretrained(cls.checkpoint_name, use_fast=False)
|
||||||
|
return cls
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def model(self):
|
||||||
|
"""Only load the model if needed."""
|
||||||
|
model = PLBartForConditionalGeneration.from_pretrained(self.checkpoint_name).to(torch_device)
|
||||||
|
if "cuda" in torch_device:
|
||||||
|
model = model.half()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@require_sentencepiece
|
||||||
|
@require_tokenizers
|
||||||
|
class PLBartJavaCsIntegrationTest(AbstractSeq2SeqIntegrationTest):
|
||||||
|
checkpoint_name = "uclanlp/plbart-java-cs"
|
||||||
|
src_text = [
|
||||||
|
"public int maximum(int a, int b, int c){return Math.max(a, Math.max(b, c));}",
|
||||||
|
"public int product(int a, int b, int c){return a*b*c;}",
|
||||||
|
]
|
||||||
|
tgt_text = [
|
||||||
|
"public int maximum(int a, int b, int c){return Math.Max(",
|
||||||
|
"public int Product(int a, int b, int c){return a * b *",
|
||||||
|
]
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_java_cs_generate_one(self):
|
||||||
|
batch = self.tokenizer(
|
||||||
|
["public int maximum(int a, int b, int c){return Math.max(a, Math.max(b, c));}"], return_tensors="pt"
|
||||||
|
)
|
||||||
|
batch = batch.to(torch_device)
|
||||||
|
translated_tokens = self.model.generate(**batch)
|
||||||
|
decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
|
||||||
|
self.assertEqual(self.tgt_text[0], decoded[0])
|
||||||
|
# self.assertEqual(self.tgt_text[1], decoded[1])
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_java_cs_generate_batch(self):
|
||||||
|
batch = self.tokenizer(self.src_text, return_tensors="pt", padding=True, truncation=True)
|
||||||
|
batch = batch.to(torch_device)
|
||||||
|
translated_tokens = self.model.generate(**batch)
|
||||||
|
decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
|
||||||
|
assert self.tgt_text == decoded
|
||||||
|
|
||||||
|
def test_plbart_java_cs_config(self):
|
||||||
|
plbart_models = ["uclanlp/plbart-java-cs"]
|
||||||
|
expected = {"scale_embedding": True}
|
||||||
|
for name in plbart_models:
|
||||||
|
config = PLBartConfig.from_pretrained(name)
|
||||||
|
for k, v in expected.items():
|
||||||
|
try:
|
||||||
|
self.assertEqual(v, getattr(config, k))
|
||||||
|
except AssertionError as e:
|
||||||
|
e.args += (name, k)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def test_plbart_fast_forward(self):
|
||||||
|
config = PLBartConfig(
|
||||||
|
vocab_size=99,
|
||||||
|
d_model=24,
|
||||||
|
encoder_layers=2,
|
||||||
|
decoder_layers=2,
|
||||||
|
encoder_attention_heads=2,
|
||||||
|
decoder_attention_heads=2,
|
||||||
|
encoder_ffn_dim=32,
|
||||||
|
decoder_ffn_dim=32,
|
||||||
|
max_position_embeddings=48,
|
||||||
|
add_final_layer_norm=True,
|
||||||
|
)
|
||||||
|
lm_model = PLBartForConditionalGeneration(config).to(torch_device)
|
||||||
|
context = torch.tensor(
|
||||||
|
[[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]], device=torch_device, dtype=torch.long
|
||||||
|
)
|
||||||
|
summary = torch.tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]], device=torch_device, dtype=torch.long)
|
||||||
|
result = lm_model(input_ids=context, decoder_input_ids=summary, labels=summary)
|
||||||
|
expected_shape = (*summary.shape, config.vocab_size)
|
||||||
|
self.assertEqual(result.logits.shape, expected_shape)
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@require_sentencepiece
|
||||||
|
@require_tokenizers
|
||||||
|
class PLBartBaseIntegrationTest(AbstractSeq2SeqIntegrationTest):
|
||||||
|
checkpoint_name = "uclanlp/plbart-base"
|
||||||
|
src_text = ["Is 0 the first Fibonacci number ?", "Find the sum of all prime numbers ."]
|
||||||
|
tgt_text = ["0 the first Fibonacci number?", "the sum of all prime numbers.......... the the"]
|
||||||
|
|
||||||
|
# @unittest.skip("This test is broken, still generates english")
|
||||||
|
def test_base_generate(self):
|
||||||
|
inputs = self.tokenizer([self.src_text[0]], return_tensors="pt").to(torch_device)
|
||||||
|
translated_tokens = self.model.generate(
|
||||||
|
input_ids=inputs["input_ids"].to(torch_device),
|
||||||
|
decoder_start_token_id=self.tokenizer.lang_code_to_id["en_XX"],
|
||||||
|
)
|
||||||
|
decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
|
||||||
|
self.assertEqual(self.tgt_text[0], decoded[0])
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_fill_mask(self):
|
||||||
|
inputs = self.tokenizer(["Is 0 the <mask> Fibonacci <mask> ?"], return_tensors="pt").to(torch_device)
|
||||||
|
outputs = self.model.generate(
|
||||||
|
inputs["input_ids"], decoder_start_token_id=self.tokenizer.lang_code_to_id["en_XX"], num_beams=1
|
||||||
|
)
|
||||||
|
prediction: str = self.tokenizer.batch_decode(
|
||||||
|
outputs, clean_up_tokenization_spaces=True, skip_special_tokens=True
|
||||||
|
)[0]
|
||||||
|
self.assertEqual(prediction, "0 0 the 0 the 0 the 0 the 0 the 0 the 0 the 0 the")
|
||||||
|
|
||||||
|
|
||||||
|
class PLBartStandaloneDecoderModelTester:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent,
|
||||||
|
vocab_size=99,
|
||||||
|
batch_size=13,
|
||||||
|
d_model=16,
|
||||||
|
decoder_seq_length=7,
|
||||||
|
is_training=True,
|
||||||
|
is_decoder=True,
|
||||||
|
use_attention_mask=True,
|
||||||
|
use_cache=False,
|
||||||
|
use_labels=True,
|
||||||
|
decoder_start_token_id=2,
|
||||||
|
decoder_ffn_dim=32,
|
||||||
|
decoder_layers=4,
|
||||||
|
encoder_attention_heads=4,
|
||||||
|
decoder_attention_heads=4,
|
||||||
|
max_position_embeddings=30,
|
||||||
|
is_encoder_decoder=False,
|
||||||
|
pad_token_id=0,
|
||||||
|
bos_token_id=1,
|
||||||
|
eos_token_id=2,
|
||||||
|
scope=None,
|
||||||
|
):
|
||||||
|
self.parent = parent
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.decoder_seq_length = decoder_seq_length
|
||||||
|
# For common tests
|
||||||
|
self.seq_length = self.decoder_seq_length
|
||||||
|
self.is_training = is_training
|
||||||
|
self.use_attention_mask = use_attention_mask
|
||||||
|
self.use_labels = use_labels
|
||||||
|
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.d_model = d_model
|
||||||
|
self.hidden_size = d_model
|
||||||
|
self.num_hidden_layers = decoder_layers
|
||||||
|
self.decoder_layers = decoder_layers
|
||||||
|
self.decoder_ffn_dim = decoder_ffn_dim
|
||||||
|
self.encoder_attention_heads = encoder_attention_heads
|
||||||
|
self.decoder_attention_heads = decoder_attention_heads
|
||||||
|
self.num_attention_heads = decoder_attention_heads
|
||||||
|
self.eos_token_id = eos_token_id
|
||||||
|
self.bos_token_id = bos_token_id
|
||||||
|
self.pad_token_id = pad_token_id
|
||||||
|
self.decoder_start_token_id = decoder_start_token_id
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.is_encoder_decoder = is_encoder_decoder
|
||||||
|
|
||||||
|
self.scope = None
|
||||||
|
self.decoder_key_length = decoder_seq_length
|
||||||
|
self.base_model_out_len = 2
|
||||||
|
self.decoder_attention_idx = 1
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
|
||||||
|
|
||||||
|
attention_mask = None
|
||||||
|
if self.use_attention_mask:
|
||||||
|
attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2)
|
||||||
|
|
||||||
|
lm_labels = None
|
||||||
|
if self.use_labels:
|
||||||
|
lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
|
||||||
|
|
||||||
|
config = PLBartConfig(
|
||||||
|
vocab_size=self.vocab_size,
|
||||||
|
d_model=self.d_model,
|
||||||
|
decoder_layers=self.decoder_layers,
|
||||||
|
decoder_ffn_dim=self.decoder_ffn_dim,
|
||||||
|
encoder_attention_heads=self.encoder_attention_heads,
|
||||||
|
decoder_attention_heads=self.decoder_attention_heads,
|
||||||
|
eos_token_id=self.eos_token_id,
|
||||||
|
bos_token_id=self.bos_token_id,
|
||||||
|
use_cache=self.use_cache,
|
||||||
|
pad_token_id=self.pad_token_id,
|
||||||
|
decoder_start_token_id=self.decoder_start_token_id,
|
||||||
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
|
is_encoder_decoder=self.is_encoder_decoder,
|
||||||
|
)
|
||||||
|
|
||||||
|
return (config, input_ids, attention_mask, lm_labels)
|
||||||
|
|
||||||
|
def create_and_check_decoder_model_past(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
lm_labels,
|
||||||
|
):
|
||||||
|
config.use_cache = True
|
||||||
|
model = PLBartDecoder(config=config).to(torch_device).eval()
|
||||||
|
# first forward pass
|
||||||
|
outputs = model(input_ids, use_cache=True)
|
||||||
|
outputs_use_cache_conf = model(input_ids)
|
||||||
|
outputs_no_past = model(input_ids, use_cache=False)
|
||||||
|
|
||||||
|
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
|
||||||
|
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
|
||||||
|
|
||||||
|
past_key_values = outputs["past_key_values"]
|
||||||
|
|
||||||
|
# create hypothetical next token and extent to next_input_ids
|
||||||
|
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||||
|
|
||||||
|
# append to next input_ids and
|
||||||
|
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||||
|
|
||||||
|
output_from_no_past = model(next_input_ids)["last_hidden_state"]
|
||||||
|
output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"]
|
||||||
|
|
||||||
|
# select random slice
|
||||||
|
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||||
|
output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach()
|
||||||
|
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
|
||||||
|
|
||||||
|
# test that outputs are equal for slice
|
||||||
|
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||||
|
|
||||||
|
def create_and_check_decoder_model_attention_mask_past(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
lm_labels,
|
||||||
|
):
|
||||||
|
model = PLBartDecoder(config=config).to(torch_device).eval()
|
||||||
|
|
||||||
|
# create attention mask
|
||||||
|
attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
|
||||||
|
|
||||||
|
half_seq_length = input_ids.shape[-1] // 2
|
||||||
|
attn_mask[:, half_seq_length:] = 0
|
||||||
|
|
||||||
|
# first forward pass
|
||||||
|
past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True)["past_key_values"]
|
||||||
|
|
||||||
|
# create hypothetical next token and extent to next_input_ids
|
||||||
|
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||||
|
|
||||||
|
# change a random masked slice from input_ids
|
||||||
|
random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1
|
||||||
|
random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1)
|
||||||
|
input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens
|
||||||
|
|
||||||
|
# append to next input_ids and attn_mask
|
||||||
|
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||||
|
attn_mask = torch.cat(
|
||||||
|
[attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# get two different outputs
|
||||||
|
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
|
||||||
|
output_from_past = model(next_tokens, attention_mask=attn_mask, past_key_values=past_key_values)[
|
||||||
|
"last_hidden_state"
|
||||||
|
]
|
||||||
|
|
||||||
|
# select random slice
|
||||||
|
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||||
|
output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach()
|
||||||
|
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
|
||||||
|
|
||||||
|
# test that outputs are equal for slice
|
||||||
|
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
|
(config, input_ids, attention_mask, lm_labels) = config_and_inputs
|
||||||
|
inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask}
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class PLBartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||||
|
all_model_classes = (PLBartDecoder, PLBartForCausalLM) if is_torch_available() else ()
|
||||||
|
all_generative_model_classes = (PLBartForCausalLM,) if is_torch_available() else ()
|
||||||
|
test_pruning = False
|
||||||
|
is_encoder_decoder = False
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.model_tester = PLBartStandaloneDecoderModelTester(self, is_training=False)
|
||||||
|
self.config_tester = ConfigTester(self, config_class=PLBartConfig)
|
||||||
|
|
||||||
|
def test_config(self):
|
||||||
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
|
def test_decoder_model_past(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_decoder_model_past(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_decoder_model_attn_mask_past(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_retain_grad_hidden_states_attentions(self):
|
||||||
|
# decoder cannot keep gradients
|
||||||
|
return
|
||||||
361
tests/test_tokenization_plbart.py
Normal file
361
tests/test_tokenization_plbart.py
Normal file
@@ -0,0 +1,361 @@
|
|||||||
|
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from transformers import SPIECE_UNDERLINE, BatchEncoding, PLBartTokenizer, is_torch_available
|
||||||
|
from transformers.testing_utils import nested_simplify, require_sentencepiece, require_tokenizers, require_torch
|
||||||
|
|
||||||
|
from .test_tokenization_common import TokenizerTesterMixin
|
||||||
|
|
||||||
|
|
||||||
|
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
from transformers.models.plbart.modeling_plbart import shift_tokens_right
|
||||||
|
|
||||||
|
EN_CODE = 50003
|
||||||
|
PYTHON_CODE = 50002
|
||||||
|
|
||||||
|
|
||||||
|
@require_sentencepiece
|
||||||
|
@require_tokenizers
|
||||||
|
class PLBartTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||||
|
tokenizer_class = PLBartTokenizer
|
||||||
|
rust_tokenizer_class = None
|
||||||
|
test_rust_tokenizer = False
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
|
||||||
|
# We have a SentencePiece fixture for testing
|
||||||
|
tokenizer = PLBartTokenizer(SAMPLE_VOCAB, language_codes="base", keep_accents=True)
|
||||||
|
tokenizer.save_pretrained(self.tmpdirname)
|
||||||
|
|
||||||
|
def test_full_base_tokenizer(self):
|
||||||
|
tokenizer = PLBartTokenizer(SAMPLE_VOCAB, language_codes="base", keep_accents=True)
|
||||||
|
|
||||||
|
tokens = tokenizer.tokenize("This is a test")
|
||||||
|
self.assertListEqual(tokens, ["▁This", "▁is", "▁a", "▁t", "est"])
|
||||||
|
|
||||||
|
self.assertListEqual(
|
||||||
|
tokenizer.convert_tokens_to_ids(tokens),
|
||||||
|
[value + tokenizer.fairseq_offset for value in [285, 46, 10, 170, 382]],
|
||||||
|
)
|
||||||
|
|
||||||
|
tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.")
|
||||||
|
self.assertListEqual(
|
||||||
|
tokens,
|
||||||
|
[
|
||||||
|
SPIECE_UNDERLINE + "I",
|
||||||
|
SPIECE_UNDERLINE + "was",
|
||||||
|
SPIECE_UNDERLINE + "b",
|
||||||
|
"or",
|
||||||
|
"n",
|
||||||
|
SPIECE_UNDERLINE + "in",
|
||||||
|
SPIECE_UNDERLINE + "",
|
||||||
|
"9",
|
||||||
|
"2",
|
||||||
|
"0",
|
||||||
|
"0",
|
||||||
|
"0",
|
||||||
|
",",
|
||||||
|
SPIECE_UNDERLINE + "and",
|
||||||
|
SPIECE_UNDERLINE + "this",
|
||||||
|
SPIECE_UNDERLINE + "is",
|
||||||
|
SPIECE_UNDERLINE + "f",
|
||||||
|
"al",
|
||||||
|
"s",
|
||||||
|
"é",
|
||||||
|
".",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||||
|
self.assertListEqual(
|
||||||
|
ids,
|
||||||
|
[
|
||||||
|
value + tokenizer.fairseq_offset
|
||||||
|
for value in [8, 21, 84, 55, 24, 19, 7, 2, 602, 347, 347, 347, 3, 12, 66, 46, 72, 80, 6, 2, 4]
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
back_tokens = tokenizer.convert_ids_to_tokens(ids)
|
||||||
|
self.assertListEqual(
|
||||||
|
back_tokens,
|
||||||
|
[
|
||||||
|
SPIECE_UNDERLINE + "I",
|
||||||
|
SPIECE_UNDERLINE + "was",
|
||||||
|
SPIECE_UNDERLINE + "b",
|
||||||
|
"or",
|
||||||
|
"n",
|
||||||
|
SPIECE_UNDERLINE + "in",
|
||||||
|
SPIECE_UNDERLINE + "",
|
||||||
|
"<unk>",
|
||||||
|
"2",
|
||||||
|
"0",
|
||||||
|
"0",
|
||||||
|
"0",
|
||||||
|
",",
|
||||||
|
SPIECE_UNDERLINE + "and",
|
||||||
|
SPIECE_UNDERLINE + "this",
|
||||||
|
SPIECE_UNDERLINE + "is",
|
||||||
|
SPIECE_UNDERLINE + "f",
|
||||||
|
"al",
|
||||||
|
"s",
|
||||||
|
"<unk>",
|
||||||
|
".",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
end = tokenizer.vocab_size
|
||||||
|
language_tokens = [tokenizer.convert_ids_to_tokens(x) for x in range(end - 4, end)]
|
||||||
|
|
||||||
|
self.assertListEqual(language_tokens, ["java", "python", "en_XX", "<mask>"])
|
||||||
|
|
||||||
|
def test_full_multi_tokenizer(self):
|
||||||
|
tokenizer = PLBartTokenizer(SAMPLE_VOCAB, language_codes="multi", keep_accents=True)
|
||||||
|
|
||||||
|
tokens = tokenizer.tokenize("This is a test")
|
||||||
|
self.assertListEqual(tokens, ["▁This", "▁is", "▁a", "▁t", "est"])
|
||||||
|
|
||||||
|
self.assertListEqual(
|
||||||
|
tokenizer.convert_tokens_to_ids(tokens),
|
||||||
|
[value + tokenizer.fairseq_offset for value in [285, 46, 10, 170, 382]],
|
||||||
|
)
|
||||||
|
|
||||||
|
tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.")
|
||||||
|
self.assertListEqual(
|
||||||
|
tokens,
|
||||||
|
[
|
||||||
|
SPIECE_UNDERLINE + "I",
|
||||||
|
SPIECE_UNDERLINE + "was",
|
||||||
|
SPIECE_UNDERLINE + "b",
|
||||||
|
"or",
|
||||||
|
"n",
|
||||||
|
SPIECE_UNDERLINE + "in",
|
||||||
|
SPIECE_UNDERLINE + "",
|
||||||
|
"9",
|
||||||
|
"2",
|
||||||
|
"0",
|
||||||
|
"0",
|
||||||
|
"0",
|
||||||
|
",",
|
||||||
|
SPIECE_UNDERLINE + "and",
|
||||||
|
SPIECE_UNDERLINE + "this",
|
||||||
|
SPIECE_UNDERLINE + "is",
|
||||||
|
SPIECE_UNDERLINE + "f",
|
||||||
|
"al",
|
||||||
|
"s",
|
||||||
|
"é",
|
||||||
|
".",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||||
|
self.assertListEqual(
|
||||||
|
ids,
|
||||||
|
[
|
||||||
|
value + tokenizer.fairseq_offset
|
||||||
|
for value in [8, 21, 84, 55, 24, 19, 7, 2, 602, 347, 347, 347, 3, 12, 66, 46, 72, 80, 6, 2, 4]
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
back_tokens = tokenizer.convert_ids_to_tokens(ids)
|
||||||
|
self.assertListEqual(
|
||||||
|
back_tokens,
|
||||||
|
[
|
||||||
|
SPIECE_UNDERLINE + "I",
|
||||||
|
SPIECE_UNDERLINE + "was",
|
||||||
|
SPIECE_UNDERLINE + "b",
|
||||||
|
"or",
|
||||||
|
"n",
|
||||||
|
SPIECE_UNDERLINE + "in",
|
||||||
|
SPIECE_UNDERLINE + "",
|
||||||
|
"<unk>",
|
||||||
|
"2",
|
||||||
|
"0",
|
||||||
|
"0",
|
||||||
|
"0",
|
||||||
|
",",
|
||||||
|
SPIECE_UNDERLINE + "and",
|
||||||
|
SPIECE_UNDERLINE + "this",
|
||||||
|
SPIECE_UNDERLINE + "is",
|
||||||
|
SPIECE_UNDERLINE + "f",
|
||||||
|
"al",
|
||||||
|
"s",
|
||||||
|
"<unk>",
|
||||||
|
".",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
end = tokenizer.vocab_size
|
||||||
|
language_tokens = [tokenizer.convert_ids_to_tokens(x) for x in range(end - 7, end)]
|
||||||
|
|
||||||
|
self.assertListEqual(language_tokens, ["java", "python", "en_XX", "javascript", "php", "ruby", "go"])
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@require_sentencepiece
|
||||||
|
@require_tokenizers
|
||||||
|
class PLBartPythonEnIntegrationTest(unittest.TestCase):
|
||||||
|
checkpoint_name = "uclanlp/plbart-python-en_XX"
|
||||||
|
src_text = [
|
||||||
|
"def maximum(a,b,c):NEW_LINE_INDENTreturn max([a,b,c])",
|
||||||
|
"def sum(a,b,c):NEW_LINE_INDENTreturn sum([a,b,c])",
|
||||||
|
]
|
||||||
|
tgt_text = [
|
||||||
|
"Returns the maximum value of a b c.",
|
||||||
|
"Sums the values of a b c.",
|
||||||
|
]
|
||||||
|
expected_src_tokens = [
|
||||||
|
134,
|
||||||
|
5452,
|
||||||
|
33460,
|
||||||
|
33441,
|
||||||
|
33463,
|
||||||
|
33465,
|
||||||
|
33463,
|
||||||
|
33449,
|
||||||
|
988,
|
||||||
|
20,
|
||||||
|
33456,
|
||||||
|
19,
|
||||||
|
33456,
|
||||||
|
771,
|
||||||
|
39,
|
||||||
|
4258,
|
||||||
|
889,
|
||||||
|
3318,
|
||||||
|
33441,
|
||||||
|
33463,
|
||||||
|
33465,
|
||||||
|
33463,
|
||||||
|
33449,
|
||||||
|
2471,
|
||||||
|
2,
|
||||||
|
PYTHON_CODE,
|
||||||
|
]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.tokenizer: PLBartTokenizer = PLBartTokenizer.from_pretrained(
|
||||||
|
cls.checkpoint_name, language_codes="base", src_lang="python", tgt_lang="en_XX"
|
||||||
|
)
|
||||||
|
cls.pad_token_id = 1
|
||||||
|
return cls
|
||||||
|
|
||||||
|
def check_language_codes(self):
|
||||||
|
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["java"], 50001)
|
||||||
|
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["python"], 50002)
|
||||||
|
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["en_XX"], 50003)
|
||||||
|
|
||||||
|
def test_python_en_tokenizer_batch_encode_plus(self):
|
||||||
|
ids = self.tokenizer.batch_encode_plus(self.src_text).input_ids[0]
|
||||||
|
self.assertListEqual(self.expected_src_tokens, ids)
|
||||||
|
|
||||||
|
def test_python_en_tokenizer_decode_ignores_language_codes(self):
|
||||||
|
self.assertIn(PYTHON_CODE, self.tokenizer.all_special_ids)
|
||||||
|
generated_ids = [EN_CODE, 9037, 33442, 57, 752, 153, 14, 56, 18, 9, 2]
|
||||||
|
result = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
|
||||||
|
expected_english = self.tokenizer.decode(generated_ids[1:], skip_special_tokens=True)
|
||||||
|
self.assertEqual(result, expected_english)
|
||||||
|
self.assertNotIn(self.tokenizer.eos_token, result)
|
||||||
|
|
||||||
|
def test_python_en_tokenizer_truncation(self):
|
||||||
|
src_text = ["def sum(a,b,c):NEW_LINE_INDENTreturn sum([a,b,c])" * 20]
|
||||||
|
self.assertIsInstance(src_text[0], str)
|
||||||
|
desired_max_length = 10
|
||||||
|
ids = self.tokenizer(src_text, max_length=desired_max_length, truncation=True).input_ids[0]
|
||||||
|
self.assertEqual(ids[-2], 2)
|
||||||
|
self.assertEqual(ids[-1], PYTHON_CODE)
|
||||||
|
self.assertEqual(len(ids), desired_max_length)
|
||||||
|
|
||||||
|
def test_mask_token(self):
|
||||||
|
self.assertListEqual(self.tokenizer.convert_tokens_to_ids(["<mask>", "java"]), [50004, 50001])
|
||||||
|
|
||||||
|
def test_special_tokens_unaffacted_by_save_load(self):
|
||||||
|
tmpdirname = tempfile.mkdtemp()
|
||||||
|
original_special_tokens = self.tokenizer.fairseq_tokens_to_ids
|
||||||
|
self.tokenizer.save_pretrained(tmpdirname)
|
||||||
|
new_tok = PLBartTokenizer.from_pretrained(tmpdirname)
|
||||||
|
self.assertDictEqual(new_tok.fairseq_tokens_to_ids, original_special_tokens)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_batch_fairseq_parity(self):
|
||||||
|
batch = self.tokenizer(self.src_text, padding=True)
|
||||||
|
with self.tokenizer.as_target_tokenizer():
|
||||||
|
targets = self.tokenizer(self.tgt_text, padding=True, return_tensors="pt")
|
||||||
|
labels = targets["input_ids"]
|
||||||
|
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id).tolist()
|
||||||
|
|
||||||
|
# fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4
|
||||||
|
self.assertEqual(batch.input_ids[1][-2:], [2, PYTHON_CODE])
|
||||||
|
self.assertEqual(batch.decoder_input_ids[1][0], EN_CODE)
|
||||||
|
self.assertEqual(batch.decoder_input_ids[1][-1], 2)
|
||||||
|
self.assertEqual(labels[1][-2:].tolist(), [2, EN_CODE])
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_python_en_tokenizer_prepare_batch(self):
|
||||||
|
batch = self.tokenizer(
|
||||||
|
self.src_text, padding=True, truncation=True, max_length=len(self.expected_src_tokens), return_tensors="pt"
|
||||||
|
)
|
||||||
|
with self.tokenizer.as_target_tokenizer():
|
||||||
|
targets = self.tokenizer(
|
||||||
|
self.tgt_text,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=len(self.expected_src_tokens),
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
labels = targets["input_ids"]
|
||||||
|
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)
|
||||||
|
|
||||||
|
self.assertIsInstance(batch, BatchEncoding)
|
||||||
|
|
||||||
|
self.assertEqual((2, 26), batch.input_ids.shape)
|
||||||
|
self.assertEqual((2, 26), batch.attention_mask.shape)
|
||||||
|
result = batch.input_ids.tolist()[0]
|
||||||
|
self.assertListEqual(self.expected_src_tokens, result)
|
||||||
|
self.assertEqual(2, batch.decoder_input_ids[0, -1]) # EOS
|
||||||
|
# Test that special tokens are reset
|
||||||
|
self.assertEqual(self.tokenizer.prefix_tokens, [])
|
||||||
|
self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id, PYTHON_CODE])
|
||||||
|
|
||||||
|
def test_seq2seq_max_length(self):
|
||||||
|
batch = self.tokenizer(self.src_text, padding=True, truncation=True, max_length=3, return_tensors="pt")
|
||||||
|
with self.tokenizer.as_target_tokenizer():
|
||||||
|
targets = self.tokenizer(self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pt")
|
||||||
|
labels = targets["input_ids"]
|
||||||
|
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)
|
||||||
|
|
||||||
|
self.assertEqual(batch.input_ids.shape[1], 3)
|
||||||
|
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_tokenizer_translation(self):
|
||||||
|
inputs = self.tokenizer._build_translation_inputs(
|
||||||
|
"A test", return_tensors="pt", src_lang="en_XX", tgt_lang="java"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(inputs),
|
||||||
|
{
|
||||||
|
# A, test, EOS, en_XX
|
||||||
|
"input_ids": [[150, 242, 2, 50003]],
|
||||||
|
"attention_mask": [[1, 1, 1, 1]],
|
||||||
|
# java
|
||||||
|
"forced_bos_token_id": 50001,
|
||||||
|
},
|
||||||
|
)
|
||||||
@@ -45,6 +45,9 @@ PRIVATE_MODELS = [
|
|||||||
IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
|
IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
|
||||||
# models to ignore for not tested
|
# models to ignore for not tested
|
||||||
"SegformerDecodeHead", # Building part of bigger (tested) model.
|
"SegformerDecodeHead", # Building part of bigger (tested) model.
|
||||||
|
"PLBartEncoder", # Building part of bigger (tested) model.
|
||||||
|
"PLBartDecoder", # Building part of bigger (tested) model.
|
||||||
|
"PLBartDecoderWrapper", # Building part of bigger (tested) model.
|
||||||
"BigBirdPegasusEncoder", # Building part of bigger (tested) model.
|
"BigBirdPegasusEncoder", # Building part of bigger (tested) model.
|
||||||
"BigBirdPegasusDecoder", # Building part of bigger (tested) model.
|
"BigBirdPegasusDecoder", # Building part of bigger (tested) model.
|
||||||
"BigBirdPegasusDecoderWrapper", # Building part of bigger (tested) model.
|
"BigBirdPegasusDecoderWrapper", # Building part of bigger (tested) model.
|
||||||
@@ -119,6 +122,9 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
|
|||||||
"PerceiverForOpticalFlow",
|
"PerceiverForOpticalFlow",
|
||||||
"SegformerDecodeHead",
|
"SegformerDecodeHead",
|
||||||
"FlaxBeitForMaskedImageModeling",
|
"FlaxBeitForMaskedImageModeling",
|
||||||
|
"PLBartEncoder",
|
||||||
|
"PLBartDecoder",
|
||||||
|
"PLBartDecoderWrapper",
|
||||||
"BeitForMaskedImageModeling",
|
"BeitForMaskedImageModeling",
|
||||||
"CLIPTextModel",
|
"CLIPTextModel",
|
||||||
"CLIPVisionModel",
|
"CLIPVisionModel",
|
||||||
|
|||||||
Reference in New Issue
Block a user