add OpenAI GPT
This commit is contained in:
@@ -416,12 +416,12 @@ class BertPreTrainingHeads(nn.Module):
|
||||
return prediction_scores, seq_relationship_score
|
||||
|
||||
|
||||
class PreTrainedModel(nn.Module):
|
||||
class BertPreTrainedModel(nn.Module):
|
||||
""" An abstract class to handle weights initialization and
|
||||
a simple interface for dowloading and loading pretrained models.
|
||||
"""
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super(PreTrainedModel, self).__init__()
|
||||
super(BertPreTrainedModel, self).__init__()
|
||||
if not isinstance(config, BertConfig):
|
||||
raise ValueError(
|
||||
"Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
|
||||
@@ -447,7 +447,7 @@ class PreTrainedModel(nn.Module):
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, *inputs, **kwargs):
|
||||
"""
|
||||
Instantiate a PreTrainedModel from a pre-trained model file or a pytorch state dict.
|
||||
Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
||||
Download and cache the pre-trained model file if needed.
|
||||
|
||||
Params:
|
||||
@@ -547,13 +547,16 @@ class PreTrainedModel(nn.Module):
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.info("Weights from pretrained model not used in {}: {}".format(
|
||||
model.__class__.__name__, unexpected_keys))
|
||||
if len(error_msgs) > 0:
|
||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
||||
self.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
if tempdir:
|
||||
# Clean up temp dir
|
||||
shutil.rmtree(tempdir)
|
||||
return model
|
||||
|
||||
|
||||
class BertModel(PreTrainedModel):
|
||||
class BertModel(BertPreTrainedModel):
|
||||
"""BERT model ("Bidirectional Embedding Representations from a Transformer").
|
||||
|
||||
Params:
|
||||
@@ -636,7 +639,7 @@ class BertModel(PreTrainedModel):
|
||||
return encoded_layers, pooled_output
|
||||
|
||||
|
||||
class BertForPreTraining(PreTrainedModel):
|
||||
class BertForPreTraining(BertPreTrainedModel):
|
||||
"""BERT model with pre-training heads.
|
||||
This module comprises the BERT model followed by the two pre-training heads:
|
||||
- the masked language modeling head, and
|
||||
@@ -707,7 +710,7 @@ class BertForPreTraining(PreTrainedModel):
|
||||
return prediction_scores, seq_relationship_score
|
||||
|
||||
|
||||
class BertForMaskedLM(PreTrainedModel):
|
||||
class BertForMaskedLM(BertPreTrainedModel):
|
||||
"""BERT model with the masked language modeling head.
|
||||
This module comprises the BERT model followed by the masked language modeling head.
|
||||
|
||||
@@ -768,7 +771,7 @@ class BertForMaskedLM(PreTrainedModel):
|
||||
return prediction_scores
|
||||
|
||||
|
||||
class BertForNextSentencePrediction(PreTrainedModel):
|
||||
class BertForNextSentencePrediction(BertPreTrainedModel):
|
||||
"""BERT model with next sentence prediction head.
|
||||
This module comprises the BERT model followed by the next sentence classification head.
|
||||
|
||||
@@ -830,7 +833,7 @@ class BertForNextSentencePrediction(PreTrainedModel):
|
||||
return seq_relationship_score
|
||||
|
||||
|
||||
class BertForSequenceClassification(PreTrainedModel):
|
||||
class BertForSequenceClassification(BertPreTrainedModel):
|
||||
"""BERT model for classification.
|
||||
This module is composed of the BERT model with a linear layer on top of
|
||||
the pooled output.
|
||||
@@ -875,7 +878,7 @@ class BertForSequenceClassification(PreTrainedModel):
|
||||
logits = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config, num_labels=2):
|
||||
def __init__(self, config, num_labels):
|
||||
super(BertForSequenceClassification, self).__init__(config)
|
||||
self.num_labels = num_labels
|
||||
self.bert = BertModel(config)
|
||||
@@ -896,7 +899,7 @@ class BertForSequenceClassification(PreTrainedModel):
|
||||
return logits
|
||||
|
||||
|
||||
class BertForMultipleChoice(PreTrainedModel):
|
||||
class BertForMultipleChoice(BertPreTrainedModel):
|
||||
"""BERT model for multiple choice tasks.
|
||||
This module is composed of the BERT model with a linear layer on top of
|
||||
the pooled output.
|
||||
@@ -940,7 +943,7 @@ class BertForMultipleChoice(PreTrainedModel):
|
||||
logits = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config, num_choices=2):
|
||||
def __init__(self, config, num_choices):
|
||||
super(BertForMultipleChoice, self).__init__(config)
|
||||
self.num_choices = num_choices
|
||||
self.bert = BertModel(config)
|
||||
@@ -965,7 +968,7 @@ class BertForMultipleChoice(PreTrainedModel):
|
||||
return reshaped_logits
|
||||
|
||||
|
||||
class BertForTokenClassification(PreTrainedModel):
|
||||
class BertForTokenClassification(BertPreTrainedModel):
|
||||
"""BERT model for token-level classification.
|
||||
This module is composed of the BERT model with a linear layer on top of
|
||||
the full hidden state of the last layer.
|
||||
@@ -1010,7 +1013,7 @@ class BertForTokenClassification(PreTrainedModel):
|
||||
logits = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config, num_labels=2):
|
||||
def __init__(self, config, num_labels):
|
||||
super(BertForTokenClassification, self).__init__(config)
|
||||
self.num_labels = num_labels
|
||||
self.bert = BertModel(config)
|
||||
@@ -1031,7 +1034,7 @@ class BertForTokenClassification(PreTrainedModel):
|
||||
return logits
|
||||
|
||||
|
||||
class BertForQuestionAnswering(PreTrainedModel):
|
||||
class BertForQuestionAnswering(BertPreTrainedModel):
|
||||
"""BERT model for Question Answering (span extraction).
|
||||
This module is composed of the BERT model with a linear layer on top of
|
||||
the sequence output that computes start_logits and end_logits
|
||||
|
||||
Reference in New Issue
Block a user