Black 20 release
This commit is contained in:
@@ -240,14 +240,19 @@ class ReformerModelTester:
|
||||
half_input_ids = input_ids[:, :half_seq_len]
|
||||
|
||||
# normal padded
|
||||
attn_mask = torch.cat([torch.ones_like(half_input_ids), torch.zeros_like(half_input_ids)], dim=-1,)
|
||||
attn_mask = torch.cat(
|
||||
[torch.ones_like(half_input_ids), torch.zeros_like(half_input_ids)],
|
||||
dim=-1,
|
||||
)
|
||||
input_ids_padded = torch.cat(
|
||||
[half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size)], dim=-1,
|
||||
[half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size)],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
# shifted padded
|
||||
input_ids_roll = torch.cat(
|
||||
[half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size)], dim=-1,
|
||||
[half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size)],
|
||||
dim=-1,
|
||||
)
|
||||
input_ids_roll = torch.roll(input_ids_roll, roll, dims=-1)
|
||||
attn_mask_roll = torch.roll(attn_mask, roll, dims=-1)
|
||||
@@ -283,13 +288,21 @@ class ReformerModelTester:
|
||||
torch.manual_seed(layer.attention_seed)
|
||||
attn_outputs = layer.attention(hidden_states, attention_mask=input_mask)
|
||||
self.parent.assertTrue(
|
||||
torch.allclose(prev_attn_output + attn_outputs.hidden_states, next_attn_output, atol=1e-3,)
|
||||
torch.allclose(
|
||||
prev_attn_output + attn_outputs.hidden_states,
|
||||
next_attn_output,
|
||||
atol=1e-3,
|
||||
)
|
||||
)
|
||||
|
||||
torch.manual_seed(layer.feed_forward_seed)
|
||||
feed_forward_hidden_states = layer.feed_forward(next_attn_output)
|
||||
self.parent.assertTrue(
|
||||
torch.allclose(next_hidden_states, hidden_states + feed_forward_hidden_states, atol=1e-3,)
|
||||
torch.allclose(
|
||||
next_hidden_states,
|
||||
hidden_states + feed_forward_hidden_states,
|
||||
atol=1e-3,
|
||||
)
|
||||
)
|
||||
|
||||
def create_and_check_reformer_feed_backward_chunking(self, config, input_ids, input_mask, choice_labels):
|
||||
@@ -416,7 +429,10 @@ class ReformerModelTester:
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(
|
||||
input_ids, attention_mask=input_mask, start_positions=choice_labels, end_positions=choice_labels,
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
start_positions=choice_labels,
|
||||
end_positions=choice_labels,
|
||||
)
|
||||
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
|
||||
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
|
||||
@@ -468,7 +484,7 @@ class ReformerModelTester:
|
||||
|
||||
class ReformerTesterMixin:
|
||||
"""
|
||||
Reformer Local and Reformer LSH run essentially the same tests
|
||||
Reformer Local and Reformer LSH run essentially the same tests
|
||||
"""
|
||||
|
||||
def test_config(self):
|
||||
@@ -887,7 +903,9 @@ class ReformerIntegrationTests(unittest.TestCase):
|
||||
reformer_output = layer(prev_attn_output=hidden_states.clone(), hidden_states=hidden_states)
|
||||
output_slice = reformer_output.hidden_states[0, 0, :5]
|
||||
expected_output_slice = torch.tensor(
|
||||
[1.6879, -1.3083, -0.4708, 1.3555, -0.6292], dtype=torch.float, device=torch_device,
|
||||
[1.6879, -1.3083, -0.4708, 1.3555, -0.6292],
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
|
||||
@@ -902,11 +920,15 @@ class ReformerIntegrationTests(unittest.TestCase):
|
||||
layer = ReformerLayer(ReformerConfig(**config)).to(torch_device)
|
||||
layer.eval()
|
||||
reformer_output = layer(
|
||||
prev_attn_output=hidden_states.clone(), hidden_states=hidden_states, attention_mask=attn_mask,
|
||||
prev_attn_output=hidden_states.clone(),
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attn_mask,
|
||||
)
|
||||
output_slice = reformer_output.hidden_states[0, 0, :5]
|
||||
expected_output_slice = torch.tensor(
|
||||
[1.6439, -1.2306, -0.5108, 1.3006, -0.6537], dtype=torch.float, device=torch_device,
|
||||
[1.6439, -1.2306, -0.5108, 1.3006, -0.6537],
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
|
||||
@@ -922,7 +944,9 @@ class ReformerIntegrationTests(unittest.TestCase):
|
||||
reformer_output = layer(prev_attn_output=hidden_states, hidden_states=hidden_states)
|
||||
output_slice = reformer_output.hidden_states[0, 0, :5]
|
||||
expected_output_slice = torch.tensor(
|
||||
[1.4212, -2.0576, -0.9688, 1.4599, -0.1344], dtype=torch.float, device=torch_device,
|
||||
[1.4212, -2.0576, -0.9688, 1.4599, -0.1344],
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
|
||||
@@ -935,10 +959,16 @@ class ReformerIntegrationTests(unittest.TestCase):
|
||||
torch.manual_seed(0)
|
||||
layer = ReformerLayer(ReformerConfig(**config)).to(torch_device)
|
||||
layer.eval()
|
||||
reformer_output = layer(prev_attn_output=hidden_states, hidden_states=hidden_states, attention_mask=attn_mask,)
|
||||
reformer_output = layer(
|
||||
prev_attn_output=hidden_states,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attn_mask,
|
||||
)
|
||||
output_slice = reformer_output.hidden_states[0, 0, :5]
|
||||
expected_output_slice = torch.tensor(
|
||||
[1.4750, -2.0235, -0.9743, 1.4463, -0.1269], dtype=torch.float, device=torch_device,
|
||||
[1.4750, -2.0235, -0.9743, 1.4463, -0.1269],
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
|
||||
@@ -953,7 +983,9 @@ class ReformerIntegrationTests(unittest.TestCase):
|
||||
hidden_states = model(input_ids=input_ids, attention_mask=attn_mask)[0]
|
||||
output_slice = hidden_states[0, 0, :5]
|
||||
expected_output_slice = torch.tensor(
|
||||
[-0.9896, -0.9396, -1.0831, -0.0597, 0.2456], dtype=torch.float, device=torch_device,
|
||||
[-0.9896, -0.9396, -1.0831, -0.0597, 0.2456],
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
|
||||
@@ -967,7 +999,9 @@ class ReformerIntegrationTests(unittest.TestCase):
|
||||
hidden_states = model(input_ids=input_ids, attention_mask=attn_mask)[0]
|
||||
output_slice = hidden_states[0, 0, :5]
|
||||
expected_output_slice = torch.tensor(
|
||||
[-1.6791, 0.7171, 0.1594, 0.4063, 1.2584], dtype=torch.float, device=torch_device,
|
||||
[-1.6791, 0.7171, 0.1594, 0.4063, 1.2584],
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
|
||||
@@ -983,7 +1017,9 @@ class ReformerIntegrationTests(unittest.TestCase):
|
||||
hidden_states = model(input_ids=input_ids, attention_mask=attn_mask)[0]
|
||||
output_slice = hidden_states[1, -1, :5]
|
||||
expected_output_slice = torch.tensor(
|
||||
[0.0256, -0.0121, 0.0636, 0.0024, -0.0393], dtype=torch.float, device=torch_device,
|
||||
[0.0256, -0.0121, 0.0636, 0.0024, -0.0393],
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
|
||||
@@ -1005,15 +1041,21 @@ class ReformerIntegrationTests(unittest.TestCase):
|
||||
# check last grads to cover all proable errors
|
||||
grad_slice_word = model.reformer.embeddings.word_embeddings.weight.grad[0, :5]
|
||||
expected_grad_slice_word = torch.tensor(
|
||||
[-0.0005, 0.0001, 0.0002, 0.0003, 0.0006], dtype=torch.float, device=torch_device,
|
||||
[-0.0005, 0.0001, 0.0002, 0.0003, 0.0006],
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
grad_slice_position_factor_1 = model.reformer.embeddings.position_embeddings.weights[0][1, 0, -5:]
|
||||
expected_grad_slice_pos_fac_1 = torch.tensor(
|
||||
[0.0037, -1.3793, -1.0231, -1.5230, -2.5306], dtype=torch.float, device=torch_device,
|
||||
[0.0037, -1.3793, -1.0231, -1.5230, -2.5306],
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
grad_slice_position_factor_2 = model.reformer.embeddings.position_embeddings.weights[1][0, 1, :5]
|
||||
expected_grad_slice_pos_fac_2 = torch.tensor(
|
||||
[-1.3165, 0.5168, 0.7785, 1.0811, -0.9830], dtype=torch.float, device=torch_device,
|
||||
[-1.3165, 0.5168, 0.7785, 1.0811, -0.9830],
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
self.assertTrue(torch.allclose(grad_slice_word, expected_grad_slice_word, atol=1e-3))
|
||||
self.assertTrue(torch.allclose(grad_slice_position_factor_1, expected_grad_slice_pos_fac_1, atol=1e-3))
|
||||
@@ -1038,15 +1080,21 @@ class ReformerIntegrationTests(unittest.TestCase):
|
||||
# check last grads to cover all proable errors
|
||||
grad_slice_word = model.reformer.embeddings.word_embeddings.weight.grad[0, :5]
|
||||
expected_grad_slice_word = torch.tensor(
|
||||
[2.6357e-05, 4.3358e-04, -8.4985e-04, 1.0094e-04, 3.8954e-04], dtype=torch.float, device=torch_device,
|
||||
[2.6357e-05, 4.3358e-04, -8.4985e-04, 1.0094e-04, 3.8954e-04],
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
grad_slice_position_factor_1 = model.reformer.embeddings.position_embeddings.weights[0][1, 0, -5:]
|
||||
expected_grad_slice_pos_fac_1 = torch.tensor(
|
||||
[-0.0984, 0.6283, 0.4282, 1.2960, 0.6897], dtype=torch.float, device=torch_device,
|
||||
[-0.0984, 0.6283, 0.4282, 1.2960, 0.6897],
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
grad_slice_position_factor_2 = model.reformer.embeddings.position_embeddings.weights[1][0, 1, :5]
|
||||
expected_grad_slice_pos_fac_2 = torch.tensor(
|
||||
[0.4626, -0.0231, -0.0172, 0.1081, 0.3805], dtype=torch.float, device=torch_device,
|
||||
[0.4626, -0.0231, -0.0172, 0.1081, 0.3805],
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
self.assertTrue(torch.allclose(grad_slice_word, expected_grad_slice_word, atol=1e-3))
|
||||
self.assertTrue(torch.allclose(grad_slice_position_factor_1, expected_grad_slice_pos_fac_1, atol=1e-3))
|
||||
|
||||
Reference in New Issue
Block a user