Fix device of masks in tests (#27887)

fix device of mask in tests
This commit is contained in:
fxmarty
2023-12-07 13:34:43 +01:00
committed by GitHub
parent fc71e815f6
commit c99f254763
3 changed files with 3 additions and 3 deletions

View File

@@ -104,7 +104,7 @@ class LlamaModelTester:
input_mask = None
if self.use_input_mask:
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length))
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
token_type_ids = None
if self.use_token_type_ids:

View File

@@ -107,7 +107,7 @@ class MistralModelTester:
input_mask = None
if self.use_input_mask:
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length))
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
token_type_ids = None
if self.use_token_type_ids:

View File

@@ -104,7 +104,7 @@ class PersimmonModelTester:
input_mask = None
if self.use_input_mask:
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length))
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
token_type_ids = None
if self.use_token_type_ids: