Rename add_start_docstrings_to_callable (#8120)

This commit is contained in:
Sylvain Gugger
2020-10-28 13:42:31 -04:00
committed by GitHub
parent 6241c873cd
commit 378142afdf
55 changed files with 327 additions and 292 deletions

View File

@@ -2,7 +2,7 @@ import torch
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_callable
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from transformers.modeling_bert import (
BERT_INPUTS_DOCSTRING,
BERT_START_DOCSTRING,
@@ -134,7 +134,7 @@ class DeeBertModel(BertPreTrainedModel):
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING)
def forward(
self,
input_ids=None,
@@ -288,7 +288,7 @@ class DeeBertForSequenceClassification(BertPreTrainedModel):
self.init_weights()
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING)
def forward(
self,
input_ids=None,

View File

@@ -4,7 +4,7 @@ import torch.nn as nn
from torch.nn import CrossEntropyLoss, MSELoss
from transformers.configuration_roberta import RobertaConfig
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_callable
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from transformers.modeling_roberta import ROBERTA_INPUTS_DOCSTRING, ROBERTA_START_DOCSTRING, RobertaEmbeddings
from .modeling_highway_bert import BertPreTrainedModel, DeeBertModel, HighwayException, entropy
@@ -45,7 +45,7 @@ class DeeRobertaForSequenceClassification(BertPreTrainedModel):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING)
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING)
def forward(
self,
input_ids=None,