[Bart] Fix: put dummy_inputs on correct device (#3398)

* Dummy inputs to model.device

* Move self.device to ModuleUtilsMixin
This commit is contained in:
Sam Shleifer
2020-03-26 18:42:09 -04:00
committed by GitHub
parent 1a5aefc95c
commit 2b2a2f8df2
2 changed files with 6 additions and 2 deletions

View File

@@ -129,8 +129,8 @@ class PretrainedBartModel(PreTrainedModel):
@property
def dummy_inputs(self):
pad_token = self.config.pad_token_id
input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]])
decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs(self.config, input_ids,)
input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs(self.config, input_ids)
dummy_inputs = {
"decoder_input_ids": decoder_input_ids,
"attention_mask": input_ids.ne(pad_token),

View File

@@ -108,6 +108,10 @@ class ModuleUtilsMixin:
module.mem_rss_post_forward = 0
module.mem_rss_pre_forward = 0
@property
def device(self):
return next(self.parameters()).device
class PreTrainedModel(nn.Module, ModuleUtilsMixin):
r""" Base class for all models.