finish reformer qa head (#5433)
This commit is contained in:
committed by
GitHub
parent
d697b6ca75
commit
fe81f7d12c
@@ -29,6 +29,7 @@ if is_torch_available():
|
||||
ReformerModelWithLMHead,
|
||||
ReformerTokenizer,
|
||||
ReformerLayer,
|
||||
ReformerForQuestionAnswering,
|
||||
REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
)
|
||||
import torch
|
||||
@@ -43,6 +44,7 @@ class ReformerModelTester:
|
||||
is_training=None,
|
||||
is_decoder=None,
|
||||
use_input_mask=None,
|
||||
use_labels=None,
|
||||
vocab_size=None,
|
||||
attention_head_size=None,
|
||||
hidden_size=None,
|
||||
@@ -81,6 +83,7 @@ class ReformerModelTester:
|
||||
self.is_training = is_training
|
||||
self.is_decoder = is_decoder
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.attention_head_size = attention_head_size
|
||||
self.hidden_size = hidden_size
|
||||
@@ -128,6 +131,10 @@ class ReformerModelTester:
|
||||
if self.use_input_mask:
|
||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
|
||||
choice_labels = None
|
||||
if self.use_labels:
|
||||
choice_labels = ids_tensor([self.batch_size], 2)
|
||||
|
||||
config = ReformerConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
@@ -160,14 +167,13 @@ class ReformerModelTester:
|
||||
config,
|
||||
input_ids,
|
||||
input_mask,
|
||||
choice_labels,
|
||||
)
|
||||
|
||||
def check_loss_output(self, result):
|
||||
self.parent.assertListEqual(list(result["loss"].size()), [])
|
||||
|
||||
def create_and_check_reformer_model(
|
||||
self, config, input_ids, input_mask,
|
||||
):
|
||||
def create_and_check_reformer_model(self, config, input_ids, input_mask, choice_labels):
|
||||
model = ReformerModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
@@ -182,18 +188,14 @@ class ReformerModelTester:
|
||||
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, 2 * self.hidden_size],
|
||||
)
|
||||
|
||||
def create_and_check_reformer_model_with_lm_backward(
|
||||
self, config, input_ids, input_mask,
|
||||
):
|
||||
def create_and_check_reformer_model_with_lm_backward(self, config, input_ids, input_mask, choice_labels):
|
||||
model = ReformerModelWithLMHead(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss = model(input_ids, attention_mask=input_mask, labels=input_ids)[0]
|
||||
loss.backward()
|
||||
|
||||
def create_and_check_reformer_with_lm(
|
||||
self, config, input_ids, input_mask,
|
||||
):
|
||||
def create_and_check_reformer_with_lm(self, config, input_ids, input_mask, choice_labels):
|
||||
model = ReformerModelWithLMHead(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
@@ -207,7 +209,7 @@ class ReformerModelTester:
|
||||
)
|
||||
self.check_loss_output(result)
|
||||
|
||||
def create_and_check_reformer_model_with_attn_mask(self, config, input_ids, input_mask, is_decoder):
|
||||
def create_and_check_reformer_model_with_attn_mask(self, config, input_ids, input_mask, choice_labels, is_decoder):
|
||||
# no special position embeddings
|
||||
config.axial_pos_embds = False
|
||||
config.is_decoder = is_decoder
|
||||
@@ -248,7 +250,7 @@ class ReformerModelTester:
|
||||
|
||||
self.parent.assertTrue(torch.allclose(output_padded, output_padded_rolled, atol=1e-3))
|
||||
|
||||
def create_and_check_reformer_layer_dropout_seed(self, config, input_ids, input_mask, is_decoder):
|
||||
def create_and_check_reformer_layer_dropout_seed(self, config, input_ids, input_mask, choice_labels, is_decoder):
|
||||
config.is_decoder = is_decoder
|
||||
layer = ReformerLayer(config).to(torch_device)
|
||||
layer.train()
|
||||
@@ -281,7 +283,7 @@ class ReformerModelTester:
|
||||
torch.allclose(next_hidden_states, hidden_states + feed_forward_hidden_states, atol=1e-3,)
|
||||
)
|
||||
|
||||
def create_and_check_reformer_feed_forward_chunking(self, config, input_ids, input_mask):
|
||||
def create_and_check_reformer_feed_forward_chunking(self, config, input_ids, input_mask, choice_labels):
|
||||
torch.manual_seed(0)
|
||||
model = ReformerModel(config=config)
|
||||
model.to(torch_device)
|
||||
@@ -299,7 +301,7 @@ class ReformerModelTester:
|
||||
hidden_states_with_chunk = model(input_ids, attention_mask=input_mask)[0]
|
||||
self.parent.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3))
|
||||
|
||||
def create_and_check_reformer_feed_backward_chunking(self, config, input_ids, input_mask):
|
||||
def create_and_check_reformer_feed_backward_chunking(self, config, input_ids, input_mask, choice_labels):
|
||||
if not self.is_training:
|
||||
return
|
||||
|
||||
@@ -341,7 +343,7 @@ class ReformerModelTester:
|
||||
torch.allclose(grad_slice_position_factor_2_chunk, grad_slice_position_factor_2_no_chunk, atol=1e-3)
|
||||
)
|
||||
|
||||
def create_and_check_reformer_random_seed(self, config, input_ids, input_mask):
|
||||
def create_and_check_reformer_random_seed(self, config, input_ids, input_mask, choice_labels):
|
||||
layer = ReformerLayer(config).to(torch_device)
|
||||
layer.train()
|
||||
|
||||
@@ -372,7 +374,7 @@ class ReformerModelTester:
|
||||
seeds.append(layer.feed_forward_seed)
|
||||
self.parent.assertGreater(len(set(seeds)), 70)
|
||||
|
||||
def create_and_check_reformer_model_fp16_forward(self, config, input_ids, input_mask):
|
||||
def create_and_check_reformer_model_fp16_forward(self, config, input_ids, input_mask, choice_labels):
|
||||
model = ReformerModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.half()
|
||||
@@ -380,7 +382,7 @@ class ReformerModelTester:
|
||||
output = model(input_ids, attention_mask=input_mask)[0]
|
||||
self.parent.assertFalse(torch.isnan(output).any().item())
|
||||
|
||||
def create_and_check_reformer_model_fp16_generate(self, config, input_ids, input_mask):
|
||||
def create_and_check_reformer_model_fp16_generate(self, config, input_ids, input_mask, choice_labels):
|
||||
model = ReformerModelWithLMHead(config=config)
|
||||
model.to(torch_device)
|
||||
model.half()
|
||||
@@ -388,7 +390,7 @@ class ReformerModelTester:
|
||||
output = model.generate(input_ids, attention_mask=input_mask, do_sample=False)
|
||||
self.parent.assertFalse(torch.isnan(output).any().item())
|
||||
|
||||
def create_and_check_reformer_no_chunking(self, config, input_ids, input_mask):
|
||||
def create_and_check_reformer_no_chunking(self, config, input_ids, input_mask, choice_labels):
|
||||
# force chunk length to be bigger than input_ids
|
||||
config.lsh_attn_chunk_length = 2 * input_ids.shape[-1]
|
||||
config.local_attn_chunk_length = 2 * input_ids.shape[-1]
|
||||
@@ -398,9 +400,25 @@ class ReformerModelTester:
|
||||
output_logits = model(input_ids, attention_mask=input_mask)[0]
|
||||
self.parent.assertTrue(output_logits.shape[1] == input_ids.shape[-1])
|
||||
|
||||
def create_and_check_longformer_for_question_answering(self, config, input_ids, input_mask, choice_labels):
|
||||
model = ReformerForQuestionAnswering(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, start_logits, end_logits = model(
|
||||
input_ids, attention_mask=input_mask, start_positions=choice_labels, end_positions=choice_labels,
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"start_logits": start_logits,
|
||||
"end_logits": end_logits,
|
||||
}
|
||||
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
|
||||
self.check_loss_output(result)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(config, input_ids, input_mask,) = config_and_inputs
|
||||
(config, input_ids, input_mask, choice_labels) = config_and_inputs
|
||||
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
@@ -470,7 +488,9 @@ class ReformerTesterMixin:
|
||||
|
||||
@require_torch
|
||||
class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else ()
|
||||
all_model_classes = (
|
||||
(ReformerModel, ReformerModelWithLMHead, ReformerForQuestionAnswering) if is_torch_available() else ()
|
||||
)
|
||||
all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
@@ -483,6 +503,7 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest
|
||||
"is_training": True,
|
||||
"is_decoder": False,
|
||||
"use_input_mask": True,
|
||||
"use_labels": True,
|
||||
"vocab_size": 32,
|
||||
"attention_head_size": 16,
|
||||
"hidden_size": 32,
|
||||
@@ -524,7 +545,9 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest
|
||||
|
||||
@require_torch
|
||||
class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else ()
|
||||
all_model_classes = (
|
||||
(ReformerModel, ReformerModelWithLMHead, ReformerForQuestionAnswering) if is_torch_available() else ()
|
||||
)
|
||||
all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
@@ -535,6 +558,7 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.T
|
||||
"batch_size": 13,
|
||||
"seq_length": 13,
|
||||
"use_input_mask": True,
|
||||
"use_labels": True,
|
||||
"is_training": False,
|
||||
"is_decoder": False,
|
||||
"vocab_size": 32,
|
||||
|
||||
Reference in New Issue
Block a user