adding models in readme and auto classes
This commit is contained in:
@@ -122,7 +122,8 @@ At some point in the future, you'll be able to seamlessly move from pre-training
|
||||
7. **[RoBERTa](https://github.com/pytorch/fairseq/tree/master/examples/roberta)** (from Facebook), released together with the paper a [Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
|
||||
8. **[DistilBERT](https://github.com/huggingface/transformers/tree/master/examples/distillation)** (from HuggingFace), released together with the paper [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter](https://arxiv.org/abs/1910.01108) by Victor Sanh, Lysandre Debut and Thomas Wolf. The same method has been applied to compress GPT2 into [DistilGPT2](https://github.com/huggingface/transformers/tree/master/examples/distillation).
|
||||
9. **[CTRL](https://github.com/salesforce/ctrl/)** (from Salesforce) released with the paper [CTRL: A Conditional Transformer Language Model for Controllable Generation](https://arxiv.org/abs/1909.05858) by Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher.
|
||||
10. Want to contribute a new model? We have added a **detailed guide and templates** to guide you in the process of adding a new model. You can find them in the [`templates`](./templates) folder of the repository. Be sure to check the [contributing guidelines](./CONTRIBUTING.md) and contact the maintainers or open an issue to collect feedbacks before starting your PR.
|
||||
10. **[T5](https://github.com/google-research/text-to-text-transfer-transformer)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
|
||||
11. Want to contribute a new model? We have added a **detailed guide and templates** to guide you in the process of adding a new model. You can find them in the [`templates`](./templates) folder of the repository. Be sure to check the [contributing guidelines](./CONTRIBUTING.md) and contact the maintainers or open an issue to collect feedbacks before starting your PR.
|
||||
|
||||
These implementations have been tested on several datasets (see the example scripts) and should match the performances of the original implementations (e.g. ~93 F1 on SQuAD for BERT Whole-Word-Masking, ~88 F1 on RocStories for OpenAI GPT, ~18.3 perplexity on WikiText 103 for Transformer-XL, ~0.916 Peason R coefficient on STS-B for XLNet). You can find more details on the performances in the Examples section of the [documentation](https://huggingface.co/transformers/examples.html).
|
||||
|
||||
|
||||
@@ -144,5 +144,25 @@ Here is the full list of the currently provided pretrained models together with
|
||||
| CTRL | ``ctrl`` | | 48-layer, 1280-hidden, 16-heads, 1.6B parameters |
|
||||
| | | | Salesforce's Large-sized CTRL English model |
|
||||
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| T5 | ``t5-small`` | | 6-layer, 768-hidden, 12-heads, 66M parameters |
|
||||
| | | | The DistilBERT model distilled from the BERT model `bert-base-uncased` checkpoint |
|
||||
| | | (see `details <https://github.com/huggingface/transformers/tree/master/examples/distillation>`__) |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``t5-base`` | | 6-layer, 768-hidden, 12-heads, 66M parameters |
|
||||
| | | | The DistilBERT model distilled from the BERT model `bert-base-uncased` checkpoint, with an additional linear layer. |
|
||||
| | | (see `details <https://github.com/huggingface/transformers/tree/master/examples/distillation>`__) |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``t5-large`` | | 6-layer, 768-hidden, 12-heads, 82M parameters |
|
||||
| | | | The DistilGPT2 model distilled from the GPT2 model `gpt2` checkpoint. |
|
||||
| | | (see `details <https://github.com/huggingface/transformers/tree/master/examples/distillation>`__) |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``t5-3b`` | | 6-layer, 768-hidden, 12-heads, 82M parameters |
|
||||
| | | | The DistilRoBERTa model distilled from the RoBERTa model `roberta-base` checkpoint. |
|
||||
| | | (see `details <https://github.com/huggingface/transformers/tree/master/examples/distillation>`__) |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``t5-11b`` | | 6-layer, 768-hidden, 12-heads, 82M parameters |
|
||||
| | | | The DistilRoBERTa model distilled from the RoBERTa model `roberta-base` checkpoint. |
|
||||
| | | (see `details <https://github.com/huggingface/transformers/tree/master/examples/distillation>`__) |
|
||||
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
|
||||
.. <https://huggingface.co/transformers/examples.html>`__
|
||||
@@ -6,6 +6,7 @@ def main():
|
||||
"This command line utility let you convert original (author released) model checkpoint to pytorch.\n"
|
||||
"It should be used as one of: \n"
|
||||
">> transformers bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT, \n"
|
||||
">> transformers t5 TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT, \n"
|
||||
">> transformers gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG], \n"
|
||||
">> transformers transfo_xl TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG] or \n"
|
||||
">> transformers gpt2 TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG] or \n"
|
||||
@@ -21,6 +22,23 @@ def main():
|
||||
"https://www.tensorflow.org/install/ for installation instructions.")
|
||||
raise
|
||||
|
||||
if len(sys.argv) != 5:
|
||||
# pylint: disable=line-too-long
|
||||
print("Should be used as `transformers bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`")
|
||||
else:
|
||||
PYTORCH_DUMP_OUTPUT = sys.argv.pop()
|
||||
TF_CONFIG = sys.argv.pop()
|
||||
TF_CHECKPOINT = sys.argv.pop()
|
||||
convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
|
||||
elif sys.argv[1] == "t5":
|
||||
try:
|
||||
from .convert_t5_original_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch
|
||||
except ImportError:
|
||||
print("transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
|
||||
"In that case, it requires TensorFlow to be installed. Please see "
|
||||
"https://www.tensorflow.org/install/ for installation instructions.")
|
||||
raise
|
||||
|
||||
if len(sys.argv) != 5:
|
||||
# pylint: disable=line-too-long
|
||||
print("Should be used as `transformers bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`")
|
||||
|
||||
@@ -33,7 +33,8 @@ from transformers import (load_pytorch_checkpoint_in_tf2_model,
|
||||
OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
RobertaConfig, TFRobertaForMaskedLM, TFRobertaForSequenceClassification, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
DistilBertConfig, TFDistilBertForMaskedLM, TFDistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
CTRLConfig, TFCTRLLMHeadModel, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP)
|
||||
CTRLConfig, TFCTRLLMHeadModel, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
T5Config, TFT5WithLMHeadModel, T5_PRETRAINED_CONFIG_ARCHIVE_MAP)
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
@@ -46,7 +47,8 @@ if is_torch_available():
|
||||
OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
DistilBertForMaskedLM, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
else:
|
||||
(BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
@@ -56,7 +58,8 @@ else:
|
||||
OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
DistilBertForMaskedLM, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP) = (
|
||||
CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP) = (
|
||||
None, None, None, None,
|
||||
None, None,
|
||||
None, None,
|
||||
@@ -65,6 +68,7 @@ else:
|
||||
None, None,
|
||||
None, None, None,
|
||||
None, None, None,
|
||||
None, None,
|
||||
None, None)
|
||||
|
||||
|
||||
@@ -85,7 +89,8 @@ MODEL_CLASSES = {
|
||||
'roberta-large-mnli': (RobertaConfig, TFRobertaForSequenceClassification, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
'distilbert': (DistilBertConfig, TFDistilBertForMaskedLM, DistilBertForMaskedLM, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
'distilbert-base-uncased-distilled-squad': (DistilBertConfig, TFDistilBertForQuestionAnswering, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
'ctrl': (CTRLConfig, TFCTRLLMHeadModel, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP)
|
||||
'ctrl': (CTRLConfig, TFCTRLLMHeadModel, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
't5': (T5Config, TFT5WithLMHeadModel, T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP, T5_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
}
|
||||
|
||||
def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True):
|
||||
|
||||
@@ -27,6 +27,7 @@ from .modeling_xlnet import XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassi
|
||||
from .modeling_xlm import XLMModel, XLMWithLMHeadModel, XLMForSequenceClassification, XLMForQuestionAnswering
|
||||
from .modeling_roberta import RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification
|
||||
from .modeling_distilbert import DistilBertModel, DistilBertForQuestionAnswering, DistilBertForMaskedLM, DistilBertForSequenceClassification
|
||||
from .modeling_t5 import T5Model, T5WithLMHeadModel
|
||||
|
||||
from .modeling_utils import PreTrainedModel, SequenceSummary
|
||||
|
||||
@@ -47,6 +48,7 @@ class AutoModel(object):
|
||||
|
||||
The base model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
- contains `t5`: T5Model (T5 model)
|
||||
- contains `distilbert`: DistilBertModel (DistilBERT model)
|
||||
- contains `roberta`: RobertaModel (RoBERTa model)
|
||||
- contains `bert`: BertModel (Bert model)
|
||||
@@ -70,6 +72,7 @@ class AutoModel(object):
|
||||
|
||||
The model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
- contains `t5`: T5Model (T5 model)
|
||||
- contains `distilbert`: DistilBertModel (DistilBERT model)
|
||||
- contains `roberta`: RobertaModel (RoBERTa model)
|
||||
- contains `bert`: BertModel (Bert model)
|
||||
@@ -136,7 +139,9 @@ class AutoModel(object):
|
||||
model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||
|
||||
"""
|
||||
if 'distilbert' in pretrained_model_name_or_path:
|
||||
if 't5' in pretrained_model_name_or_path:
|
||||
return T5Model.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'distilbert' in pretrained_model_name_or_path:
|
||||
return DistilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'roberta' in pretrained_model_name_or_path:
|
||||
return RobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
@@ -171,6 +176,7 @@ class AutoModelWithLMHead(object):
|
||||
|
||||
The model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
- contains `t5`: T5ModelWithLMHead (T5 model)
|
||||
- contains `distilbert`: DistilBertForMaskedLM (DistilBERT model)
|
||||
- contains `roberta`: RobertaForMaskedLM (RoBERTa model)
|
||||
- contains `bert`: BertForMaskedLM (Bert model)
|
||||
@@ -197,6 +203,7 @@ class AutoModelWithLMHead(object):
|
||||
|
||||
The model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
- contains `t5`: T5ModelWithLMHead (T5 model)
|
||||
- contains `distilbert`: DistilBertForMaskedLM (DistilBERT model)
|
||||
- contains `roberta`: RobertaForMaskedLM (RoBERTa model)
|
||||
- contains `bert`: BertForMaskedLM (Bert model)
|
||||
@@ -262,7 +269,9 @@ class AutoModelWithLMHead(object):
|
||||
model = AutoModelWithLMHead.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||
|
||||
"""
|
||||
if 'distilbert' in pretrained_model_name_or_path:
|
||||
if 't5' in pretrained_model_name_or_path:
|
||||
return T5WithLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'distilbert' in pretrained_model_name_or_path:
|
||||
return DistilBertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'roberta' in pretrained_model_name_or_path:
|
||||
return RobertaForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
|
||||
@@ -27,6 +27,7 @@ from .modeling_tf_xlm import TFXLMModel, TFXLMWithLMHeadModel, TFXLMForSequenceC
|
||||
from .modeling_tf_roberta import TFRobertaModel, TFRobertaForMaskedLM, TFRobertaForSequenceClassification
|
||||
from .modeling_tf_distilbert import TFDistilBertModel, TFDistilBertForQuestionAnswering, TFDistilBertForMaskedLM, TFDistilBertForSequenceClassification
|
||||
from .modeling_tf_ctrl import TFCTRLModel, TFCTRLLMHeadModel
|
||||
from .modeling_tf_t5 import TFT5Model, TFT5WithLMHeadModel
|
||||
|
||||
from .file_utils import add_start_docstrings
|
||||
|
||||
@@ -45,6 +46,7 @@ class TFAutoModel(object):
|
||||
|
||||
The base model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
- contains `t5`: TFT5Model (T5 model)
|
||||
- contains `distilbert`: TFDistilBertModel (DistilBERT model)
|
||||
- contains `roberta`: TFRobertaModel (RoBERTa model)
|
||||
- contains `bert`: TFBertModel (Bert model)
|
||||
@@ -68,6 +70,7 @@ class TFAutoModel(object):
|
||||
|
||||
The model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
- contains `t5`: TFT5Model (T5 model)
|
||||
- contains `distilbert`: TFDistilBertModel (DistilBERT model)
|
||||
- contains `roberta`: TFRobertaModel (RoBERTa model)
|
||||
- contains `bert`: TFTFBertModel (Bert model)
|
||||
@@ -133,7 +136,9 @@ class TFAutoModel(object):
|
||||
model = TFAutoModel.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config)
|
||||
|
||||
"""
|
||||
if 'distilbert' in pretrained_model_name_or_path:
|
||||
if 't5' in pretrained_model_name_or_path:
|
||||
return TFT5Model.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'distilbert' in pretrained_model_name_or_path:
|
||||
return TFDistilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'roberta' in pretrained_model_name_or_path:
|
||||
return TFRobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
@@ -169,6 +174,7 @@ class TFAutoModelWithLMHead(object):
|
||||
|
||||
The model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
- contains `t5`: TFT5WithLMHeadModel (T5 model)
|
||||
- contains `distilbert`: TFDistilBertForMaskedLM (DistilBERT model)
|
||||
- contains `roberta`: TFRobertaForMaskedLM (RoBERTa model)
|
||||
- contains `bert`: TFBertForMaskedLM (Bert model)
|
||||
@@ -195,6 +201,7 @@ class TFAutoModelWithLMHead(object):
|
||||
|
||||
The model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
- contains `t5`: TFT5WithLMHeadModel (T5 model)
|
||||
- contains `distilbert`: TFDistilBertForMaskedLM (DistilBERT model)
|
||||
- contains `roberta`: TFRobertaForMaskedLM (RoBERTa model)
|
||||
- contains `bert`: TFBertForMaskedLM (Bert model)
|
||||
@@ -261,7 +268,9 @@ class TFAutoModelWithLMHead(object):
|
||||
model = TFAutoModelWithLMHead.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config)
|
||||
|
||||
"""
|
||||
if 'distilbert' in pretrained_model_name_or_path:
|
||||
if 't5' in pretrained_model_name_or_path:
|
||||
return TFT5WithLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'distilbert' in pretrained_model_name_or_path:
|
||||
return TFDistilBertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'roberta' in pretrained_model_name_or_path:
|
||||
return TFRobertaForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
|
||||
@@ -27,6 +27,7 @@ from .tokenization_xlnet import XLNetTokenizer
|
||||
from .tokenization_xlm import XLMTokenizer
|
||||
from .tokenization_roberta import RobertaTokenizer
|
||||
from .tokenization_distilbert import DistilBertTokenizer
|
||||
from .tokenization_t5 import T5Tokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -41,6 +42,7 @@ class AutoTokenizer(object):
|
||||
|
||||
The tokenizer class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
- contains `t5`: T5Tokenizer (T5 model)
|
||||
- contains `distilbert`: DistilBertTokenizer (DistilBert model)
|
||||
- contains `roberta`: RobertaTokenizer (RoBERTa model)
|
||||
- contains `bert`: BertTokenizer (Bert model)
|
||||
@@ -64,6 +66,7 @@ class AutoTokenizer(object):
|
||||
|
||||
The tokenizer class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
- contains `t5`: T5Tokenizer (T5 model)
|
||||
- contains `distilbert`: DistilBertTokenizer (DistilBert model)
|
||||
- contains `roberta`: RobertaTokenizer (XLM model)
|
||||
- contains `bert`: BertTokenizer (Bert model)
|
||||
@@ -101,7 +104,9 @@ class AutoTokenizer(object):
|
||||
tokenizer = AutoTokenizer.from_pretrained('./test/bert_saved_model/') # E.g. tokenizer was saved using `save_pretrained('./test/saved_model/')`
|
||||
|
||||
"""
|
||||
if 'distilbert' in pretrained_model_name_or_path:
|
||||
if 't5' in pretrained_model_name_or_path:
|
||||
return T5Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif 'distilbert' in pretrained_model_name_or_path:
|
||||
return DistilBertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif 'roberta' in pretrained_model_name_or_path:
|
||||
return RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user