TF Longformer (#5764)
* improve names and tests longformer * more and better tests for longformer * add first tf test * finalize tf basic op functions * fix merge * tf shape test passes * narrow down discrepancies * make longformer local attn tf work * correct tf longformer * add first global attn function * add more global longformer func * advance tf longformer * finish global attn * upload big model * finish all tests * correct false any statement * fix common tests * make all tests pass except keras save load * fix some tests * fix torch test import * finish tests * fix test * fix torch tf tests * add docs * finish docs * Update src/transformers/modeling_longformer.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Update src/transformers/modeling_tf_longformer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * apply Lysandres suggestions * reverse to assert statement because function will fail otherwise * applying sylvains recommendations * Update src/transformers/modeling_longformer.py Co-authored-by: Sam Shleifer <sshleifer@gmail.com> * Update src/transformers/modeling_tf_longformer.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
committed by
GitHub
parent
3425936643
commit
00bb0b25ed
@@ -33,6 +33,7 @@ if is_torch_available():
|
||||
LongformerForTokenClassification,
|
||||
LongformerForQuestionAnswering,
|
||||
LongformerForMultipleChoice,
|
||||
LongformerSelfAttention,
|
||||
)
|
||||
|
||||
|
||||
@@ -325,7 +326,209 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
self.model_tester.create_and_check_longformer_for_multiple_choice(*config_and_inputs)
|
||||
|
||||
|
||||
@require_torch
|
||||
class LongformerModelIntegrationTest(unittest.TestCase):
|
||||
def _get_hidden_states(self):
|
||||
return torch.tensor(
|
||||
[
|
||||
[
|
||||
[
|
||||
4.98332758e-01,
|
||||
2.69175139e00,
|
||||
-7.08081422e-03,
|
||||
1.04915401e00,
|
||||
-1.83476661e00,
|
||||
7.67220476e-01,
|
||||
2.98580543e-01,
|
||||
2.84803992e-02,
|
||||
],
|
||||
[
|
||||
-7.58357372e-01,
|
||||
4.20635998e-01,
|
||||
-4.04739919e-02,
|
||||
1.59924145e-01,
|
||||
2.05135748e00,
|
||||
-1.15997978e00,
|
||||
5.37166397e-01,
|
||||
2.62873606e-01,
|
||||
],
|
||||
[
|
||||
-1.69438001e00,
|
||||
4.17574660e-01,
|
||||
-1.49196962e00,
|
||||
-1.76483717e00,
|
||||
-1.94566312e-01,
|
||||
-1.71183858e00,
|
||||
7.72903565e-01,
|
||||
-1.11557056e00,
|
||||
],
|
||||
[
|
||||
5.44028163e-01,
|
||||
2.05466114e-01,
|
||||
-3.63045868e-01,
|
||||
2.41865062e-01,
|
||||
3.20348382e-01,
|
||||
-9.05611176e-01,
|
||||
-1.92690727e-01,
|
||||
-1.19917547e00,
|
||||
],
|
||||
]
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
def test_diagonalize(self):
|
||||
hidden_states = self._get_hidden_states()
|
||||
hidden_states = hidden_states.reshape((1, 8, 4)) # set seq length = 8, hidden dim = 4
|
||||
chunked_hidden_states = LongformerSelfAttention._chunk(hidden_states, window_overlap=2)
|
||||
window_overlap_size = chunked_hidden_states.shape[2]
|
||||
self.assertTrue(window_overlap_size == 4)
|
||||
|
||||
padded_hidden_states = LongformerSelfAttention._pad_and_diagonalize(chunked_hidden_states)
|
||||
|
||||
self.assertTrue(padded_hidden_states.shape[-1] == chunked_hidden_states.shape[-1] + window_overlap_size - 1)
|
||||
|
||||
# first row => [0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000]
|
||||
self.assertTrue(torch.allclose(padded_hidden_states[0, 0, 0, :4], chunked_hidden_states[0, 0, 0], atol=1e-3))
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
padded_hidden_states[0, 0, 0, 4:],
|
||||
torch.zeros((3,), device=torch_device, dtype=torch.float32),
|
||||
atol=1e-3,
|
||||
)
|
||||
)
|
||||
# last row => [0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629]
|
||||
self.assertTrue(torch.allclose(padded_hidden_states[0, 0, -1, 3:], chunked_hidden_states[0, 0, -1], atol=1e-3))
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
padded_hidden_states[0, 0, -1, :3],
|
||||
torch.zeros((3,), device=torch_device, dtype=torch.float32),
|
||||
atol=1e-3,
|
||||
)
|
||||
)
|
||||
|
||||
def test_pad_and_transpose_last_two_dims(self):
|
||||
hidden_states = self._get_hidden_states()
|
||||
self.assertTrue(hidden_states.shape, (1, 8, 4))
|
||||
padding = (0, 0, 0, 1)
|
||||
|
||||
padded_hidden_states = LongformerSelfAttention._pad_and_transpose_last_two_dims(hidden_states, padding)
|
||||
self.assertTrue(padded_hidden_states.shape, (1, 8, 5))
|
||||
|
||||
expected_added_dim = torch.zeros((5,), device=torch_device, dtype=torch.float32)
|
||||
self.assertTrue(torch.allclose(expected_added_dim, padded_hidden_states[0, -1, :], atol=1e-6))
|
||||
self.assertTrue(torch.allclose(hidden_states[0, -1, :], padded_hidden_states.view(1, -1)[0, 24:32], atol=1e-6))
|
||||
|
||||
def test_chunk(self):
|
||||
hidden_states = self._get_hidden_states()
|
||||
batch_size = 1
|
||||
seq_length = 8
|
||||
hidden_size = 4
|
||||
hidden_states = hidden_states.reshape((batch_size, seq_length, hidden_size))
|
||||
|
||||
chunked_hidden_states = LongformerSelfAttention._chunk(hidden_states, window_overlap=2)
|
||||
|
||||
# expected slices across chunk and seq length dim
|
||||
expected_slice_along_seq_length = torch.tensor(
|
||||
[0.4983, -0.7584, -1.6944], device=torch_device, dtype=torch.float32
|
||||
)
|
||||
expected_slice_along_chunk = torch.tensor(
|
||||
[0.4983, -1.8348, -0.7584, 2.0514], device=torch_device, dtype=torch.float32
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(chunked_hidden_states[0, :, 0, 0], expected_slice_along_seq_length, atol=1e-3))
|
||||
self.assertTrue(torch.allclose(chunked_hidden_states[0, 0, :, 0], expected_slice_along_chunk, atol=1e-3))
|
||||
self.assertTrue(chunked_hidden_states.shape, (1, 3, 4, 4))
|
||||
|
||||
def test_mask_invalid_locations(self):
|
||||
hidden_states = self._get_hidden_states()
|
||||
|
||||
batch_size = 1
|
||||
seq_length = 8
|
||||
hidden_size = 4
|
||||
hidden_states = hidden_states.reshape((batch_size, seq_length, hidden_size))
|
||||
chunked_hidden_states = LongformerSelfAttention._chunk(hidden_states, window_overlap=2)
|
||||
|
||||
hid_states_1 = chunked_hidden_states.clone()
|
||||
LongformerSelfAttention._mask_invalid_locations(hid_states_1, 1)
|
||||
self.assertTrue(torch.isinf(hid_states_1).sum().item() == 8)
|
||||
|
||||
hid_states_2 = chunked_hidden_states.clone()
|
||||
LongformerSelfAttention._mask_invalid_locations(hid_states_2, 2)
|
||||
self.assertTrue(torch.isinf(hid_states_2).sum().item() == 24)
|
||||
|
||||
hid_states_3 = chunked_hidden_states.clone()[:, :, :, :3]
|
||||
LongformerSelfAttention._mask_invalid_locations(hid_states_3, 2)
|
||||
self.assertTrue(torch.isinf(hid_states_3).sum().item() == 24)
|
||||
|
||||
hid_states_4 = chunked_hidden_states.clone()[:, :, 2:, :]
|
||||
LongformerSelfAttention._mask_invalid_locations(hid_states_4, 2)
|
||||
self.assertTrue(torch.isinf(hid_states_4).sum().item() == 12)
|
||||
|
||||
def test_layer_local_attn(self):
|
||||
model = LongformerModel.from_pretrained("patrickvonplaten/longformer-random-tiny")
|
||||
model.eval()
|
||||
layer = model.encoder.layer[0].attention.self.to(torch_device)
|
||||
hidden_states = self._get_hidden_states()
|
||||
batch_size, seq_length, hidden_size = hidden_states.size()
|
||||
attention_mask = torch.zeros((batch_size, 1, 1, seq_length), dtype=torch.float32, device=torch_device)
|
||||
attention_mask[:, :, :, -2:] = -10000
|
||||
output_hidden_states = layer(hidden_states, attention_mask)[0]
|
||||
|
||||
self.assertTrue(output_hidden_states.shape, (1, 4, 8))
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
output_hidden_states[0, 1],
|
||||
torch.tensor(
|
||||
[0.0019, 0.0122, -0.0171, -0.0256, -0.0300, 0.0173, -0.0115, 0.0048],
|
||||
dtype=torch.float32,
|
||||
device=torch_device,
|
||||
),
|
||||
atol=1e-3,
|
||||
)
|
||||
)
|
||||
|
||||
def test_layer_global_attn(self):
|
||||
model = LongformerModel.from_pretrained("patrickvonplaten/longformer-random-tiny")
|
||||
model.eval()
|
||||
layer = model.encoder.layer[0].attention.self.to(torch_device)
|
||||
hidden_states = torch.cat([self._get_hidden_states(), self._get_hidden_states() - 0.5], dim=0)
|
||||
batch_size, seq_length, hidden_size = hidden_states.size()
|
||||
attention_mask = torch.zeros((batch_size, 1, 1, seq_length), dtype=torch.float32, device=torch_device)
|
||||
|
||||
# create attn mask
|
||||
attention_mask[0, :, :, -2:] = 10000.0
|
||||
attention_mask[0, :, :, -1:] = -10000.0
|
||||
attention_mask[1, :, :, 1:] = 10000.0
|
||||
output_hidden_states = layer(hidden_states, attention_mask)[0]
|
||||
|
||||
self.assertTrue(output_hidden_states.shape, (2, 4, 8))
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
output_hidden_states[0, 2],
|
||||
torch.tensor(
|
||||
[-0.0651, -0.0393, 0.0309, -0.0342, -0.0066, -0.0155, -0.0209, -0.0494],
|
||||
dtype=torch.float32,
|
||||
device=torch_device,
|
||||
),
|
||||
atol=1e-3,
|
||||
)
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
output_hidden_states[1, -2],
|
||||
torch.tensor(
|
||||
[-0.0405, -0.0384, 0.0396, -0.0374, -0.0341, 0.0136, 0.0014, -0.0571],
|
||||
dtype=torch.float32,
|
||||
device=torch_device,
|
||||
),
|
||||
atol=1e-3,
|
||||
)
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_inference_no_head(self):
|
||||
model = LongformerModel.from_pretrained("allenai/longformer-base-4096")
|
||||
@@ -371,13 +574,13 @@ class LongformerModelIntegrationTest(unittest.TestCase):
|
||||
input_ids = torch.tensor(
|
||||
[[0] + [20920, 232, 328, 1437] * 1000 + [2]], dtype=torch.long, device=torch_device
|
||||
) # long input
|
||||
input_ids = input_ids.to(torch_device)
|
||||
|
||||
loss, prediction_scores = model(input_ids, labels=input_ids)
|
||||
|
||||
expected_loss = torch.tensor(0.0074, device=torch_device)
|
||||
expected_prediction_scores_sum = torch.tensor(-6.1048e08, device=torch_device)
|
||||
expected_prediction_scores_mean = torch.tensor(-3.0348, device=torch_device)
|
||||
input_ids = input_ids.to(torch_device)
|
||||
|
||||
self.assertTrue(torch.allclose(loss, expected_loss, atol=1e-4))
|
||||
self.assertTrue(torch.allclose(prediction_scores.sum(), expected_prediction_scores_sum, atol=1e-4))
|
||||
|
||||
Reference in New Issue
Block a user