[Reformer] Make random seed generator available on random seed and not on model device (#6244)

* improve if else statement random seeds

* Apply suggestions from code review

* Update src/transformers/modeling_reformer.py
This commit is contained in:
Patrick von Platen
2020-08-04 19:22:43 +02:00
committed by GitHub
parent d5b0a0e235
commit 6c9ba1d8fc

View File

@@ -1399,15 +1399,16 @@ class ReformerLayer(nn.Module):
""" """
# randomize seeds # randomize seeds
if next(self.parameters()).device.type == "cuda": # use cuda generator if available
if len(torch.cuda.default_generators) > 0:
# GPU # GPU
device_idx = torch.cuda.current_device() device_idx = torch.cuda.current_device()
self.attention_seed = torch.cuda.default_generators[device_idx].seed() self.attention_seed = torch.cuda.default_generators[device_idx].seed()
torch.cuda.manual_seed(self.attention_seed)
else: else:
# CPU # CPU
self.attention_seed = int(torch.seed() % sys.maxsize) self.attention_seed = int(torch.seed() % sys.maxsize)
torch.manual_seed(self.attention_seed)
torch.manual_seed(self.attention_seed)
def _init_feed_forward_seed(self): def _init_feed_forward_seed(self):
""" """
@@ -1417,17 +1418,17 @@ class ReformerLayer(nn.Module):
call and 1 forward call in backward call and 1 forward call in backward
to recalculate activations. to recalculate activations.
""" """
# randomize seeds # randomize seeds
if next(self.parameters()).device.type == "cuda": # use cuda generator if available
if len(torch.cuda.default_generators) > 0:
# GPU # GPU
device_idx = torch.cuda.current_device() device_idx = torch.cuda.current_device()
self.feed_forward_seed = torch.cuda.default_generators[device_idx].seed() self.feed_forward_seed = torch.cuda.default_generators[device_idx].seed()
torch.cuda.manual_seed(self.feed_forward_seed)
else: else:
# CPU # CPU
self.feed_forward_seed = int(torch.seed() % sys.maxsize) self.feed_forward_seed = int(torch.seed() % sys.maxsize)
torch.manual_seed(self.feed_forward_seed)
torch.manual_seed(self.feed_forward_seed)
def forward( def forward(
self, self,