From 2b2a2f8df27d7ee211f52ea1b482a8ef0baf8ba4 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Thu, 26 Mar 2020 18:42:09 -0400 Subject: [PATCH] [Bart] Fix: put dummy_inputs on correct device (#3398) * Dummy inputs to model.device * Move self.device to ModuleUtilsMixin --- src/transformers/modeling_bart.py | 4 ++-- src/transformers/modeling_utils.py | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py index 23c513393c..ab44440197 100644 --- a/src/transformers/modeling_bart.py +++ b/src/transformers/modeling_bart.py @@ -129,8 +129,8 @@ class PretrainedBartModel(PreTrainedModel): @property def dummy_inputs(self): pad_token = self.config.pad_token_id - input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]]) - decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs(self.config, input_ids,) + input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) + decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs(self.config, input_ids) dummy_inputs = { "decoder_input_ids": decoder_input_ids, "attention_mask": input_ids.ne(pad_token), diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index e1f5fd2af2..9d4abb2ded 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -108,6 +108,10 @@ class ModuleUtilsMixin: module.mem_rss_post_forward = 0 module.mem_rss_pre_forward = 0 + @property + def device(self): + return next(self.parameters()).device + class PreTrainedModel(nn.Module, ModuleUtilsMixin): r""" Base class for all models.