Correct device assignment

This commit is contained in:
Lysandre
2020-06-19 21:58:17 -04:00
parent 9a3f91088c
commit d97b4176e5

View File

@@ -894,6 +894,7 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
]
resized_dense.weight.data[..., : self.cls.predictions.dense.weight.data.shape[1]] = kept_data
self.cls.predictions.dense = resized_dense
self.cls.predictions.dense.to(self.device)
if output_embeddings is not None:
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
@@ -1008,6 +1009,7 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel):
]
resized_dense.weight.data[..., : self.cls.predictions.dense.weight.data.shape[1]] = kept_data
self.cls.predictions.dense = resized_dense
self.cls.predictions.dense.to(self.device)
if output_embeddings is not None:
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())