Actual fix (#9787)
This commit is contained in:
@@ -541,6 +541,7 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||||||
self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
# Model parallel
|
# Model parallel
|
||||||
self.model_parallel = False
|
self.model_parallel = False
|
||||||
self.device_map = None
|
self.device_map = None
|
||||||
@@ -805,7 +806,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
|
# Model parallel
|
||||||
self.model_parallel = False
|
self.model_parallel = False
|
||||||
|
self.device_map = None
|
||||||
|
|
||||||
@add_start_docstrings(PARALLELIZE_DOCSTRING)
|
@add_start_docstrings(PARALLELIZE_DOCSTRING)
|
||||||
def parallelize(self, device_map=None):
|
def parallelize(self, device_map=None):
|
||||||
@@ -971,6 +974,10 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
|
# Model parallel
|
||||||
|
self.model_parallel = False
|
||||||
|
self.device_map = None
|
||||||
|
|
||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.lm_head
|
return self.lm_head
|
||||||
|
|
||||||
@@ -1153,6 +1160,10 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
|
# Model parallel
|
||||||
|
self.model_parallel = False
|
||||||
|
self.device_map = None
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||||
|
|||||||
@@ -1651,6 +1651,10 @@ class T5EncoderModel(T5PreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
|
# Model parallel
|
||||||
|
self.model_parallel = False
|
||||||
|
self.device_map = None
|
||||||
|
|
||||||
@add_start_docstrings(PARALLELIZE_DOCSTRING)
|
@add_start_docstrings(PARALLELIZE_DOCSTRING)
|
||||||
def parallelize(self, device_map=None):
|
def parallelize(self, device_map=None):
|
||||||
self.device_map = (
|
self.device_map = (
|
||||||
|
|||||||
Reference in New Issue
Block a user