Black 20 release

This commit is contained in:
Lysandre
2020-08-26 17:20:22 +02:00
parent e78c110338
commit a75c64d80c
191 changed files with 4807 additions and 3503 deletions

View File

@@ -104,10 +104,20 @@ class XLNetModelTester:
input_ids_q = ids_tensor([self.batch_size, self.seq_length + 1], self.vocab_size)
perm_mask = torch.zeros(
self.batch_size, self.seq_length + 1, self.seq_length + 1, dtype=torch.float, device=torch_device,
self.batch_size,
self.seq_length + 1,
self.seq_length + 1,
dtype=torch.float,
device=torch_device,
)
perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
target_mapping = torch.zeros(self.batch_size, 1, self.seq_length + 1, dtype=torch.float, device=torch_device,)
target_mapping = torch.zeros(
self.batch_size,
1,
self.seq_length + 1,
dtype=torch.float,
device=torch_device,
)
target_mapping[:, 0, -1] = 1.0 # predict last token
sequence_labels = None
@@ -217,7 +227,11 @@ class XLNetModelTester:
# first forward pass
causal_mask = torch.ones(
input_ids_1.shape[0], input_ids_1.shape[1], input_ids_1.shape[1], dtype=torch.float, device=torch_device,
input_ids_1.shape[0],
input_ids_1.shape[1],
input_ids_1.shape[1],
dtype=torch.float,
device=torch_device,
)
causal_mask = torch.triu(causal_mask, diagonal=0)
outputs_cache = model(input_ids_1, use_cache=True, perm_mask=causal_mask)
@@ -363,7 +377,11 @@ class XLNetModelTester:
total_loss, mems = result_with_labels.to_tuple()
result_with_labels = model(input_ids_1, start_positions=sequence_labels, end_positions=sequence_labels,)
result_with_labels = model(
input_ids_1,
start_positions=sequence_labels,
end_positions=sequence_labels,
)
total_loss, mems = result_with_labels.to_tuple()