fix for pytorch < 1.6 (#6300)

This commit is contained in:
Patrick von Platen
2020-08-06 21:14:46 +02:00
committed by GitHub
parent 2804fff839
commit 118ecfd427

View File

@@ -1400,7 +1400,7 @@ class ReformerLayer(nn.Module):
# randomize seeds
# use cuda generator if available
if len(torch.cuda.default_generators) > 0:
if hasattr(torch.cuda, "default_generators") and len(torch.cuda.default_generators) > 0:
# GPU
device_idx = torch.cuda.current_device()
self.attention_seed = torch.cuda.default_generators[device_idx].seed()
@@ -1420,7 +1420,7 @@ class ReformerLayer(nn.Module):
"""
# randomize seeds
# use cuda generator if available
if len(torch.cuda.default_generators) > 0:
if hasattr(torch.cuda, "default_generators") and len(torch.cuda.default_generators) > 0:
# GPU
device_idx = torch.cuda.current_device()
self.feed_forward_seed = torch.cuda.default_generators[device_idx].seed()