Hubconf doc - Specia case loading

This commit is contained in:
VictorSanh
2019-05-30 16:06:21 -04:00
parent 96592b544b
commit 372a5c1cee

View File

@@ -191,6 +191,12 @@ def bertForSequenceClassification(*args, **kwargs):
The sequence-level classifier is a linear layer that takes as input the The sequence-level classifier is a linear layer that takes as input the
last hidden state of the first character in the input sequence last hidden state of the first character in the input sequence
(see Figures 3a and 3b in the BERT paper). (see Figures 3a and 3b in the BERT paper).
Args:
num_labels: the number (>=2) of classes for the classifier.
Example:
>>> torch.hub.load('huggingface/pytorch-pretrained-BERT', 'bertForSequenceClassification', 'bert-base-cased', num_labels=2, force_reload=True)
""" """
model = BertForSequenceClassification.from_pretrained(*args, **kwargs) model = BertForSequenceClassification.from_pretrained(*args, **kwargs)
return model return model
@@ -201,6 +207,12 @@ def bertForMultipleChoice(*args, **kwargs):
""" """
BertForMultipleChoice is a fine-tuning model that includes BertModel and a BertForMultipleChoice is a fine-tuning model that includes BertModel and a
linear layer on top of the BertModel. linear layer on top of the BertModel.
Args:
num_choices: the number (>=2) of classes for the classifier.
Example:
>>> torch.hub.load('huggingface/pytorch-pretrained-BERT', 'bertForMultipleChoice', 'bert-base-cased', num_choices=2, force_reload=True)
""" """
model = BertForMultipleChoice.from_pretrained(*args, **kwargs) model = BertForMultipleChoice.from_pretrained(*args, **kwargs)
return model return model
@@ -225,6 +237,12 @@ def bertForTokenClassification(*args, **kwargs):
The token-level classifier is a linear layer that takes as input the last The token-level classifier is a linear layer that takes as input the last
hidden state of the sequence. hidden state of the sequence.
Args:
num_labels: the number (>=2) of classes for the classifier.
Example:
>>> torch.hub.load('huggingface/pytorch-pretrained-BERT', 'bertForTokenClassification', 'bert-base-cased', num_labels=2, force_reload=True)
""" """
model = BertForTokenClassification.from_pretrained(*args, **kwargs) model = BertForTokenClassification.from_pretrained(*args, **kwargs)
return model return model