ALBERT-V2

This commit is contained in:
Lysandre
2019-11-04 11:34:30 -05:00
committed by Lysandre Debut
parent c110c41fdb
commit 70d99980de
4 changed files with 37 additions and 20 deletions

View File

@@ -30,10 +30,14 @@ logger = logging.getLogger(__name__)
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
'albert-base': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-pytorch_model.bin",
'albert-large': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-pytorch_model.bin",
'albert-xlarge': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-pytorch_model.bin",
'albert-xxlarge': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-pytorch_model.bin",
'albert-base-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-pytorch_model.bin",
'albert-large-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-pytorch_model.bin",
'albert-xlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-pytorch_model.bin",
'albert-xxlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-pytorch_model.bin",
'albert-base-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-pytorch_model.bin",
'albert-large-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-pytorch_model.bin",
'albert-xlarge-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-pytorch_model.bin",
'albert-xxlarge-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-pytorch_model.bin",
}
@@ -538,8 +542,8 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
Examples::
tokenizer = AlbertTokenizer.from_pretrained('albert-base')
model = AlbertForSequenceClassification.from_pretrained('albert-base')
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
model = AlbertForSequenceClassification.from_pretrained('albert-base-v2')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)