[Bart] Fix: put dummy_inputs on correct device (#3398)

* Dummy inputs to model.device

* Move self.device to ModuleUtilsMixin
This commit is contained in:
Sam Shleifer
2020-03-26 18:42:09 -04:00
committed by GitHub
parent 1a5aefc95c
commit 2b2a2f8df2
2 changed files with 6 additions and 2 deletions

View File

@@ -108,6 +108,10 @@ class ModuleUtilsMixin:
module.mem_rss_post_forward = 0
module.mem_rss_pre_forward = 0
@property
def device(self):
return next(self.parameters()).device
class PreTrainedModel(nn.Module, ModuleUtilsMixin):
r""" Base class for all models.