@@ -104,7 +104,7 @@ class LlamaModelTester:
|
|||||||
|
|
||||||
input_mask = None
|
input_mask = None
|
||||||
if self.use_input_mask:
|
if self.use_input_mask:
|
||||||
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length))
|
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
|
||||||
|
|
||||||
token_type_ids = None
|
token_type_ids = None
|
||||||
if self.use_token_type_ids:
|
if self.use_token_type_ids:
|
||||||
|
|||||||
@@ -107,7 +107,7 @@ class MistralModelTester:
|
|||||||
|
|
||||||
input_mask = None
|
input_mask = None
|
||||||
if self.use_input_mask:
|
if self.use_input_mask:
|
||||||
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length))
|
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
|
||||||
|
|
||||||
token_type_ids = None
|
token_type_ids = None
|
||||||
if self.use_token_type_ids:
|
if self.use_token_type_ids:
|
||||||
|
|||||||
@@ -104,7 +104,7 @@ class PersimmonModelTester:
|
|||||||
|
|
||||||
input_mask = None
|
input_mask = None
|
||||||
if self.use_input_mask:
|
if self.use_input_mask:
|
||||||
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length))
|
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
|
||||||
|
|
||||||
token_type_ids = None
|
token_type_ids = None
|
||||||
if self.use_token_type_ids:
|
if self.use_token_type_ids:
|
||||||
|
|||||||
Reference in New Issue
Block a user