[Bart] Fix: put dummy_inputs on correct device (#3398)
* Dummy inputs to model.device * Move self.device to ModuleUtilsMixin
This commit is contained in:
@@ -108,6 +108,10 @@ class ModuleUtilsMixin:
|
||||
module.mem_rss_post_forward = 0
|
||||
module.mem_rss_pre_forward = 0
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
|
||||
class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
r""" Base class for all models.
|
||||
|
||||
Reference in New Issue
Block a user