@@ -191,8 +191,8 @@ class XLNetModelTester:
|
||||
model = XLNetModel(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
base_model_output = model(input_ids_1)
|
||||
self.parent.assertEqual(len(base_model_output), 2)
|
||||
no_mems_outputs = model(input_ids_1)
|
||||
self.parent.assertEqual(len(no_mems_outputs), 1)
|
||||
|
||||
self.parent.assertListEqual(
|
||||
list(result["outputs"].size()), [self.batch_size, self.seq_length, self.hidden_size],
|
||||
@@ -202,72 +202,6 @@ class XLNetModelTester:
|
||||
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
|
||||
)
|
||||
|
||||
def create_and_check_xlnet_model_use_cache(
|
||||
self,
|
||||
config,
|
||||
input_ids_1,
|
||||
input_ids_2,
|
||||
input_ids_q,
|
||||
perm_mask,
|
||||
input_mask,
|
||||
target_mapping,
|
||||
segment_ids,
|
||||
lm_labels,
|
||||
sequence_labels,
|
||||
is_impossible_labels,
|
||||
token_labels,
|
||||
):
|
||||
model = XLNetModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# first forward pass
|
||||
causal_mask = torch.ones(
|
||||
input_ids_1.shape[0],
|
||||
input_ids_1.shape[1],
|
||||
input_ids_1.shape[1],
|
||||
dtype=torch.float,
|
||||
device=input_ids_1.device,
|
||||
)
|
||||
causal_mask = torch.triu(causal_mask, diagonal=0)
|
||||
outputs_cache = model(input_ids_1, use_cache=True, perm_mask=causal_mask)
|
||||
outputs_no_cache = model(input_ids_1, use_cache=False, perm_mask=causal_mask)
|
||||
outputs_conf = model(input_ids_1)
|
||||
|
||||
self.parent.assertTrue(len(outputs_cache) == len(outputs_conf))
|
||||
self.parent.assertTrue(len(outputs_cache) == len(outputs_no_cache) + 1)
|
||||
|
||||
output, mems = outputs_cache
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||
|
||||
# append to next input_ids and token_type_ids
|
||||
next_input_ids = torch.cat([input_ids_1, next_tokens], dim=-1)
|
||||
|
||||
# causal mask
|
||||
causal_mask = torch.ones(
|
||||
input_ids_1.shape[0],
|
||||
input_ids_1.shape[1] + 1,
|
||||
input_ids_1.shape[1] + 1,
|
||||
dtype=torch.float,
|
||||
device=input_ids_1.device,
|
||||
)
|
||||
causal_mask = torch.triu(causal_mask, diagonal=0)
|
||||
single_mask = torch.ones(input_ids_1.shape[0], 1, 1)
|
||||
|
||||
# second forward pass
|
||||
output_from_no_past, _ = model(next_input_ids, perm_mask=causal_mask)
|
||||
output_from_past, _ = model(next_tokens, mems=mems, perm_mask=single_mask)
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||
output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()
|
||||
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
|
||||
|
||||
# test that outputs are equal for slice
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||
|
||||
def create_and_check_xlnet_base_model_with_att_output(
|
||||
self,
|
||||
config,
|
||||
@@ -517,6 +451,7 @@ class XLNetModelTester:
|
||||
|
||||
@require_torch
|
||||
class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (
|
||||
(
|
||||
XLNetModel,
|
||||
@@ -547,12 +482,6 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_xlnet_base_model(*config_and_inputs)
|
||||
|
||||
def test_xlnet_base_model_use_cache(self):
|
||||
# checking that in auto-regressive mode, `use_cache` gives the same results
|
||||
self.model_tester.set_seed()
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_xlnet_model_use_cache(*config_and_inputs)
|
||||
|
||||
def test_xlnet_base_model_with_att_output(self):
|
||||
self.model_tester.set_seed()
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
@@ -945,33 +874,33 @@ class XLNetModelLanguageGenerationTest(unittest.TestCase):
|
||||
9,
|
||||
69,
|
||||
27,
|
||||
442,
|
||||
50,
|
||||
551,
|
||||
22,
|
||||
2771,
|
||||
24,
|
||||
11335,
|
||||
20,
|
||||
4901,
|
||||
19,
|
||||
21,
|
||||
45,
|
||||
668,
|
||||
21,
|
||||
18,
|
||||
9225,
|
||||
2198,
|
||||
9,
|
||||
69,
|
||||
27,
|
||||
442,
|
||||
416,
|
||||
41,
|
||||
1499,
|
||||
22,
|
||||
2771,
|
||||
24,
|
||||
11335,
|
||||
20,
|
||||
755,
|
||||
18,
|
||||
9225,
|
||||
2198,
|
||||
14285,
|
||||
9,
|
||||
69,
|
||||
12943,
|
||||
4354,
|
||||
153,
|
||||
27,
|
||||
442,
|
||||
1499,
|
||||
22,
|
||||
642,
|
||||
22,
|
||||
2771,
|
||||
]
|
||||
# In 1991, the remains of Russian Tsar Nicholas II and his family (except for Alexei and Maria)
|
||||
# are discovered. The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich,
|
||||
@@ -981,8 +910,9 @@ class XLNetModelLanguageGenerationTest(unittest.TestCase):
|
||||
# him for making such an accusation, Rasputin watches as the man is chased outside and beaten.
|
||||
# Twenty years later, Rasputin sees a vision of the Virgin Mary, prompting him to become a priest.
|
||||
# Rasputin quickly becomes famous, with people, even a bishop, begging for his blessing.
|
||||
# <sep><cls>, Rasputin is asked to perform magic. He is asked to perform a ritual of the Virgin Mary.
|
||||
# He is asked to perform a ritual of the Virgin Mary. He is asked to perform
|
||||
# <sep><cls>, Rasputin is asked to perform magic.
|
||||
# He is not able to perform magic, and his father and
|
||||
# the men are forced to leave the monastery. Rasputin is forced to return to
|
||||
|
||||
output_ids = model.generate(input_ids, max_length=200, do_sample=False)
|
||||
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
||||
|
||||
Reference in New Issue
Block a user