From d97b4176e5e9acdab930d73a7cb308b12bd4ad9e Mon Sep 17 00:00:00 2001 From: Lysandre Date: Fri, 19 Jun 2020 21:58:17 -0400 Subject: [PATCH] Correct device assignment --- src/transformers/modeling_mobilebert.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/modeling_mobilebert.py b/src/transformers/modeling_mobilebert.py index 48e7f875dd..6d536a0135 100644 --- a/src/transformers/modeling_mobilebert.py +++ b/src/transformers/modeling_mobilebert.py @@ -894,6 +894,7 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel): ] resized_dense.weight.data[..., : self.cls.predictions.dense.weight.data.shape[1]] = kept_data self.cls.predictions.dense = resized_dense + self.cls.predictions.dense.to(self.device) if output_embeddings is not None: self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings()) @@ -1008,6 +1009,7 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel): ] resized_dense.weight.data[..., : self.cls.predictions.dense.weight.data.shape[1]] = kept_data self.cls.predictions.dense = resized_dense + self.cls.predictions.dense.to(self.device) if output_embeddings is not None: self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())