[Reformer] Make random seed generator available on random seed and not on model device (#6244)
* improve if else statement random seeds * Apply suggestions from code review * Update src/transformers/modeling_reformer.py
This commit is contained in:
committed by
GitHub
parent
d5b0a0e235
commit
6c9ba1d8fc
@@ -1399,15 +1399,16 @@ class ReformerLayer(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# randomize seeds
|
# randomize seeds
|
||||||
if next(self.parameters()).device.type == "cuda":
|
# use cuda generator if available
|
||||||
|
if len(torch.cuda.default_generators) > 0:
|
||||||
# GPU
|
# GPU
|
||||||
device_idx = torch.cuda.current_device()
|
device_idx = torch.cuda.current_device()
|
||||||
self.attention_seed = torch.cuda.default_generators[device_idx].seed()
|
self.attention_seed = torch.cuda.default_generators[device_idx].seed()
|
||||||
torch.cuda.manual_seed(self.attention_seed)
|
|
||||||
else:
|
else:
|
||||||
# CPU
|
# CPU
|
||||||
self.attention_seed = int(torch.seed() % sys.maxsize)
|
self.attention_seed = int(torch.seed() % sys.maxsize)
|
||||||
torch.manual_seed(self.attention_seed)
|
|
||||||
|
torch.manual_seed(self.attention_seed)
|
||||||
|
|
||||||
def _init_feed_forward_seed(self):
|
def _init_feed_forward_seed(self):
|
||||||
"""
|
"""
|
||||||
@@ -1417,17 +1418,17 @@ class ReformerLayer(nn.Module):
|
|||||||
call and 1 forward call in backward
|
call and 1 forward call in backward
|
||||||
to recalculate activations.
|
to recalculate activations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# randomize seeds
|
# randomize seeds
|
||||||
if next(self.parameters()).device.type == "cuda":
|
# use cuda generator if available
|
||||||
|
if len(torch.cuda.default_generators) > 0:
|
||||||
# GPU
|
# GPU
|
||||||
device_idx = torch.cuda.current_device()
|
device_idx = torch.cuda.current_device()
|
||||||
self.feed_forward_seed = torch.cuda.default_generators[device_idx].seed()
|
self.feed_forward_seed = torch.cuda.default_generators[device_idx].seed()
|
||||||
torch.cuda.manual_seed(self.feed_forward_seed)
|
|
||||||
else:
|
else:
|
||||||
# CPU
|
# CPU
|
||||||
self.feed_forward_seed = int(torch.seed() % sys.maxsize)
|
self.feed_forward_seed = int(torch.seed() % sys.maxsize)
|
||||||
torch.manual_seed(self.feed_forward_seed)
|
|
||||||
|
torch.manual_seed(self.feed_forward_seed)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Reference in New Issue
Block a user