[XLNet] Fix mems behavior (#8567)

* fix mems in xlnet

* fix use_mems

* fix use_mem_len

* fix use mems

* clean docs

* fix tf typo

* make xlnet tf for generation work

* fix tf test

* refactor use cache

* add use cache for missing models

* correct use_cache in generate

* correct use cache in tf generate

* fix tf

* correct getattr typo

* make sylvain happy

* change in docs as well

* do not apply to cookie cutter statements

* fix tf test

* make pytorch model fully backward compatible
This commit is contained in:
Patrick von Platen
2020-11-25 22:54:59 +01:00
committed by GitHub
parent 369f1d77b4
commit 2a6fbe6a40
47 changed files with 259 additions and 134 deletions

View File

@@ -206,7 +206,36 @@ class XLNetModelTester:
[(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
)
def create_and_check_xlnet_model_use_cache(
def create_and_check_use_mems_train(
self,
config,
input_ids_1,
input_ids_2,
input_ids_q,
perm_mask,
input_mask,
target_mapping,
segment_ids,
lm_labels,
sequence_labels,
is_impossible_labels,
token_labels,
):
model = XLNetForSequenceClassification(config)
model.to(torch_device)
model.train()
train_size = input_ids_1.shape[0]
batch_size = 4
for i in range(train_size // batch_size + 1):
input_ids = input_ids_1[i : (i + 1) * batch_size]
labels = sequence_labels[i : (i + 1) * batch_size]
outputs = model(input_ids=input_ids, labels=labels, return_dict=True)
self.parent.assertIsNone(outputs.mems)
self.parent.assertIsNotNone(outputs.loss)
def create_and_check_xlnet_model_use_mems(
self,
config,
input_ids_1,
@@ -234,8 +263,8 @@ class XLNetModelTester:
device=torch_device,
)
causal_mask = torch.triu(causal_mask, diagonal=0)
outputs_cache = model(input_ids_1, use_cache=True, perm_mask=causal_mask)
outputs_no_cache = model(input_ids_1, use_cache=False, perm_mask=causal_mask)
outputs_cache = model(input_ids_1, use_mems=True, perm_mask=causal_mask)
outputs_no_cache = model(input_ids_1, use_mems=False, perm_mask=causal_mask)
outputs_conf = model(input_ids_1)
self.parent.assertTrue(len(outputs_cache) == len(outputs_conf))
@@ -525,11 +554,15 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlnet_base_model(*config_and_inputs)
def test_xlnet_base_model_use_cache(self):
# checking that in auto-regressive mode, :obj:`use_cache` gives the same results
def test_xlnet_base_model_use_mems(self):
# checking that in auto-regressive mode, :obj:`use_mems` gives the same results
self.model_tester.set_seed()
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlnet_model_use_cache(*config_and_inputs)
self.model_tester.create_and_check_xlnet_model_use_mems(*config_and_inputs)
def test_seq_classification_use_mems_train(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_use_mems_train(*config_and_inputs)
def test_xlnet_base_model_with_att_output(self):
self.model_tester.set_seed()