finish reformer qa head (#5433)

This commit is contained in:
Patrick von Platen
2020-07-01 18:27:14 +02:00
committed by GitHub
parent d697b6ca75
commit fe81f7d12c
5 changed files with 161 additions and 21 deletions

View File

@@ -29,6 +29,7 @@ if is_torch_available():
ReformerModelWithLMHead,
ReformerTokenizer,
ReformerLayer,
ReformerForQuestionAnswering,
REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
)
import torch
@@ -43,6 +44,7 @@ class ReformerModelTester:
is_training=None,
is_decoder=None,
use_input_mask=None,
use_labels=None,
vocab_size=None,
attention_head_size=None,
hidden_size=None,
@@ -81,6 +83,7 @@ class ReformerModelTester:
self.is_training = is_training
self.is_decoder = is_decoder
self.use_input_mask = use_input_mask
self.use_labels = use_labels
self.vocab_size = vocab_size
self.attention_head_size = attention_head_size
self.hidden_size = hidden_size
@@ -128,6 +131,10 @@ class ReformerModelTester:
if self.use_input_mask:
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
choice_labels = None
if self.use_labels:
choice_labels = ids_tensor([self.batch_size], 2)
config = ReformerConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
@@ -160,14 +167,13 @@ class ReformerModelTester:
config,
input_ids,
input_mask,
choice_labels,
)
def check_loss_output(self, result):
self.parent.assertListEqual(list(result["loss"].size()), [])
def create_and_check_reformer_model(
self, config, input_ids, input_mask,
):
def create_and_check_reformer_model(self, config, input_ids, input_mask, choice_labels):
model = ReformerModel(config=config)
model.to(torch_device)
model.eval()
@@ -182,18 +188,14 @@ class ReformerModelTester:
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, 2 * self.hidden_size],
)
def create_and_check_reformer_model_with_lm_backward(
self, config, input_ids, input_mask,
):
def create_and_check_reformer_model_with_lm_backward(self, config, input_ids, input_mask, choice_labels):
model = ReformerModelWithLMHead(config=config)
model.to(torch_device)
model.eval()
loss = model(input_ids, attention_mask=input_mask, labels=input_ids)[0]
loss.backward()
def create_and_check_reformer_with_lm(
self, config, input_ids, input_mask,
):
def create_and_check_reformer_with_lm(self, config, input_ids, input_mask, choice_labels):
model = ReformerModelWithLMHead(config=config)
model.to(torch_device)
model.eval()
@@ -207,7 +209,7 @@ class ReformerModelTester:
)
self.check_loss_output(result)
def create_and_check_reformer_model_with_attn_mask(self, config, input_ids, input_mask, is_decoder):
def create_and_check_reformer_model_with_attn_mask(self, config, input_ids, input_mask, choice_labels, is_decoder):
# no special position embeddings
config.axial_pos_embds = False
config.is_decoder = is_decoder
@@ -248,7 +250,7 @@ class ReformerModelTester:
self.parent.assertTrue(torch.allclose(output_padded, output_padded_rolled, atol=1e-3))
def create_and_check_reformer_layer_dropout_seed(self, config, input_ids, input_mask, is_decoder):
def create_and_check_reformer_layer_dropout_seed(self, config, input_ids, input_mask, choice_labels, is_decoder):
config.is_decoder = is_decoder
layer = ReformerLayer(config).to(torch_device)
layer.train()
@@ -281,7 +283,7 @@ class ReformerModelTester:
torch.allclose(next_hidden_states, hidden_states + feed_forward_hidden_states, atol=1e-3,)
)
def create_and_check_reformer_feed_forward_chunking(self, config, input_ids, input_mask):
def create_and_check_reformer_feed_forward_chunking(self, config, input_ids, input_mask, choice_labels):
torch.manual_seed(0)
model = ReformerModel(config=config)
model.to(torch_device)
@@ -299,7 +301,7 @@ class ReformerModelTester:
hidden_states_with_chunk = model(input_ids, attention_mask=input_mask)[0]
self.parent.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3))
def create_and_check_reformer_feed_backward_chunking(self, config, input_ids, input_mask):
def create_and_check_reformer_feed_backward_chunking(self, config, input_ids, input_mask, choice_labels):
if not self.is_training:
return
@@ -341,7 +343,7 @@ class ReformerModelTester:
torch.allclose(grad_slice_position_factor_2_chunk, grad_slice_position_factor_2_no_chunk, atol=1e-3)
)
def create_and_check_reformer_random_seed(self, config, input_ids, input_mask):
def create_and_check_reformer_random_seed(self, config, input_ids, input_mask, choice_labels):
layer = ReformerLayer(config).to(torch_device)
layer.train()
@@ -372,7 +374,7 @@ class ReformerModelTester:
seeds.append(layer.feed_forward_seed)
self.parent.assertGreater(len(set(seeds)), 70)
def create_and_check_reformer_model_fp16_forward(self, config, input_ids, input_mask):
def create_and_check_reformer_model_fp16_forward(self, config, input_ids, input_mask, choice_labels):
model = ReformerModel(config=config)
model.to(torch_device)
model.half()
@@ -380,7 +382,7 @@ class ReformerModelTester:
output = model(input_ids, attention_mask=input_mask)[0]
self.parent.assertFalse(torch.isnan(output).any().item())
def create_and_check_reformer_model_fp16_generate(self, config, input_ids, input_mask):
def create_and_check_reformer_model_fp16_generate(self, config, input_ids, input_mask, choice_labels):
model = ReformerModelWithLMHead(config=config)
model.to(torch_device)
model.half()
@@ -388,7 +390,7 @@ class ReformerModelTester:
output = model.generate(input_ids, attention_mask=input_mask, do_sample=False)
self.parent.assertFalse(torch.isnan(output).any().item())
def create_and_check_reformer_no_chunking(self, config, input_ids, input_mask):
def create_and_check_reformer_no_chunking(self, config, input_ids, input_mask, choice_labels):
# force chunk length to be bigger than input_ids
config.lsh_attn_chunk_length = 2 * input_ids.shape[-1]
config.local_attn_chunk_length = 2 * input_ids.shape[-1]
@@ -398,9 +400,25 @@ class ReformerModelTester:
output_logits = model(input_ids, attention_mask=input_mask)[0]
self.parent.assertTrue(output_logits.shape[1] == input_ids.shape[-1])
def create_and_check_longformer_for_question_answering(self, config, input_ids, input_mask, choice_labels):
model = ReformerForQuestionAnswering(config=config)
model.to(torch_device)
model.eval()
loss, start_logits, end_logits = model(
input_ids, attention_mask=input_mask, start_positions=choice_labels, end_positions=choice_labels,
)
result = {
"loss": loss,
"start_logits": start_logits,
"end_logits": end_logits,
}
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
self.check_loss_output(result)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, input_mask,) = config_and_inputs
(config, input_ids, input_mask, choice_labels) = config_and_inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
return config, inputs_dict
@@ -470,7 +488,9 @@ class ReformerTesterMixin:
@require_torch
class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.TestCase):
all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else ()
all_model_classes = (
(ReformerModel, ReformerModelWithLMHead, ReformerForQuestionAnswering) if is_torch_available() else ()
)
all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else ()
test_pruning = False
test_headmasking = False
@@ -483,6 +503,7 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest
"is_training": True,
"is_decoder": False,
"use_input_mask": True,
"use_labels": True,
"vocab_size": 32,
"attention_head_size": 16,
"hidden_size": 32,
@@ -524,7 +545,9 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest
@require_torch
class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.TestCase):
all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else ()
all_model_classes = (
(ReformerModel, ReformerModelWithLMHead, ReformerForQuestionAnswering) if is_torch_available() else ()
)
all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else ()
test_pruning = False
test_headmasking = False
@@ -535,6 +558,7 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.T
"batch_size": 13,
"seq_length": 13,
"use_input_mask": True,
"use_labels": True,
"is_training": False,
"is_decoder": False,
"vocab_size": 32,