XLNet use_cache refactor (#5770)

Slightly breaking change, changes functionality for `use_cache` in XLNet: if use_cache is True and mem_len is 0 or None (which is the case in the base model config), the model behaves like GPT-2 and returns mems to be used as past in generation. At training time `use_cache` is overriden and always True.
This commit is contained in:
Teven
2020-07-17 20:24:16 +02:00
committed by GitHub
parent 9750e1300c
commit 0b2da0e592
4 changed files with 135 additions and 47 deletions

View File

@@ -191,8 +191,8 @@ class XLNetModelTester:
model = XLNetModel(config)
model.to(torch_device)
model.eval()
no_mems_outputs = model(input_ids_1)
self.parent.assertEqual(len(no_mems_outputs), 1)
base_model_output = model(input_ids_1)
self.parent.assertEqual(len(base_model_output), 2)
self.parent.assertListEqual(
list(result["outputs"].size()), [self.batch_size, self.seq_length, self.hidden_size],
@@ -202,6 +202,72 @@ class XLNetModelTester:
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
)
def create_and_check_xlnet_model_use_cache(
self,
config,
input_ids_1,
input_ids_2,
input_ids_q,
perm_mask,
input_mask,
target_mapping,
segment_ids,
lm_labels,
sequence_labels,
is_impossible_labels,
token_labels,
):
model = XLNetModel(config=config)
model.to(torch_device)
model.eval()
# first forward pass
causal_mask = torch.ones(
input_ids_1.shape[0],
input_ids_1.shape[1],
input_ids_1.shape[1],
dtype=torch.float,
device=input_ids_1.device,
)
causal_mask = torch.triu(causal_mask, diagonal=0)
outputs_cache = model(input_ids_1, use_cache=True, perm_mask=causal_mask)
outputs_no_cache = model(input_ids_1, use_cache=False, perm_mask=causal_mask)
outputs_conf = model(input_ids_1)
self.parent.assertTrue(len(outputs_cache) == len(outputs_conf))
self.parent.assertTrue(len(outputs_cache) == len(outputs_no_cache) + 1)
output, mems = outputs_cache
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
# append to next input_ids and token_type_ids
next_input_ids = torch.cat([input_ids_1, next_tokens], dim=-1)
# causal mask
causal_mask = torch.ones(
input_ids_1.shape[0],
input_ids_1.shape[1] + 1,
input_ids_1.shape[1] + 1,
dtype=torch.float,
device=input_ids_1.device,
)
causal_mask = torch.triu(causal_mask, diagonal=0)
single_mask = torch.ones(input_ids_1.shape[0], 1, 1)
# second forward pass
output_from_no_past, _ = model(next_input_ids, perm_mask=causal_mask)
output_from_past, _ = model(next_tokens, mems=mems, perm_mask=single_mask)
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def create_and_check_xlnet_base_model_with_att_output(
self,
config,
@@ -451,7 +517,6 @@ class XLNetModelTester:
@require_torch
class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (
(
XLNetModel,
@@ -482,6 +547,12 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlnet_base_model(*config_and_inputs)
def test_xlnet_base_model_use_cache(self):
# checking that in auto-regressive mode, `use_cache` gives the same results
self.model_tester.set_seed()
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlnet_model_use_cache(*config_and_inputs)
def test_xlnet_base_model_with_att_output(self):
self.model_tester.set_seed()
config_and_inputs = self.model_tester.prepare_config_and_inputs()
@@ -874,33 +945,33 @@ class XLNetModelLanguageGenerationTest(unittest.TestCase):
9,
69,
27,
50,
551,
442,
22,
2771,
4901,
19,
21,
45,
668,
21,
24,
11335,
20,
18,
416,
41,
1499,
22,
755,
18,
14285,
9225,
2198,
9,
12943,
4354,
153,
69,
27,
1499,
442,
22,
642,
2771,
24,
11335,
20,
18,
9225,
2198,
9,
69,
27,
442,
22,
2771,
]
# In 1991, the remains of Russian Tsar Nicholas II and his family (except for Alexei and Maria)
# are discovered. The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich,
@@ -910,9 +981,8 @@ class XLNetModelLanguageGenerationTest(unittest.TestCase):
# him for making such an accusation, Rasputin watches as the man is chased outside and beaten.
# Twenty years later, Rasputin sees a vision of the Virgin Mary, prompting him to become a priest.
# Rasputin quickly becomes famous, with people, even a bishop, begging for his blessing.
# <sep><cls>, Rasputin is asked to perform magic.
# He is not able to perform magic, and his father and
# the men are forced to leave the monastery. Rasputin is forced to return to
# <sep><cls>, Rasputin is asked to perform magic. He is asked to perform a ritual of the Virgin Mary.
# He is asked to perform a ritual of the Virgin Mary. He is asked to perform
output_ids = model.generate(input_ids, max_length=200, do_sample=False)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)