ALBERT-V2
This commit is contained in:
@@ -18,10 +18,14 @@
|
|||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
'albert-base': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-config.json",
|
'albert-base-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-config.json",
|
||||||
'albert-large': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-config.json",
|
'albert-large-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-config.json",
|
||||||
'albert-xlarge': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-config.json",
|
'albert-xlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-config.json",
|
||||||
'albert-xxlarge': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-config.json",
|
'albert-xxlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-config.json",
|
||||||
|
'albert-base-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-config.json",
|
||||||
|
'albert-large-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-config.json",
|
||||||
|
'albert-xlarge-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-config.json",
|
||||||
|
'albert-xxlarge-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-config.json",
|
||||||
}
|
}
|
||||||
|
|
||||||
class AlbertConfig(PretrainedConfig):
|
class AlbertConfig(PretrainedConfig):
|
||||||
|
|||||||
@@ -26,9 +26,10 @@ from transformers import AlbertConfig, AlbertForMaskedLM, load_tf_weights_in_alb
|
|||||||
import logging
|
import logging
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
|
|
||||||
|
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, albert_config_file, pytorch_dump_path):
|
||||||
# Initialise PyTorch model
|
# Initialise PyTorch model
|
||||||
config = AlbertConfig.from_json_file(bert_config_file)
|
config = AlbertConfig.from_json_file(albert_config_file)
|
||||||
print("Building PyTorch model from configuration: {}".format(str(config)))
|
print("Building PyTorch model from configuration: {}".format(str(config)))
|
||||||
model = AlbertForMaskedLM(config)
|
model = AlbertForMaskedLM(config)
|
||||||
|
|
||||||
|
|||||||
@@ -30,10 +30,14 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||||
'albert-base': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-pytorch_model.bin",
|
'albert-base-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-pytorch_model.bin",
|
||||||
'albert-large': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-pytorch_model.bin",
|
'albert-large-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-pytorch_model.bin",
|
||||||
'albert-xlarge': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-pytorch_model.bin",
|
'albert-xlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-pytorch_model.bin",
|
||||||
'albert-xxlarge': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-pytorch_model.bin",
|
'albert-xxlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-pytorch_model.bin",
|
||||||
|
'albert-base-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-pytorch_model.bin",
|
||||||
|
'albert-large-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-pytorch_model.bin",
|
||||||
|
'albert-xlarge-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-pytorch_model.bin",
|
||||||
|
'albert-xxlarge-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-pytorch_model.bin",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -538,8 +542,8 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
|
|||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
|
|
||||||
tokenizer = AlbertTokenizer.from_pretrained('albert-base')
|
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
|
||||||
model = AlbertForSequenceClassification.from_pretrained('albert-base')
|
model = AlbertForSequenceClassification.from_pretrained('albert-base-v2')
|
||||||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
||||||
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
|
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
|
||||||
outputs = model(input_ids, labels=labels)
|
outputs = model(input_ids, labels=labels)
|
||||||
|
|||||||
@@ -29,18 +29,26 @@ VOCAB_FILES_NAMES = {'vocab_file': 'spiece.model'}
|
|||||||
PRETRAINED_VOCAB_FILES_MAP = {
|
PRETRAINED_VOCAB_FILES_MAP = {
|
||||||
'vocab_file':
|
'vocab_file':
|
||||||
{
|
{
|
||||||
'albert-base': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-spiece.model",
|
'albert-base-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-spiece.model",
|
||||||
'albert-large': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-spiece.model",
|
'albert-large-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-spiece.model",
|
||||||
'albert-xlarge': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-spiece.model",
|
'albert-xlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-spiece.model",
|
||||||
'albert-xxlarge': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-spiece.model",
|
'albert-xxlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-spiece.model",
|
||||||
|
'albert-base-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-spiece.model",
|
||||||
|
'albert-large-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-spiece.model",
|
||||||
|
'albert-xlarge-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-spiece.model",
|
||||||
|
'albert-xxlarge-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-spiece.model",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||||
'albert-base': 512,
|
'albert-base-v1': 512,
|
||||||
'albert-large': 512,
|
'albert-large-v1': 512,
|
||||||
'albert-xlarge': 512,
|
'albert-xlarge-v1': 512,
|
||||||
'albert-xxlarge': 512,
|
'albert-xxlarge-v1': 512,
|
||||||
|
'albert-base-v2': 512,
|
||||||
|
'albert-large-v2': 512,
|
||||||
|
'albert-xlarge-v2': 512,
|
||||||
|
'albert-xxlarge-v2': 512,
|
||||||
}
|
}
|
||||||
|
|
||||||
SPIECE_UNDERLINE = u'▁'
|
SPIECE_UNDERLINE = u'▁'
|
||||||
|
|||||||
Reference in New Issue
Block a user