[Bart] Fix: put dummy_inputs on correct device (#3398)
* Dummy inputs to model.device * Move self.device to ModuleUtilsMixin
This commit is contained in:
@@ -129,8 +129,8 @@ class PretrainedBartModel(PreTrainedModel):
|
||||
@property
|
||||
def dummy_inputs(self):
|
||||
pad_token = self.config.pad_token_id
|
||||
input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]])
|
||||
decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs(self.config, input_ids,)
|
||||
input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
|
||||
decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs(self.config, input_ids)
|
||||
dummy_inputs = {
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": input_ids.ne(pad_token),
|
||||
|
||||
@@ -108,6 +108,10 @@ class ModuleUtilsMixin:
|
||||
module.mem_rss_post_forward = 0
|
||||
module.mem_rss_pre_forward = 0
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
|
||||
class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
r""" Base class for all models.
|
||||
|
||||
Reference in New Issue
Block a user