diff --git a/pytorch_transformers/__init__.py b/pytorch_transformers/__init__.py index 78916d1ebb..e6774c96d8 100644 --- a/pytorch_transformers/__init__.py +++ b/pytorch_transformers/__init__.py @@ -40,7 +40,8 @@ from .modeling_xlm import (XLMConfig, XLMPreTrainedModel , XLMModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP) from .modeling_roberta import (RobertaConfig, RobertaForMaskedLM, RobertaModel, RobertaForSequenceClassification, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP) -from .modeling_dilbert import (DilBertconfig, DilBertForMaskedLM, DilBertModel, DilBertForSequenceClassification, +from .modeling_dilbert import (DilBertConfig, DilBertForMaskedLM, DilBertModel, + DilBertForSequenceClassification, DilBertForQuestionAnswering, DILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DILBERT_PRETRAINED_MODEL_ARCHIVE_MAP) from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME, PretrainedConfig, PreTrainedModel, prune_layer, Conv1D) diff --git a/pytorch_transformers/modeling_dilbert.py b/pytorch_transformers/modeling_dilbert.py index b5d7e51b79..1fcb33e9ad 100644 --- a/pytorch_transformers/modeling_dilbert.py +++ b/pytorch_transformers/modeling_dilbert.py @@ -45,7 +45,7 @@ DILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { } -class DilBertconfig(PretrainedConfig): +class DilBertConfig(PretrainedConfig): pretrained_config_archive_map = DILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP def __init__(self, @@ -62,7 +62,7 @@ class DilBertconfig(PretrainedConfig): initializer_range=0.02, tie_weights=True, **kwargs): - super(DilBertconfig, self).__init__(**kwargs) + super(DilBertConfig, self).__init__(**kwargs) if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 and isinstance(vocab_size_or_config_json_file, unicode)): @@ -77,6 +77,7 @@ class DilBertconfig(PretrainedConfig): self.n_layers = n_layers self.n_heads = n_heads self.dim = dim + self.hidden_dim = hidden_dim self.dropout = dropout self.attention_dropout = attention_dropout self.activation = activation @@ -341,7 +342,7 @@ class DilBertPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ - config_class = DilBertconfig + config_class = DilBertConfig pretrained_model_archive_map = DILBERT_PRETRAINED_MODEL_ARCHIVE_MAP load_tf_weights = None base_model_prefix = "dilbert" @@ -370,7 +371,7 @@ DILBERT_START_DOCSTRING = r""" For more information on DilBERT, you should check TODO(Victor): Link to Medium Parameters: - config (:class:`~pytorch_transformers.DilBertconfig`): Model configuration class with all the parameters of the model. + config (:class:`~pytorch_transformers.DilBertConfig`): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the :meth:`~pytorch_transformers.PreTrainedModel.from_pretrained` method to load the model weights. """ @@ -391,18 +392,7 @@ DILBERT_INPUTS_DOCSTRING = r""" @add_start_docstrings("The bare DilBERT encoder/transformer outputing raw hidden-states without any specific head on top.", DILBERT_START_DOCSTRING, DILBERT_INPUTS_DOCSTRING) class DilBertModel(DilBertPreTrainedModel): - def __init__(self, config): - super(DilBertModel, self).__init__(config) - - self.embeddings = Embeddings(config) # Embeddings - self.transformer = Transformer(config) # Encoder - - self.apply(self.init_weights) - - def forward(self, - input_ids: torch.tensor, - attention_mask: torch.tensor = None): - """ + r""" Parameters ---------- input_ids: torch.tensor(bs, seq_length) @@ -422,7 +412,18 @@ class DilBertModel(DilBertPreTrainedModel): all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)] Tuple of length n_layers with the attention weights from each layer Optional: only if output_attentions=True - """ + """ + def __init__(self, config): + super(DilBertModel, self).__init__(config) + + self.embeddings = Embeddings(config) # Embeddings + self.transformer = Transformer(config) # Encoder + + self.apply(self.init_weights) + + def forward(self, + input_ids: torch.tensor, + attention_mask: torch.tensor = None): if attention_mask is None: attention_mask = torch.ones_like(input_ids) # (bs, seq_length) @@ -438,33 +439,7 @@ class DilBertModel(DilBertPreTrainedModel): @add_start_docstrings("""DilBert Model with a `masked language modeling` head on top. """, DILBERT_START_DOCSTRING, DILBERT_INPUTS_DOCSTRING) class DilBertForMaskedLM(DilBertPreTrainedModel): - def __init__(self, config): - super(DilBertForMaskedLM, self).__init__(config) - self.output_attentions = config.output_attentions - self.output_hidden_states = config.output_hidden_states - - self.encoder = DilBertModel(config) - self.vocab_transform = nn.Linear(config.dim, config.dim) - self.vocab_layer_norm = nn.LayerNorm(config.dim, eps=1e-12) - self.vocab_projector = nn.Linear(config.dim, config.vocab_size) - - self.apply(self.init_weights) - self.tie_weights() - - self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1) - - def tie_weights_(self): - """ - Tying the weights of the vocabulary projection to the base token embeddings. - """ - if self.config.tie_weights: - self.vocab_projector.weight = self.encoder.embeddings.word_embeddings.weight - - def forward(self, - input_ids: torch.tensor, - attention_mask: torch.tensor = None, - masked_lm_labels: torch.tensor = None): - """ + r""" Parameters ---------- input_ids: torch.tensor(bs, seq_length) @@ -487,7 +462,33 @@ class DilBertForMaskedLM(DilBertPreTrainedModel): all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)] Tuple of length n_layers with the attention weights from each layer Optional: only if `output_attentions`=True + """ + def __init__(self, config): + super(DilBertForMaskedLM, self).__init__(config) + self.output_attentions = config.output_attentions + self.output_hidden_states = config.output_hidden_states + + self.encoder = DilBertModel(config) + self.vocab_transform = nn.Linear(config.dim, config.dim) + self.vocab_layer_norm = nn.LayerNorm(config.dim, eps=1e-12) + self.vocab_projector = nn.Linear(config.dim, config.vocab_size) + + self.apply(self.init_weights) + self.tie_weights_() + + self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1) + + def tie_weights_(self): """ + Tying the weights of the vocabulary projection to the base token embeddings. + """ + if self.config.tie_weights: + self.vocab_projector.weight = self.encoder.embeddings.word_embeddings.weight + + def forward(self, + input_ids: torch.tensor, + attention_mask: torch.tensor = None, + masked_lm_labels: torch.tensor = None): tfmr_output = self.encoder(input_ids=input_ids, attention_mask=attention_mask) hidden_states = tfmr_output[0] # (bs, seq_length, dim) @@ -508,22 +509,7 @@ class DilBertForMaskedLM(DilBertPreTrainedModel): the pooled output) e.g. for GLUE tasks. """, DILBERT_START_DOCSTRING, DILBERT_INPUTS_DOCSTRING) class DilBertForSequenceClassification(DilBertPreTrainedModel): - def __init__(self, config): - super(DilBertForSequenceClassification, self).__init__(config) - self.num_labels = config.num_labels - - self.dilbert = DilBertModel(config) - self.pre_classifier = nn.Linear(config.dim, config.dim) - self.classifier = nn.Linear(config.dim, config.num_labels) - self.dropout = nn.Dropout(config.seq_classif_dropout) - - self.apply(self.init_weights) - - def forward(self, - input_ids: torch.tensor, - attention_mask: torch.tensor = None, - labels: torch.tensor = None): - """ + r""" Parameters ---------- input_ids: torch.tensor(bs, seq_length) @@ -546,7 +532,22 @@ class DilBertForSequenceClassification(DilBertPreTrainedModel): all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)] Tuple of length n_layers with the attention weights from each layer Optional: only if `output_attentions`=True - """ + """ + def __init__(self, config): + super(DilBertForSequenceClassification, self).__init__(config) + self.num_labels = config.num_labels + + self.dilbert = DilBertModel(config) + self.pre_classifier = nn.Linear(config.dim, config.dim) + self.classifier = nn.Linear(config.dim, config.num_labels) + self.dropout = nn.Dropout(config.seq_classif_dropout) + + self.apply(self.init_weights) + + def forward(self, + input_ids: torch.tensor, + attention_mask: torch.tensor = None, + labels: torch.tensor = None): dilbert_output = self.dilbert(input_ids=input_ids, attention_mask=attention_mask) pooled_output = dilbert_output[1] # (bs, dim) @@ -571,22 +572,7 @@ class DilBertForSequenceClassification(DilBertPreTrainedModel): the hidden-states output to compute `span start logits` and `span end logits`). """, DILBERT_START_DOCSTRING, DILBERT_INPUTS_DOCSTRING) class DilBertForQuestionAnswering(DilBertPreTrainedModel): - def __init__(self, config): - super(DilBertForQuestionAnswering, self).__init__(config) - - self.dilbert = DilBertModel(config) - self.qa_outputs = nn.Linear(config.dim, config.num_labels) - assert config.num_labels == 2 - self.dropout = nn.Dropout(config.qa_dropout) - - self.apply(self.init_weights) - - def forward(self, - input_ids: torch.tensor, - attention_mask: torch.tensor = None, - start_positions: torch.tensor = None, - end_positions: torch.tensor = None): - """ + r""" Parameters ---------- input_ids: torch.tensor(bs, seq_length) @@ -619,7 +605,22 @@ class DilBertForQuestionAnswering(DilBertPreTrainedModel): all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)] Tuple of length n_layers with the attention weights from each layer Optional: only if `output_attentions`=True - """ + """ + def __init__(self, config): + super(DilBertForQuestionAnswering, self).__init__(config) + + self.dilbert = DilBertModel(config) + self.qa_outputs = nn.Linear(config.dim, config.num_labels) + assert config.num_labels == 2 + self.dropout = nn.Dropout(config.qa_dropout) + + self.apply(self.init_weights) + + def forward(self, + input_ids: torch.tensor, + attention_mask: torch.tensor = None, + start_positions: torch.tensor = None, + end_positions: torch.tensor = None): dilbert_output = self.dilbert(input_ids=input_ids, attention_mask=attention_mask) hidden_states = dilbert_output[0] # (bs, max_query_len, dim)