add OpenAI GPT

This commit is contained in:
thomwolf
2019-01-08 12:26:58 +01:00
parent 793dcd236b
commit eed51c5bdf
8 changed files with 573 additions and 270 deletions

View File

@@ -416,12 +416,12 @@ class BertPreTrainingHeads(nn.Module):
return prediction_scores, seq_relationship_score
class PreTrainedModel(nn.Module):
class BertPreTrainedModel(nn.Module):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def __init__(self, config, *inputs, **kwargs):
super(PreTrainedModel, self).__init__()
super(BertPreTrainedModel, self).__init__()
if not isinstance(config, BertConfig):
raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
@@ -447,7 +447,7 @@ class PreTrainedModel(nn.Module):
@classmethod
def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, *inputs, **kwargs):
"""
Instantiate a PreTrainedModel from a pre-trained model file or a pytorch state dict.
Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Params:
@@ -547,13 +547,16 @@ class PreTrainedModel(nn.Module):
if len(unexpected_keys) > 0:
logger.info("Weights from pretrained model not used in {}: {}".format(
model.__class__.__name__, unexpected_keys))
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
self.__class__.__name__, "\n\t".join(error_msgs)))
if tempdir:
# Clean up temp dir
shutil.rmtree(tempdir)
return model
class BertModel(PreTrainedModel):
class BertModel(BertPreTrainedModel):
"""BERT model ("Bidirectional Embedding Representations from a Transformer").
Params:
@@ -636,7 +639,7 @@ class BertModel(PreTrainedModel):
return encoded_layers, pooled_output
class BertForPreTraining(PreTrainedModel):
class BertForPreTraining(BertPreTrainedModel):
"""BERT model with pre-training heads.
This module comprises the BERT model followed by the two pre-training heads:
- the masked language modeling head, and
@@ -707,7 +710,7 @@ class BertForPreTraining(PreTrainedModel):
return prediction_scores, seq_relationship_score
class BertForMaskedLM(PreTrainedModel):
class BertForMaskedLM(BertPreTrainedModel):
"""BERT model with the masked language modeling head.
This module comprises the BERT model followed by the masked language modeling head.
@@ -768,7 +771,7 @@ class BertForMaskedLM(PreTrainedModel):
return prediction_scores
class BertForNextSentencePrediction(PreTrainedModel):
class BertForNextSentencePrediction(BertPreTrainedModel):
"""BERT model with next sentence prediction head.
This module comprises the BERT model followed by the next sentence classification head.
@@ -830,7 +833,7 @@ class BertForNextSentencePrediction(PreTrainedModel):
return seq_relationship_score
class BertForSequenceClassification(PreTrainedModel):
class BertForSequenceClassification(BertPreTrainedModel):
"""BERT model for classification.
This module is composed of the BERT model with a linear layer on top of
the pooled output.
@@ -875,7 +878,7 @@ class BertForSequenceClassification(PreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, num_labels=2):
def __init__(self, config, num_labels):
super(BertForSequenceClassification, self).__init__(config)
self.num_labels = num_labels
self.bert = BertModel(config)
@@ -896,7 +899,7 @@ class BertForSequenceClassification(PreTrainedModel):
return logits
class BertForMultipleChoice(PreTrainedModel):
class BertForMultipleChoice(BertPreTrainedModel):
"""BERT model for multiple choice tasks.
This module is composed of the BERT model with a linear layer on top of
the pooled output.
@@ -940,7 +943,7 @@ class BertForMultipleChoice(PreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, num_choices=2):
def __init__(self, config, num_choices):
super(BertForMultipleChoice, self).__init__(config)
self.num_choices = num_choices
self.bert = BertModel(config)
@@ -965,7 +968,7 @@ class BertForMultipleChoice(PreTrainedModel):
return reshaped_logits
class BertForTokenClassification(PreTrainedModel):
class BertForTokenClassification(BertPreTrainedModel):
"""BERT model for token-level classification.
This module is composed of the BERT model with a linear layer on top of
the full hidden state of the last layer.
@@ -1010,7 +1013,7 @@ class BertForTokenClassification(PreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, num_labels=2):
def __init__(self, config, num_labels):
super(BertForTokenClassification, self).__init__(config)
self.num_labels = num_labels
self.bert = BertModel(config)
@@ -1031,7 +1034,7 @@ class BertForTokenClassification(PreTrainedModel):
return logits
class BertForQuestionAnswering(PreTrainedModel):
class BertForQuestionAnswering(BertPreTrainedModel):
"""BERT model for Question Answering (span extraction).
This module is composed of the BERT model with a linear layer on top of
the sequence output that computes start_logits and end_logits