doc and conversion
This commit is contained in:
@@ -119,5 +119,8 @@ Here is the full list of the currently provided pretrained models together with
|
|||||||
| | | | The DistilBERT model distilled from the BERT model `bert-base-uncased` checkpoint, with an additional linear layer. |
|
| | | | The DistilBERT model distilled from the BERT model `bert-base-uncased` checkpoint, with an additional linear layer. |
|
||||||
| | | (see `details <https://medium.com/huggingface/distilbert-8cf3380435b5>`__) |
|
| | | (see `details <https://medium.com/huggingface/distilbert-8cf3380435b5>`__) |
|
||||||
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||||
|
| CTRL | ``ctrl`` | | 48-layer, 1280-hidden, 16-heads, 1.6B parameters |
|
||||||
|
| | | | Salesforce's Large-sized CTRL English model |
|
||||||
|
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||||
|
|
||||||
.. <https://huggingface.co/transformers/examples.html>`__
|
.. <https://huggingface.co/transformers/examples.html>`__
|
||||||
@@ -155,6 +155,11 @@ if is_tf_available():
|
|||||||
load_distilbert_pt_weights_in_tf2,
|
load_distilbert_pt_weights_in_tf2,
|
||||||
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
|
|
||||||
|
from .modeling_tf_ctrl import (TFCTRLPreTrainedModel, TFCTRLModel,
|
||||||
|
TFCTRLLMHeadModel,
|
||||||
|
load_ctrl_pt_weights_in_tf2,
|
||||||
|
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
|
|
||||||
# TF 2.0 <=> PyTorch conversion utilities
|
# TF 2.0 <=> PyTorch conversion utilities
|
||||||
if is_tf_available() and is_torch_available():
|
if is_tf_available() and is_torch_available():
|
||||||
from .modeling_tf_pytorch_utils import (convert_tf_weight_name_to_pt_weight_name,
|
from .modeling_tf_pytorch_utils import (convert_tf_weight_name_to_pt_weight_name,
|
||||||
|
|||||||
@@ -31,7 +31,8 @@ from transformers import (BertConfig, TFBertForPreTraining, TFBertForQuestionAns
|
|||||||
TransfoXLConfig, TFTransfoXLLMHeadModel, load_transfo_xl_pt_weights_in_tf2, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
TransfoXLConfig, TFTransfoXLLMHeadModel, load_transfo_xl_pt_weights_in_tf2, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, load_openai_gpt_pt_weights_in_tf2, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, load_openai_gpt_pt_weights_in_tf2, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
RobertaConfig, TFRobertaForMaskedLM, TFRobertaForSequenceClassification, load_roberta_pt_weights_in_tf2, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
RobertaConfig, TFRobertaForMaskedLM, TFRobertaForSequenceClassification, load_roberta_pt_weights_in_tf2, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
DistilBertConfig, TFDistilBertForMaskedLM, TFDistilBertForQuestionAnswering, load_distilbert_pt_weights_in_tf2, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP)
|
DistilBertConfig, TFDistilBertForMaskedLM, TFDistilBertForQuestionAnswering, load_distilbert_pt_weights_in_tf2, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
CTRLConfig, TFCTRLLMHeadModel, load_ctrl_pt_weights_in_tf2, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP)
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
@@ -43,7 +44,8 @@ if is_torch_available():
|
|||||||
TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
DistilBertForMaskedLM, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
DistilBertForMaskedLM, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
else:
|
else:
|
||||||
(BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
(BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
|
GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
@@ -52,7 +54,8 @@ else:
|
|||||||
TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
DistilBertForMaskedLM, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,) = (
|
DistilBertForMaskedLM, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP) = (
|
||||||
None, None, None, None,
|
None, None, None, None,
|
||||||
None, None,
|
None, None,
|
||||||
None, None,
|
None, None,
|
||||||
@@ -60,7 +63,8 @@ else:
|
|||||||
None, None,
|
None, None,
|
||||||
None, None,
|
None, None,
|
||||||
None, None, None,
|
None, None, None,
|
||||||
None, None, None,)
|
None, None, None,
|
||||||
|
None, None)
|
||||||
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@@ -80,6 +84,7 @@ MODEL_CLASSES = {
|
|||||||
'roberta-large-mnli': (RobertaConfig, TFRobertaForSequenceClassification, load_roberta_pt_weights_in_tf2, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
'roberta-large-mnli': (RobertaConfig, TFRobertaForSequenceClassification, load_roberta_pt_weights_in_tf2, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||||
'distilbert': (DistilBertConfig, TFDistilBertForMaskedLM, load_distilbert_pt_weights_in_tf2, DistilBertForMaskedLM, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
'distilbert': (DistilBertConfig, TFDistilBertForMaskedLM, load_distilbert_pt_weights_in_tf2, DistilBertForMaskedLM, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||||
'distilbert-base-uncased-distilled-squad': (DistilBertConfig, TFDistilBertForQuestionAnswering, load_distilbert_pt_weights_in_tf2, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
'distilbert-base-uncased-distilled-squad': (DistilBertConfig, TFDistilBertForQuestionAnswering, load_distilbert_pt_weights_in_tf2, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||||
|
'ctrl': (CTRLConfig, TFCTRLLMHeadModel, load_ctrl_pt_weights_in_tf2, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP)
|
||||||
}
|
}
|
||||||
|
|
||||||
def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True):
|
def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True):
|
||||||
@@ -228,6 +233,7 @@ if __name__ == "__main__":
|
|||||||
convert_all_pt_checkpoints_to_tf(args.model_type.lower() if args.model_type is not None else None,
|
convert_all_pt_checkpoints_to_tf(args.model_type.lower() if args.model_type is not None else None,
|
||||||
args.tf_dump_path,
|
args.tf_dump_path,
|
||||||
model_shortcut_names_or_path=[args.pytorch_checkpoint_path] if args.pytorch_checkpoint_path is not None else None,
|
model_shortcut_names_or_path=[args.pytorch_checkpoint_path] if args.pytorch_checkpoint_path is not None else None,
|
||||||
|
config_shortcut_names_or_path=[args.config_file] if args.config_file is not None else None,
|
||||||
compare_with_pt_model=args.compare_with_pt_model,
|
compare_with_pt_model=args.compare_with_pt_model,
|
||||||
use_cached_models=args.use_cached_models,
|
use_cached_models=args.use_cached_models,
|
||||||
only_convert_finetuned_models=args.only_convert_finetuned_models)
|
only_convert_finetuned_models=args.only_convert_finetuned_models)
|
||||||
|
|||||||
Reference in New Issue
Block a user