Correct device assignment
This commit is contained in:
@@ -894,6 +894,7 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
|
||||
]
|
||||
resized_dense.weight.data[..., : self.cls.predictions.dense.weight.data.shape[1]] = kept_data
|
||||
self.cls.predictions.dense = resized_dense
|
||||
self.cls.predictions.dense.to(self.device)
|
||||
|
||||
if output_embeddings is not None:
|
||||
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
|
||||
@@ -1008,6 +1009,7 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel):
|
||||
]
|
||||
resized_dense.weight.data[..., : self.cls.predictions.dense.weight.data.shape[1]] = kept_data
|
||||
self.cls.predictions.dense = resized_dense
|
||||
self.cls.predictions.dense.to(self.device)
|
||||
|
||||
if output_embeddings is not None:
|
||||
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
|
||||
|
||||
Reference in New Issue
Block a user