diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py index 23c513393c..ab44440197 100644 --- a/src/transformers/modeling_bart.py +++ b/src/transformers/modeling_bart.py @@ -129,8 +129,8 @@ class PretrainedBartModel(PreTrainedModel): @property def dummy_inputs(self): pad_token = self.config.pad_token_id - input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]]) - decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs(self.config, input_ids,) + input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) + decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs(self.config, input_ids) dummy_inputs = { "decoder_input_ids": decoder_input_ids, "attention_mask": input_ids.ne(pad_token), diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index e1f5fd2af2..9d4abb2ded 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -108,6 +108,10 @@ class ModuleUtilsMixin: module.mem_rss_post_forward = 0 module.mem_rss_pre_forward = 0 + @property + def device(self): + return next(self.parameters()).device + class PreTrainedModel(nn.Module, ModuleUtilsMixin): r""" Base class for all models.