Hubconf doc - Specia case loading

This commit is contained in:
VictorSanh
2019-05-30 16:06:21 -04:00
parent 96592b544b
commit 372a5c1cee

View File

@@ -191,6 +191,12 @@ def bertForSequenceClassification(*args, **kwargs):
The sequence-level classifier is a linear layer that takes as input the
last hidden state of the first character in the input sequence
(see Figures 3a and 3b in the BERT paper).
Args:
num_labels: the number (>=2) of classes for the classifier.
Example:
>>> torch.hub.load('huggingface/pytorch-pretrained-BERT', 'bertForSequenceClassification', 'bert-base-cased', num_labels=2, force_reload=True)
"""
model = BertForSequenceClassification.from_pretrained(*args, **kwargs)
return model
@@ -201,6 +207,12 @@ def bertForMultipleChoice(*args, **kwargs):
"""
BertForMultipleChoice is a fine-tuning model that includes BertModel and a
linear layer on top of the BertModel.
Args:
num_choices: the number (>=2) of classes for the classifier.
Example:
>>> torch.hub.load('huggingface/pytorch-pretrained-BERT', 'bertForMultipleChoice', 'bert-base-cased', num_choices=2, force_reload=True)
"""
model = BertForMultipleChoice.from_pretrained(*args, **kwargs)
return model
@@ -225,6 +237,12 @@ def bertForTokenClassification(*args, **kwargs):
The token-level classifier is a linear layer that takes as input the last
hidden state of the sequence.
Args:
num_labels: the number (>=2) of classes for the classifier.
Example:
>>> torch.hub.load('huggingface/pytorch-pretrained-BERT', 'bertForTokenClassification', 'bert-base-cased', num_labels=2, force_reload=True)
"""
model = BertForTokenClassification.from_pretrained(*args, **kwargs)
return model