Black 20 release
This commit is contained in:
@@ -104,10 +104,20 @@ class XLNetModelTester:
|
||||
|
||||
input_ids_q = ids_tensor([self.batch_size, self.seq_length + 1], self.vocab_size)
|
||||
perm_mask = torch.zeros(
|
||||
self.batch_size, self.seq_length + 1, self.seq_length + 1, dtype=torch.float, device=torch_device,
|
||||
self.batch_size,
|
||||
self.seq_length + 1,
|
||||
self.seq_length + 1,
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
|
||||
target_mapping = torch.zeros(self.batch_size, 1, self.seq_length + 1, dtype=torch.float, device=torch_device,)
|
||||
target_mapping = torch.zeros(
|
||||
self.batch_size,
|
||||
1,
|
||||
self.seq_length + 1,
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
target_mapping[:, 0, -1] = 1.0 # predict last token
|
||||
|
||||
sequence_labels = None
|
||||
@@ -217,7 +227,11 @@ class XLNetModelTester:
|
||||
|
||||
# first forward pass
|
||||
causal_mask = torch.ones(
|
||||
input_ids_1.shape[0], input_ids_1.shape[1], input_ids_1.shape[1], dtype=torch.float, device=torch_device,
|
||||
input_ids_1.shape[0],
|
||||
input_ids_1.shape[1],
|
||||
input_ids_1.shape[1],
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
causal_mask = torch.triu(causal_mask, diagonal=0)
|
||||
outputs_cache = model(input_ids_1, use_cache=True, perm_mask=causal_mask)
|
||||
@@ -363,7 +377,11 @@ class XLNetModelTester:
|
||||
|
||||
total_loss, mems = result_with_labels.to_tuple()
|
||||
|
||||
result_with_labels = model(input_ids_1, start_positions=sequence_labels, end_positions=sequence_labels,)
|
||||
result_with_labels = model(
|
||||
input_ids_1,
|
||||
start_positions=sequence_labels,
|
||||
end_positions=sequence_labels,
|
||||
)
|
||||
|
||||
total_loss, mems = result_with_labels.to_tuple()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user