[Longformer] Major Refactor (#5219)
* refactor naming * add small slow test * refactor * refactor naming * rename selected to extra * big global attention refactor * make style * refactor naming * save intermed * refactor functions * finish function refactor * fix tests * fix longformer * fix longformer * fix longformer * fix all tests but one * finish longformer * address sams and izs comments * fix transpose
This commit is contained in:
committed by
GitHub
parent
e0d58ddb65
commit
d697b6ca75
File diff suppressed because it is too large
Load Diff
@@ -811,7 +811,7 @@ class ModelTesterMixin:
|
|||||||
# Wrap model in nn.DataParallel
|
# Wrap model in nn.DataParallel
|
||||||
model = torch.nn.DataParallel(model)
|
model = torch.nn.DataParallel(model)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
_ = model(**inputs_dict)
|
_ = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
|
||||||
|
|
||||||
global_rng = random.Random()
|
global_rng = random.Random()
|
||||||
|
|||||||
@@ -115,6 +115,18 @@ class LongformerModelTester:
|
|||||||
def check_loss_output(self, result):
|
def check_loss_output(self, result):
|
||||||
self.parent.assertListEqual(list(result["loss"].size()), [])
|
self.parent.assertListEqual(list(result["loss"].size()), [])
|
||||||
|
|
||||||
|
def create_and_check_attention_mask_determinism(
|
||||||
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
):
|
||||||
|
model = LongformerModel(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
|
||||||
|
output_with_mask = model(input_ids, attention_mask=attention_mask)[0]
|
||||||
|
output_without_mask = model(input_ids)[0]
|
||||||
|
self.parent.assertTrue(torch.allclose(output_with_mask[0, 0, :5], output_without_mask[0, 0, :5], atol=1e-4))
|
||||||
|
|
||||||
def create_and_check_longformer_model(
|
def create_and_check_longformer_model(
|
||||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
):
|
):
|
||||||
@@ -134,6 +146,36 @@ class LongformerModelTester:
|
|||||||
)
|
)
|
||||||
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
|
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
|
||||||
|
|
||||||
|
def create_and_check_longformer_model_with_global_attention_mask(
|
||||||
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
):
|
||||||
|
model = LongformerModel(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
global_attention_mask = input_mask.clone()
|
||||||
|
global_attention_mask[:, input_mask.shape[-1] // 2] = 0
|
||||||
|
global_attention_mask = global_attention_mask.to(torch_device)
|
||||||
|
|
||||||
|
sequence_output, pooled_output = model(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=input_mask,
|
||||||
|
global_attention_mask=global_attention_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
)
|
||||||
|
sequence_output, pooled_output = model(
|
||||||
|
input_ids, token_type_ids=token_type_ids, global_attention_mask=global_attention_mask
|
||||||
|
)
|
||||||
|
sequence_output, pooled_output = model(input_ids, global_attention_mask=global_attention_mask)
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"sequence_output": sequence_output,
|
||||||
|
"pooled_output": pooled_output,
|
||||||
|
}
|
||||||
|
self.parent.assertListEqual(
|
||||||
|
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
||||||
|
)
|
||||||
|
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
|
||||||
|
|
||||||
def create_and_check_longformer_for_masked_lm(
|
def create_and_check_longformer_for_masked_lm(
|
||||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
):
|
):
|
||||||
@@ -243,7 +285,13 @@ class LongformerModelTester:
|
|||||||
token_labels,
|
token_labels,
|
||||||
choice_labels,
|
choice_labels,
|
||||||
) = config_and_inputs
|
) = config_and_inputs
|
||||||
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
|
global_attention_mask = torch.zeros_like(input_ids)
|
||||||
|
inputs_dict = {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"token_type_ids": token_type_ids,
|
||||||
|
"attention_mask": input_mask,
|
||||||
|
"global_attention_mask": global_attention_mask,
|
||||||
|
}
|
||||||
return config, inputs_dict
|
return config, inputs_dict
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_question_answering(self):
|
def prepare_config_and_inputs_for_question_answering(self):
|
||||||
@@ -277,11 +325,10 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
(
|
(
|
||||||
LongformerModel,
|
LongformerModel,
|
||||||
LongformerForMaskedLM,
|
LongformerForMaskedLM,
|
||||||
# TODO: make tests pass for those models
|
LongformerForSequenceClassification,
|
||||||
# LongformerForSequenceClassification,
|
LongformerForQuestionAnswering,
|
||||||
# LongformerForQuestionAnswering,
|
LongformerForTokenClassification,
|
||||||
# LongformerForTokenClassification,
|
LongformerForMultipleChoice,
|
||||||
# LongformerForMultipleChoice,
|
|
||||||
)
|
)
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else ()
|
else ()
|
||||||
@@ -298,6 +345,14 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_longformer_model(*config_and_inputs)
|
self.model_tester.create_and_check_longformer_model(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_longformer_model_attention_mask_determinism(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_attention_mask_determinism(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_longformer_model_global_attention_mask(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_longformer_model_with_global_attention_mask(*config_and_inputs)
|
||||||
|
|
||||||
def test_longformer_for_masked_lm(self):
|
def test_longformer_for_masked_lm(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_longformer_for_masked_lm(*config_and_inputs)
|
self.model_tester.create_and_check_longformer_for_masked_lm(*config_and_inputs)
|
||||||
@@ -325,15 +380,31 @@ class LongformerModelIntegrationTest(unittest.TestCase):
|
|||||||
model = LongformerModel.from_pretrained("allenai/longformer-base-4096")
|
model = LongformerModel.from_pretrained("allenai/longformer-base-4096")
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
|
|
||||||
|
# 'Hello world!'
|
||||||
|
input_ids = torch.tensor([[0, 20920, 232, 328, 1437, 2]], dtype=torch.long, device=torch_device)
|
||||||
|
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
|
||||||
|
output = model(input_ids, attention_mask=attention_mask)[0]
|
||||||
|
output_without_mask = model(input_ids)[0]
|
||||||
|
|
||||||
|
expected_output_slice = torch.tensor([0.0549, 0.1087, -0.1119, -0.0368, 0.0250], device=torch_device)
|
||||||
|
self.assertTrue(torch.allclose(output[0, 0, -5:], expected_output_slice, atol=1e-4))
|
||||||
|
self.assertTrue(torch.allclose(output_without_mask[0, 0, -5:], expected_output_slice, atol=1e-4))
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_inference_no_head_long(self):
|
||||||
|
model = LongformerModel.from_pretrained("allenai/longformer-base-4096")
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
# 'Hello world! ' repeated 1000 times
|
# 'Hello world! ' repeated 1000 times
|
||||||
input_ids = torch.tensor(
|
input_ids = torch.tensor(
|
||||||
[[0] + [20920, 232, 328, 1437] * 1000 + [2]], dtype=torch.long, device=torch_device
|
[[0] + [20920, 232, 328, 1437] * 1000 + [2]], dtype=torch.long, device=torch_device
|
||||||
) # long input
|
) # long input
|
||||||
|
|
||||||
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device)
|
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device)
|
||||||
attention_mask[:, [1, 4, 21]] = 2 # Set global attention on a few random positions
|
global_attention_mask = torch.zeros(input_ids.shape, dtype=torch.long, device=input_ids.device)
|
||||||
|
global_attention_mask[:, [1, 4, 21]] = 1 # Set global attention on a few random positions
|
||||||
|
|
||||||
output = model(input_ids, attention_mask=attention_mask)[0]
|
output = model(input_ids, attention_mask=attention_mask, global_attention_mask=global_attention_mask)[0]
|
||||||
|
|
||||||
expected_output_sum = torch.tensor(74585.8594, device=torch_device)
|
expected_output_sum = torch.tensor(74585.8594, device=torch_device)
|
||||||
expected_output_mean = torch.tensor(0.0243, device=torch_device)
|
expected_output_mean = torch.tensor(0.0243, device=torch_device)
|
||||||
@@ -341,7 +412,7 @@ class LongformerModelIntegrationTest(unittest.TestCase):
|
|||||||
self.assertTrue(torch.allclose(output.mean(), expected_output_mean, atol=1e-4))
|
self.assertTrue(torch.allclose(output.mean(), expected_output_mean, atol=1e-4))
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_inference_masked_lm(self):
|
def test_inference_masked_lm_long(self):
|
||||||
model = LongformerForMaskedLM.from_pretrained("allenai/longformer-base-4096")
|
model = LongformerForMaskedLM.from_pretrained("allenai/longformer-base-4096")
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
|
|
||||||
@@ -352,9 +423,9 @@ class LongformerModelIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
loss, prediction_scores = model(input_ids, labels=input_ids)
|
loss, prediction_scores = model(input_ids, labels=input_ids)
|
||||||
|
|
||||||
expected_loss = torch.tensor(0.0620, device=torch_device)
|
expected_loss = torch.tensor(0.0074, device=torch_device)
|
||||||
expected_prediction_scores_sum = torch.tensor(-6.1599e08, device=torch_device)
|
expected_prediction_scores_sum = torch.tensor(-6.1048e08, device=torch_device)
|
||||||
expected_prediction_scores_mean = torch.tensor(-3.0622, device=torch_device)
|
expected_prediction_scores_mean = torch.tensor(-3.0348, device=torch_device)
|
||||||
input_ids = input_ids.to(torch_device)
|
input_ids = input_ids.to(torch_device)
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(loss, expected_loss, atol=1e-4))
|
self.assertTrue(torch.allclose(loss, expected_loss, atol=1e-4))
|
||||||
|
|||||||
Reference in New Issue
Block a user