Correct device assignment
This commit is contained in:
@@ -894,6 +894,7 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
|
|||||||
]
|
]
|
||||||
resized_dense.weight.data[..., : self.cls.predictions.dense.weight.data.shape[1]] = kept_data
|
resized_dense.weight.data[..., : self.cls.predictions.dense.weight.data.shape[1]] = kept_data
|
||||||
self.cls.predictions.dense = resized_dense
|
self.cls.predictions.dense = resized_dense
|
||||||
|
self.cls.predictions.dense.to(self.device)
|
||||||
|
|
||||||
if output_embeddings is not None:
|
if output_embeddings is not None:
|
||||||
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
|
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
|
||||||
@@ -1008,6 +1009,7 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel):
|
|||||||
]
|
]
|
||||||
resized_dense.weight.data[..., : self.cls.predictions.dense.weight.data.shape[1]] = kept_data
|
resized_dense.weight.data[..., : self.cls.predictions.dense.weight.data.shape[1]] = kept_data
|
||||||
self.cls.predictions.dense = resized_dense
|
self.cls.predictions.dense = resized_dense
|
||||||
|
self.cls.predictions.dense.to(self.device)
|
||||||
|
|
||||||
if output_embeddings is not None:
|
if output_embeddings is not None:
|
||||||
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
|
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
|
||||||
|
|||||||
Reference in New Issue
Block a user