ALBERT-V2
This commit is contained in:
@@ -30,10 +30,14 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
'albert-base': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-pytorch_model.bin",
|
||||
'albert-large': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-pytorch_model.bin",
|
||||
'albert-xlarge': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-pytorch_model.bin",
|
||||
'albert-xxlarge': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-pytorch_model.bin",
|
||||
'albert-base-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-pytorch_model.bin",
|
||||
'albert-large-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-pytorch_model.bin",
|
||||
'albert-xlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-pytorch_model.bin",
|
||||
'albert-xxlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-pytorch_model.bin",
|
||||
'albert-base-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-pytorch_model.bin",
|
||||
'albert-large-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-pytorch_model.bin",
|
||||
'albert-xlarge-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-pytorch_model.bin",
|
||||
'albert-xxlarge-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-pytorch_model.bin",
|
||||
}
|
||||
|
||||
|
||||
@@ -538,8 +542,8 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
|
||||
|
||||
Examples::
|
||||
|
||||
tokenizer = AlbertTokenizer.from_pretrained('albert-base')
|
||||
model = AlbertForSequenceClassification.from_pretrained('albert-base')
|
||||
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
|
||||
model = AlbertForSequenceClassification.from_pretrained('albert-base-v2')
|
||||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
||||
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
|
||||
outputs = model(input_ids, labels=labels)
|
||||
|
||||
Reference in New Issue
Block a user