Add generate() functionality to TF 2.0 (#3063)

* add first copy past test to tf 2 generate

* add tf top_k_top_p_filter fn

* add generate function for TF

* add generate function for TF

* implemented generate for all models expect transfoXL

* implemented generate for all models expect transfoXL

* implemented generate for all models expect transfoXL

* make style

* change permission of test file to correct ones

* delete ipdb

* delete ipdb

* fix bug and finish simple gpt2 integration test

* clean test file

* clean test file

* make style

* make style

* make style

* make style

* change import style

* change import style

* make style

* make style

* add decorators

* add decorators

* fix tf ctrl bug dim => axis in TF

* make style

* make style

* refactored test file

* refactored test file

* take out test_torch_tf_conversion if nothing is defined

* take out test_torch_tf_conversion if nothing is defined

* remove useless files

* remove useless files

* fix conflicts

* fix conflicts

* fix conflicts

* fix conflicts

* fix conflicts

* solve conflicts

* solve conflicts

* fix conflicts

* fix conflicts

* merge conflicts

* delete ipdb

* exposed top_k_top_p_filtering fns

* delete weirdly created w! file

* add comment to test tf common modeling

* fix conflicts

* fix conflicts

* make style

* merge conflicts

* make style

* change tf.tensor.shape to shape_list(tensor)
This commit is contained in:
Patrick von Platen
2020-03-03 15:42:15 +01:00
committed by GitHub
parent b31f715019
commit 4134100363
20 changed files with 892 additions and 62 deletions

View File

@@ -36,6 +36,7 @@ if is_torch_available():
BertModel,
BertConfig,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
top_k_top_p_filtering,
)
@@ -263,7 +264,7 @@ class ModelTesterMixin:
# Prepare head_mask
# Set require_grad after having prepared the tensor to avoid error (leaf variable has been moved into the graph interior)
head_mask = torch.ones(
self.model_tester.num_hidden_layers, self.model_tester.num_attention_heads, device=torch_device
self.model_tester.num_hidden_layers, self.model_tester.num_attention_heads, device=torch_device,
)
head_mask[0, 0] = 0
head_mask[-1, :-1] = 0
@@ -303,7 +304,7 @@ class ModelTesterMixin:
return
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
(config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
if "head_mask" in inputs_dict:
del inputs_dict["head_mask"]
@@ -313,7 +314,10 @@ class ModelTesterMixin:
model = model_class(config=config)
model.to(torch_device)
model.eval()
heads_to_prune = {0: list(range(1, self.model_tester.num_attention_heads)), -1: [0]}
heads_to_prune = {
0: list(range(1, self.model_tester.num_attention_heads)),
-1: [0],
}
model.prune_heads(heads_to_prune)
with torch.no_grad():
outputs = model(**inputs_dict)
@@ -329,7 +333,7 @@ class ModelTesterMixin:
return
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
(config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
if "head_mask" in inputs_dict:
del inputs_dict["head_mask"]
@@ -339,7 +343,10 @@ class ModelTesterMixin:
model = model_class(config=config)
model.to(torch_device)
model.eval()
heads_to_prune = {0: list(range(1, self.model_tester.num_attention_heads)), -1: [0]}
heads_to_prune = {
0: list(range(1, self.model_tester.num_attention_heads)),
-1: [0],
}
model.prune_heads(heads_to_prune)
with tempfile.TemporaryDirectory() as temp_dir_name:
@@ -359,7 +366,7 @@ class ModelTesterMixin:
return
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
(config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
if "head_mask" in inputs_dict:
del inputs_dict["head_mask"]
@@ -367,7 +374,10 @@ class ModelTesterMixin:
config.output_attentions = True
config.output_hidden_states = False
heads_to_prune = {0: list(range(1, self.model_tester.num_attention_heads)), -1: [0]}
heads_to_prune = {
0: list(range(1, self.model_tester.num_attention_heads)),
-1: [0],
}
config.pruned_heads = heads_to_prune
model = model_class(config=config)
@@ -387,7 +397,7 @@ class ModelTesterMixin:
return
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
(config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
if "head_mask" in inputs_dict:
del inputs_dict["head_mask"]
@@ -465,7 +475,7 @@ class ModelTesterMixin:
)
def test_resize_tokens_embeddings(self):
original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
(original_config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
if not self.test_resize_embeddings:
return
@@ -634,6 +644,7 @@ class ModelTesterMixin:
self._check_generated_tokens(model.generate(input_ids, num_return_sequences=3))
# batch_size > 1, greedy
self._check_generated_tokens(model.generate(input_ids, do_sample=False))
# batch_size > 1, num_beams > 1, sample
self._check_generated_tokens(model.generate(input_ids, num_beams=3, num_return_sequences=3,))
# batch_size > 1, num_beams > 1, greedy
@@ -704,3 +715,110 @@ class ModelUtilsTest(unittest.TestCase):
self.assertEqual(model.config.output_attentions, True)
self.assertEqual(model.config.output_hidden_states, True)
self.assertEqual(model.config, config)
@require_torch
class UtilsFunctionsTest(unittest.TestCase):
# tests whether the top_k_top_p function behaves as expected
def test_top_k_top_p_filtering(self):
logits = torch.tensor(
[
[
8.2220991, # 3rd highest value; idx. 0
-0.5620044,
5.23229752,
4.0386393,
-6.8798378,
-0.54785802,
-3.2012153,
2.92777176,
1.88171953,
7.35341276, # 5th highest value; idx. 9
8.43207833, # 2nd highest value; idx. 10
-9.85711836,
-5.96209236,
-1.13039161,
-7.1115294,
-0.8369633,
-5.3186408,
7.06427407,
0.81369344,
-0.82023817,
-5.9179796,
0.58813443,
-6.99778438,
4.71551189,
-0.18771637,
7.44020759, # 4th highest value; idx. 25
9.38450987, # 1st highest value; idx. 26
2.12662941,
-9.32562038,
2.35652522,
], # cummulative prob of 5 highest values <= 0.6
[
0.58425518,
4.53139238,
-5.57510464,
-6.28030699,
-7.19529503,
-4.02122551,
1.39337037,
-6.06707057,
1.59480517,
-9.643119,
0.03907799,
0.67231762,
-8.88206726,
6.27115922, # 4th highest value; idx. 13
2.28520723,
4.82767506,
4.30421368,
8.8275313, # 2nd highest value; idx. 17
5.44029958, # 5th highest value; idx. 18
-4.4735794,
7.38579536, # 3rd highest value; idx. 20
-2.91051663,
2.61946077,
-2.5674762,
-9.48959302,
-4.02922645,
-1.35416918,
9.67702323, # 1st highest value; idx. 27
-5.89478553,
1.85370467,
], # cummulative prob of 5 highest values <= 0.6
],
dtype=torch.float,
device=torch_device,
)
non_inf_expected_idx = torch.tensor(
[[0, 0], [0, 9], [0, 10], [0, 25], [0, 26], [1, 13], [1, 17], [1, 18], [1, 20], [1, 27]],
dtype=torch.long,
device=torch_device,
) # expected non filtered idx as noted above
non_inf_expected_output = torch.tensor(
[
8.2221,
7.3534,
8.4321,
7.4402,
9.3845,
6.2712,
8.8275,
5.4403,
7.3858,
9.6770,
], # expected non filtered values as noted above
dtype=torch.float,
device=torch_device,
)
output = top_k_top_p_filtering(logits, top_k=10, top_p=0.6, min_tokens_to_keep=4)
non_inf_output = output[output != -float("inf")].to(device=torch_device)
non_inf_idx = (output != -float("inf")).nonzero().to(device=torch_device)
self.assertTrue(torch.allclose(non_inf_expected_output, non_inf_output, atol=1e-12))
self.assertTrue(torch.all(torch.eq(non_inf_expected_idx, non_inf_idx)))