[XLNet] Fix mems behavior (#8567)
* fix mems in xlnet * fix use_mems * fix use_mem_len * fix use mems * clean docs * fix tf typo * make xlnet tf for generation work * fix tf test * refactor use cache * add use cache for missing models * correct use_cache in generate * correct use cache in tf generate * fix tf * correct getattr typo * make sylvain happy * change in docs as well * do not apply to cookie cutter statements * fix tf test * make pytorch model fully backward compatible
This commit is contained in:
committed by
GitHub
parent
369f1d77b4
commit
2a6fbe6a40
@@ -206,7 +206,36 @@ class XLNetModelTester:
|
||||
[(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
|
||||
)
|
||||
|
||||
def create_and_check_xlnet_model_use_cache(
|
||||
def create_and_check_use_mems_train(
|
||||
self,
|
||||
config,
|
||||
input_ids_1,
|
||||
input_ids_2,
|
||||
input_ids_q,
|
||||
perm_mask,
|
||||
input_mask,
|
||||
target_mapping,
|
||||
segment_ids,
|
||||
lm_labels,
|
||||
sequence_labels,
|
||||
is_impossible_labels,
|
||||
token_labels,
|
||||
):
|
||||
model = XLNetForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
|
||||
train_size = input_ids_1.shape[0]
|
||||
|
||||
batch_size = 4
|
||||
for i in range(train_size // batch_size + 1):
|
||||
input_ids = input_ids_1[i : (i + 1) * batch_size]
|
||||
labels = sequence_labels[i : (i + 1) * batch_size]
|
||||
outputs = model(input_ids=input_ids, labels=labels, return_dict=True)
|
||||
self.parent.assertIsNone(outputs.mems)
|
||||
self.parent.assertIsNotNone(outputs.loss)
|
||||
|
||||
def create_and_check_xlnet_model_use_mems(
|
||||
self,
|
||||
config,
|
||||
input_ids_1,
|
||||
@@ -234,8 +263,8 @@ class XLNetModelTester:
|
||||
device=torch_device,
|
||||
)
|
||||
causal_mask = torch.triu(causal_mask, diagonal=0)
|
||||
outputs_cache = model(input_ids_1, use_cache=True, perm_mask=causal_mask)
|
||||
outputs_no_cache = model(input_ids_1, use_cache=False, perm_mask=causal_mask)
|
||||
outputs_cache = model(input_ids_1, use_mems=True, perm_mask=causal_mask)
|
||||
outputs_no_cache = model(input_ids_1, use_mems=False, perm_mask=causal_mask)
|
||||
outputs_conf = model(input_ids_1)
|
||||
|
||||
self.parent.assertTrue(len(outputs_cache) == len(outputs_conf))
|
||||
@@ -525,11 +554,15 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_xlnet_base_model(*config_and_inputs)
|
||||
|
||||
def test_xlnet_base_model_use_cache(self):
|
||||
# checking that in auto-regressive mode, :obj:`use_cache` gives the same results
|
||||
def test_xlnet_base_model_use_mems(self):
|
||||
# checking that in auto-regressive mode, :obj:`use_mems` gives the same results
|
||||
self.model_tester.set_seed()
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_xlnet_model_use_cache(*config_and_inputs)
|
||||
self.model_tester.create_and_check_xlnet_model_use_mems(*config_and_inputs)
|
||||
|
||||
def test_seq_classification_use_mems_train(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_use_mems_train(*config_and_inputs)
|
||||
|
||||
def test_xlnet_base_model_with_att_output(self):
|
||||
self.model_tester.set_seed()
|
||||
|
||||
Reference in New Issue
Block a user