Initializer range using BertPreTrainedModel
This commit is contained in:
@@ -6,8 +6,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
from transformers.configuration_albert import AlbertConfig
|
from transformers.configuration_albert import AlbertConfig
|
||||||
from transformers.modeling_bert import BertEmbeddings, BertModel, BertSelfAttention, prune_linear_layer, ACT2FN
|
from transformers.modeling_bert import BertEmbeddings, BertPreTrainedModel, BertModel, BertSelfAttention, prune_linear_layer, ACT2FN
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
|
||||||
from .file_utils import add_start_docstrings
|
from .file_utils import add_start_docstrings
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -362,7 +361,7 @@ class AlbertModel(BertModel):
|
|||||||
|
|
||||||
|
|
||||||
@add_start_docstrings("Bert Model with a `language modeling` head on top.", ALBERT_START_DOCSTRING, ALBERT_INPUTS_DOCSTRING)
|
@add_start_docstrings("Bert Model with a `language modeling` head on top.", ALBERT_START_DOCSTRING, ALBERT_INPUTS_DOCSTRING)
|
||||||
class AlbertForMaskedLM(PreTrainedModel):
|
class AlbertForMaskedLM(BertPreTrainedModel):
|
||||||
r"""
|
r"""
|
||||||
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||||
Labels for computing the masked language modeling loss.
|
Labels for computing the masked language modeling loss.
|
||||||
|
|||||||
Reference in New Issue
Block a user