Add generate() functionality to TF 2.0 (#3063)
* add first copy past test to tf 2 generate * add tf top_k_top_p_filter fn * add generate function for TF * add generate function for TF * implemented generate for all models expect transfoXL * implemented generate for all models expect transfoXL * implemented generate for all models expect transfoXL * make style * change permission of test file to correct ones * delete ipdb * delete ipdb * fix bug and finish simple gpt2 integration test * clean test file * clean test file * make style * make style * make style * make style * change import style * change import style * make style * make style * add decorators * add decorators * fix tf ctrl bug dim => axis in TF * make style * make style * refactored test file * refactored test file * take out test_torch_tf_conversion if nothing is defined * take out test_torch_tf_conversion if nothing is defined * remove useless files * remove useless files * fix conflicts * fix conflicts * fix conflicts * fix conflicts * fix conflicts * solve conflicts * solve conflicts * fix conflicts * fix conflicts * merge conflicts * delete ipdb * exposed top_k_top_p_filtering fns * delete weirdly created w! file * add comment to test tf common modeling * fix conflicts * fix conflicts * make style * merge conflicts * make style * change tf.tensor.shape to shape_list(tensor)
This commit is contained in:
committed by
GitHub
parent
b31f715019
commit
4134100363
@@ -36,6 +36,7 @@ if is_torch_available():
|
||||
BertModel,
|
||||
BertConfig,
|
||||
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
top_k_top_p_filtering,
|
||||
)
|
||||
|
||||
|
||||
@@ -263,7 +264,7 @@ class ModelTesterMixin:
|
||||
# Prepare head_mask
|
||||
# Set require_grad after having prepared the tensor to avoid error (leaf variable has been moved into the graph interior)
|
||||
head_mask = torch.ones(
|
||||
self.model_tester.num_hidden_layers, self.model_tester.num_attention_heads, device=torch_device
|
||||
self.model_tester.num_hidden_layers, self.model_tester.num_attention_heads, device=torch_device,
|
||||
)
|
||||
head_mask[0, 0] = 0
|
||||
head_mask[-1, :-1] = 0
|
||||
@@ -303,7 +304,7 @@ class ModelTesterMixin:
|
||||
return
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
(config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
if "head_mask" in inputs_dict:
|
||||
del inputs_dict["head_mask"]
|
||||
@@ -313,7 +314,10 @@ class ModelTesterMixin:
|
||||
model = model_class(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
heads_to_prune = {0: list(range(1, self.model_tester.num_attention_heads)), -1: [0]}
|
||||
heads_to_prune = {
|
||||
0: list(range(1, self.model_tester.num_attention_heads)),
|
||||
-1: [0],
|
||||
}
|
||||
model.prune_heads(heads_to_prune)
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs_dict)
|
||||
@@ -329,7 +333,7 @@ class ModelTesterMixin:
|
||||
return
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
(config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
if "head_mask" in inputs_dict:
|
||||
del inputs_dict["head_mask"]
|
||||
@@ -339,7 +343,10 @@ class ModelTesterMixin:
|
||||
model = model_class(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
heads_to_prune = {0: list(range(1, self.model_tester.num_attention_heads)), -1: [0]}
|
||||
heads_to_prune = {
|
||||
0: list(range(1, self.model_tester.num_attention_heads)),
|
||||
-1: [0],
|
||||
}
|
||||
model.prune_heads(heads_to_prune)
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir_name:
|
||||
@@ -359,7 +366,7 @@ class ModelTesterMixin:
|
||||
return
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
(config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
if "head_mask" in inputs_dict:
|
||||
del inputs_dict["head_mask"]
|
||||
@@ -367,7 +374,10 @@ class ModelTesterMixin:
|
||||
config.output_attentions = True
|
||||
config.output_hidden_states = False
|
||||
|
||||
heads_to_prune = {0: list(range(1, self.model_tester.num_attention_heads)), -1: [0]}
|
||||
heads_to_prune = {
|
||||
0: list(range(1, self.model_tester.num_attention_heads)),
|
||||
-1: [0],
|
||||
}
|
||||
config.pruned_heads = heads_to_prune
|
||||
|
||||
model = model_class(config=config)
|
||||
@@ -387,7 +397,7 @@ class ModelTesterMixin:
|
||||
return
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
(config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
if "head_mask" in inputs_dict:
|
||||
del inputs_dict["head_mask"]
|
||||
@@ -465,7 +475,7 @@ class ModelTesterMixin:
|
||||
)
|
||||
|
||||
def test_resize_tokens_embeddings(self):
|
||||
original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
(original_config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if not self.test_resize_embeddings:
|
||||
return
|
||||
|
||||
@@ -634,6 +644,7 @@ class ModelTesterMixin:
|
||||
self._check_generated_tokens(model.generate(input_ids, num_return_sequences=3))
|
||||
# batch_size > 1, greedy
|
||||
self._check_generated_tokens(model.generate(input_ids, do_sample=False))
|
||||
|
||||
# batch_size > 1, num_beams > 1, sample
|
||||
self._check_generated_tokens(model.generate(input_ids, num_beams=3, num_return_sequences=3,))
|
||||
# batch_size > 1, num_beams > 1, greedy
|
||||
@@ -704,3 +715,110 @@ class ModelUtilsTest(unittest.TestCase):
|
||||
self.assertEqual(model.config.output_attentions, True)
|
||||
self.assertEqual(model.config.output_hidden_states, True)
|
||||
self.assertEqual(model.config, config)
|
||||
|
||||
|
||||
@require_torch
|
||||
class UtilsFunctionsTest(unittest.TestCase):
|
||||
|
||||
# tests whether the top_k_top_p function behaves as expected
|
||||
def test_top_k_top_p_filtering(self):
|
||||
logits = torch.tensor(
|
||||
[
|
||||
[
|
||||
8.2220991, # 3rd highest value; idx. 0
|
||||
-0.5620044,
|
||||
5.23229752,
|
||||
4.0386393,
|
||||
-6.8798378,
|
||||
-0.54785802,
|
||||
-3.2012153,
|
||||
2.92777176,
|
||||
1.88171953,
|
||||
7.35341276, # 5th highest value; idx. 9
|
||||
8.43207833, # 2nd highest value; idx. 10
|
||||
-9.85711836,
|
||||
-5.96209236,
|
||||
-1.13039161,
|
||||
-7.1115294,
|
||||
-0.8369633,
|
||||
-5.3186408,
|
||||
7.06427407,
|
||||
0.81369344,
|
||||
-0.82023817,
|
||||
-5.9179796,
|
||||
0.58813443,
|
||||
-6.99778438,
|
||||
4.71551189,
|
||||
-0.18771637,
|
||||
7.44020759, # 4th highest value; idx. 25
|
||||
9.38450987, # 1st highest value; idx. 26
|
||||
2.12662941,
|
||||
-9.32562038,
|
||||
2.35652522,
|
||||
], # cummulative prob of 5 highest values <= 0.6
|
||||
[
|
||||
0.58425518,
|
||||
4.53139238,
|
||||
-5.57510464,
|
||||
-6.28030699,
|
||||
-7.19529503,
|
||||
-4.02122551,
|
||||
1.39337037,
|
||||
-6.06707057,
|
||||
1.59480517,
|
||||
-9.643119,
|
||||
0.03907799,
|
||||
0.67231762,
|
||||
-8.88206726,
|
||||
6.27115922, # 4th highest value; idx. 13
|
||||
2.28520723,
|
||||
4.82767506,
|
||||
4.30421368,
|
||||
8.8275313, # 2nd highest value; idx. 17
|
||||
5.44029958, # 5th highest value; idx. 18
|
||||
-4.4735794,
|
||||
7.38579536, # 3rd highest value; idx. 20
|
||||
-2.91051663,
|
||||
2.61946077,
|
||||
-2.5674762,
|
||||
-9.48959302,
|
||||
-4.02922645,
|
||||
-1.35416918,
|
||||
9.67702323, # 1st highest value; idx. 27
|
||||
-5.89478553,
|
||||
1.85370467,
|
||||
], # cummulative prob of 5 highest values <= 0.6
|
||||
],
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
non_inf_expected_idx = torch.tensor(
|
||||
[[0, 0], [0, 9], [0, 10], [0, 25], [0, 26], [1, 13], [1, 17], [1, 18], [1, 20], [1, 27]],
|
||||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
) # expected non filtered idx as noted above
|
||||
|
||||
non_inf_expected_output = torch.tensor(
|
||||
[
|
||||
8.2221,
|
||||
7.3534,
|
||||
8.4321,
|
||||
7.4402,
|
||||
9.3845,
|
||||
6.2712,
|
||||
8.8275,
|
||||
5.4403,
|
||||
7.3858,
|
||||
9.6770,
|
||||
], # expected non filtered values as noted above
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
output = top_k_top_p_filtering(logits, top_k=10, top_p=0.6, min_tokens_to_keep=4)
|
||||
non_inf_output = output[output != -float("inf")].to(device=torch_device)
|
||||
non_inf_idx = (output != -float("inf")).nonzero().to(device=torch_device)
|
||||
|
||||
self.assertTrue(torch.allclose(non_inf_expected_output, non_inf_output, atol=1e-12))
|
||||
self.assertTrue(torch.all(torch.eq(non_inf_expected_idx, non_inf_idx)))
|
||||
|
||||
@@ -386,33 +386,33 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_lm_generate_distilgpt2(self):
|
||||
model = GPT2LMHeadModel.from_pretrained("distilgpt2")
|
||||
input_ids = torch.Tensor([[464, 3290, 318, 13779]]).long() # The dog is cute
|
||||
input_ids = torch.Tensor([[464, 1893]]).long() # The president
|
||||
expected_output_ids = [
|
||||
464,
|
||||
3290,
|
||||
318,
|
||||
13779,
|
||||
996,
|
||||
339,
|
||||
460,
|
||||
3360,
|
||||
655,
|
||||
2513,
|
||||
1893,
|
||||
286,
|
||||
262,
|
||||
1578,
|
||||
1829,
|
||||
11,
|
||||
290,
|
||||
262,
|
||||
1893,
|
||||
286,
|
||||
262,
|
||||
1578,
|
||||
7526,
|
||||
11,
|
||||
423,
|
||||
587,
|
||||
287,
|
||||
262,
|
||||
3952,
|
||||
13,
|
||||
632,
|
||||
318,
|
||||
407,
|
||||
845,
|
||||
3621,
|
||||
284,
|
||||
] # The dog is cute though he can sometimes just walk in the park. It is not very nice to
|
||||
torch.manual_seed(0)
|
||||
2635,
|
||||
] # The president of the United States, and the president of the United Kingdom, have been in the White
|
||||
|
||||
output_ids = model.generate(
|
||||
input_ids,
|
||||
do_sample=False,
|
||||
bos_token_id=self.special_tokens["bos_token_id"],
|
||||
eos_token_ids=self.special_tokens["eos_token_id"],
|
||||
)
|
||||
|
||||
@@ -18,6 +18,7 @@ import copy
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import is_tf_available, is_torch_available
|
||||
|
||||
@@ -28,6 +29,8 @@ if is_tf_available():
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
from transformers import tf_top_k_top_p_filtering
|
||||
|
||||
if _tf_gpu_memory_limit is not None:
|
||||
gpus = tf.config.list_physical_devices("GPU")
|
||||
for gpu in gpus:
|
||||
@@ -56,6 +59,7 @@ class TFModelTesterMixin:
|
||||
|
||||
model_tester = None
|
||||
all_model_classes = ()
|
||||
all_generative_model_classes = ()
|
||||
test_torchscript = True
|
||||
test_pruning = True
|
||||
test_resize_embeddings = True
|
||||
@@ -216,7 +220,7 @@ class TFModelTesterMixin:
|
||||
outputs_dict = model(inputs_dict)
|
||||
|
||||
inputs_keywords = copy.deepcopy(inputs_dict)
|
||||
input_ids = inputs_keywords.pop("input_ids" if not self.is_encoder_decoder else "decoder_input_ids", None)
|
||||
input_ids = inputs_keywords.pop("input_ids" if not self.is_encoder_decoder else "decoder_input_ids", None,)
|
||||
outputs_keywords = model(input_ids, **inputs_keywords)
|
||||
|
||||
output_dict = outputs_dict[0].numpy()
|
||||
@@ -299,7 +303,7 @@ class TFModelTesterMixin:
|
||||
self.assertEqual(model.config.output_hidden_states, True)
|
||||
self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
|
||||
self.assertListEqual(
|
||||
list(hidden_states[0].shape[-2:]), [self.model_tester.seq_length, self.model_tester.hidden_size]
|
||||
list(hidden_states[0].shape[-2:]), [self.model_tester.seq_length, self.model_tester.hidden_size],
|
||||
)
|
||||
|
||||
def test_model_common_attributes(self):
|
||||
@@ -316,7 +320,10 @@ class TFModelTesterMixin:
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
first, second = model(inputs_dict, training=False)[0], model(inputs_dict, training=False)[0]
|
||||
first, second = (
|
||||
model(inputs_dict, training=False)[0],
|
||||
model(inputs_dict, training=False)[0],
|
||||
)
|
||||
out_1 = first.numpy()
|
||||
out_2 = second.numpy()
|
||||
out_1 = out_1[~np.isnan(out_1)]
|
||||
@@ -338,9 +345,9 @@ class TFModelTesterMixin:
|
||||
x = wte([input_ids, None, None, None], mode="embedding")
|
||||
except Exception:
|
||||
if hasattr(self.model_tester, "embedding_size"):
|
||||
x = tf.ones(input_ids.shape + [self.model_tester.embedding_size], dtype=tf.dtypes.float32)
|
||||
x = tf.ones(input_ids.shape + [self.model_tester.embedding_size], dtype=tf.dtypes.float32,)
|
||||
else:
|
||||
x = tf.ones(input_ids.shape + [self.model_tester.hidden_size], dtype=tf.dtypes.float32)
|
||||
x = tf.ones(input_ids.shape + [self.model_tester.hidden_size], dtype=tf.dtypes.float32,)
|
||||
return x
|
||||
|
||||
def test_inputs_embeds(self):
|
||||
@@ -366,6 +373,37 @@ class TFModelTesterMixin:
|
||||
|
||||
model(inputs_dict)
|
||||
|
||||
def test_lm_head_model_random_generate(self):
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = inputs_dict.get(
|
||||
"input_ids", None
|
||||
) # TODO (PVP): ugly workaround to make code work for t5 for the moment - has to changed when t5 is fixed.
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
# TODO (PVP): add beam search tests when beam search is implemented
|
||||
model = model_class(config)
|
||||
|
||||
if config.bos_token_id is None:
|
||||
with self.assertRaises(AssertionError):
|
||||
model.generate(max_length=5)
|
||||
# batch_size = 1
|
||||
self._check_generated_tokens(model.generate(input_ids))
|
||||
else:
|
||||
# batch_size = 1
|
||||
self._check_generated_tokens(model.generate(max_length=5))
|
||||
# batch_size = 1, num_beams > 1
|
||||
|
||||
# batch_size > 1, sample
|
||||
self._check_generated_tokens(model.generate(input_ids, num_return_sequences=3))
|
||||
# batch_size > 1, greedy
|
||||
self._check_generated_tokens(model.generate(input_ids, do_sample=False, num_return_sequences=3))
|
||||
|
||||
def _check_generated_tokens(self, output_ids):
|
||||
for token_id in output_ids[0].numpy().tolist():
|
||||
self.assertGreaterEqual(token_id, 0)
|
||||
self.assertLess(token_id, self.model_tester.vocab_size)
|
||||
|
||||
|
||||
def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=None):
|
||||
"""Creates a random int32 tensor of the shape within the vocab size."""
|
||||
@@ -383,3 +421,98 @@ def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=None):
|
||||
output = tf.constant(values, shape=shape, dtype=dtype if dtype is not None else tf.int32)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@require_tf
|
||||
class UtilsFunctionsTest(unittest.TestCase):
|
||||
|
||||
# tests whether the top_k_top_p_filtering function behaves as expected
|
||||
def test_top_k_top_p_filtering(self):
|
||||
logits = tf.convert_to_tensor(
|
||||
[
|
||||
[
|
||||
8.2220991, # 3rd highest value; idx. 0
|
||||
-0.5620044,
|
||||
5.23229752,
|
||||
4.0386393,
|
||||
-6.8798378,
|
||||
-0.54785802,
|
||||
-3.2012153,
|
||||
2.92777176,
|
||||
1.88171953,
|
||||
7.35341276, # 5th highest value; idx. 9
|
||||
8.43207833, # 2nd highest value; idx. 10
|
||||
-9.85711836,
|
||||
-5.96209236,
|
||||
-1.13039161,
|
||||
-7.1115294,
|
||||
-0.8369633,
|
||||
-5.3186408,
|
||||
7.06427407,
|
||||
0.81369344,
|
||||
-0.82023817,
|
||||
-5.9179796,
|
||||
0.58813443,
|
||||
-6.99778438,
|
||||
4.71551189,
|
||||
-0.18771637,
|
||||
7.44020759, # 4th highest value; idx. 25
|
||||
9.38450987, # 1st highest value; idx. 26
|
||||
2.12662941,
|
||||
-9.32562038,
|
||||
2.35652522,
|
||||
], # cummulative prob of 5 highest values <= 0.6
|
||||
[
|
||||
0.58425518,
|
||||
4.53139238,
|
||||
-5.57510464,
|
||||
-6.28030699,
|
||||
-7.19529503,
|
||||
-4.02122551,
|
||||
1.39337037,
|
||||
-6.06707057,
|
||||
1.59480517,
|
||||
-9.643119,
|
||||
0.03907799,
|
||||
0.67231762,
|
||||
-8.88206726,
|
||||
6.27115922, # 4th highest value; idx. 13
|
||||
2.28520723,
|
||||
4.82767506,
|
||||
4.30421368,
|
||||
8.8275313, # 2nd highest value; idx. 17
|
||||
5.44029958, # 5th highest value; idx. 18
|
||||
-4.4735794,
|
||||
7.38579536, # 3rd highest value; idx. 20
|
||||
-2.91051663,
|
||||
2.61946077,
|
||||
-2.5674762,
|
||||
-9.48959302,
|
||||
-4.02922645,
|
||||
-1.35416918,
|
||||
9.67702323, # 1st highest value; idx. 27
|
||||
-5.89478553,
|
||||
1.85370467,
|
||||
], # cummulative prob of 5 highest values <= 0.6
|
||||
],
|
||||
dtype=tf.float32,
|
||||
)
|
||||
|
||||
non_inf_expected_idx = tf.convert_to_tensor(
|
||||
[[0, 0], [0, 9], [0, 10], [0, 25], [0, 26], [1, 13], [1, 17], [1, 18], [1, 20], [1, 27]], dtype=tf.int32,
|
||||
) # expected non filtered idx as noted above
|
||||
|
||||
non_inf_expected_output = tf.convert_to_tensor(
|
||||
[8.222099, 7.3534126, 8.432078, 7.4402075, 9.38451, 6.271159, 8.827531, 5.4402995, 7.3857956, 9.677023],
|
||||
dtype=tf.float32,
|
||||
) # expected non filtered values as noted above
|
||||
|
||||
output = tf_top_k_top_p_filtering(logits, top_k=10, top_p=0.6, min_tokens_to_keep=4)
|
||||
|
||||
non_inf_output = output[output != -float("inf")]
|
||||
non_inf_idx = tf.cast(
|
||||
tf.where(tf.not_equal(output, tf.constant(-float("inf"), dtype=tf.float32))), dtype=tf.int32,
|
||||
)
|
||||
|
||||
tf.debugging.assert_near(non_inf_output, non_inf_expected_output, rtol=1e-12)
|
||||
tf.debugging.assert_equal(non_inf_idx, non_inf_expected_idx)
|
||||
|
||||
@@ -31,6 +31,7 @@ if is_tf_available():
|
||||
class TFCTRLModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (TFCTRLModel, TFCTRLLMHeadModel) if is_tf_available() else ()
|
||||
all_generative_model_classes = (TFCTRLLMHeadModel,) if is_tf_available() else ()
|
||||
|
||||
class TFCTRLModelTester(object):
|
||||
def __init__(
|
||||
|
||||
@@ -37,7 +37,7 @@ if is_tf_available():
|
||||
class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (TFGPT2Model, TFGPT2LMHeadModel, TFGPT2DoubleHeadsModel) if is_tf_available() else ()
|
||||
# all_model_classes = (TFGPT2Model, TFGPT2LMHeadModel) if is_tf_available() else ()
|
||||
all_generative_model_classes = (TFGPT2LMHeadModel,) if is_tf_available() else ()
|
||||
|
||||
class TFGPT2ModelTester(object):
|
||||
def __init__(
|
||||
@@ -89,6 +89,8 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
self.num_labels = num_labels
|
||||
self.num_choices = num_choices
|
||||
self.scope = scope
|
||||
self.bos_token_id = vocab_size - 1
|
||||
self.eos_token_id = vocab_size - 1
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
@@ -123,9 +125,11 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
# hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
# attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
n_positions=self.max_position_embeddings,
|
||||
n_ctx=self.max_position_embeddings
|
||||
n_ctx=self.max_position_embeddings,
|
||||
# type_vocab_size=self.type_vocab_size,
|
||||
# initializer_range=self.initializer_range
|
||||
bos_token_id=self.bos_token_id,
|
||||
eos_token_ids=self.eos_token_id,
|
||||
)
|
||||
|
||||
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
|
||||
@@ -144,7 +148,11 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
|
||||
def create_and_check_gpt2_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
||||
model = TFGPT2Model(config=config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": input_mask,
|
||||
"token_type_ids": token_type_ids,
|
||||
}
|
||||
sequence_output = model(inputs)[0]
|
||||
|
||||
inputs = [input_ids, None, input_mask] # None is the input for 'past'
|
||||
@@ -156,18 +164,22 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
"sequence_output": sequence_output.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["sequence_output"].shape), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
list(result["sequence_output"].shape), [self.batch_size, self.seq_length, self.hidden_size],
|
||||
)
|
||||
|
||||
def create_and_check_gpt2_lm_head(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
||||
model = TFGPT2LMHeadModel(config=config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": input_mask,
|
||||
"token_type_ids": token_type_ids,
|
||||
}
|
||||
prediction_scores = model(inputs)[0]
|
||||
result = {
|
||||
"prediction_scores": prediction_scores.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["prediction_scores"].shape), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
list(result["prediction_scores"].shape), [self.batch_size, self.seq_length, self.vocab_size],
|
||||
)
|
||||
|
||||
def create_and_check_gpt2_double_head(
|
||||
@@ -188,7 +200,7 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
lm_logits, mc_logits = model(inputs)[:2]
|
||||
result = {"lm_logits": lm_logits.numpy(), "mc_logits": mc_logits.numpy()}
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits"].shape), [self.batch_size, self.num_choices, self.seq_length, self.vocab_size]
|
||||
list(result["lm_logits"].shape), [self.batch_size, self.num_choices, self.seq_length, self.vocab_size],
|
||||
)
|
||||
self.parent.assertListEqual(list(result["mc_logits"].shape), [self.batch_size, self.num_choices])
|
||||
|
||||
@@ -207,7 +219,11 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
choice_labels,
|
||||
) = config_and_inputs
|
||||
|
||||
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
|
||||
inputs_dict = {
|
||||
"input_ids": input_ids,
|
||||
"token_type_ids": token_type_ids,
|
||||
"attention_mask": input_mask,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
def setUp(self):
|
||||
@@ -234,3 +250,48 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
for model_name in list(TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = TFGPT2Model.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
def prepare_generation_special_tokens():
|
||||
return {"bos_token_id": 50256, "eos_token_id": 50256}
|
||||
|
||||
|
||||
class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||
|
||||
special_tokens = prepare_generation_special_tokens()
|
||||
|
||||
@slow
|
||||
def test_lm_generate_distilgpt2(self):
|
||||
model = TFGPT2LMHeadModel.from_pretrained("distilgpt2")
|
||||
input_ids = tf.convert_to_tensor([[464, 1893]], dtype=tf.int32) # The president
|
||||
expected_output_ids = [
|
||||
464,
|
||||
1893,
|
||||
286,
|
||||
262,
|
||||
1578,
|
||||
1829,
|
||||
11,
|
||||
290,
|
||||
262,
|
||||
1893,
|
||||
286,
|
||||
262,
|
||||
1578,
|
||||
7526,
|
||||
11,
|
||||
423,
|
||||
587,
|
||||
287,
|
||||
262,
|
||||
2635,
|
||||
] # The president of the United States, and the president of the United Kingdom, have been in the White
|
||||
|
||||
output_ids = model.generate(
|
||||
input_ids,
|
||||
do_sample=False,
|
||||
bos_token_id=self.special_tokens["bos_token_id"],
|
||||
eos_token_ids=self.special_tokens["eos_token_id"],
|
||||
)
|
||||
|
||||
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
|
||||
|
||||
@@ -39,6 +39,9 @@ class TFOpenAIGPTModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(TFOpenAIGPTModel, TFOpenAIGPTLMHeadModel, TFOpenAIGPTDoubleHeadsModel) if is_tf_available() else ()
|
||||
)
|
||||
all_generative_model_classes = (
|
||||
(TFOpenAIGPTLMHeadModel,) if is_tf_available() else ()
|
||||
) # TODO (PVP): Add Double HeadsModel when generate() function is changed accordingly
|
||||
|
||||
class TFOpenAIGPTModelTester(object):
|
||||
def __init__(
|
||||
|
||||
@@ -37,6 +37,8 @@ if is_tf_available():
|
||||
class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (TFTransfoXLModel, TFTransfoXLLMHeadModel) if is_tf_available() else ()
|
||||
all_generative_model_classes = () if is_tf_available() else ()
|
||||
# TODO: add this test when TFTransfoXLLMHead has a linear output layer implemented
|
||||
test_pruning = False
|
||||
test_torchscript = False
|
||||
test_resize_embeddings = False
|
||||
@@ -62,6 +64,7 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
num_hidden_layers=5,
|
||||
scope=None,
|
||||
seed=1,
|
||||
eos_token_id=0,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@@ -82,6 +85,7 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.scope = scope
|
||||
self.seed = seed
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
@@ -103,6 +107,7 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
d_inner=self.d_inner,
|
||||
div_val=self.div_val,
|
||||
n_layer=self.num_hidden_layers,
|
||||
eos_token_ids=self.eos_token_id,
|
||||
)
|
||||
|
||||
return (config, input_ids_1, input_ids_2, lm_labels)
|
||||
|
||||
@@ -43,6 +43,9 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
if is_tf_available()
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (
|
||||
(TFXLMWithLMHeadModel,) if is_tf_available() else ()
|
||||
) # TODO (PVP): Check other models whether language generation is also applicable
|
||||
|
||||
class TFXLMModelTester(object):
|
||||
def __init__(
|
||||
@@ -75,6 +78,7 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
summary_type="last",
|
||||
use_proj=True,
|
||||
scope=None,
|
||||
bos_token_id=0,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@@ -105,6 +109,7 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
self.num_labels = num_labels
|
||||
self.num_choices = num_choices
|
||||
self.scope = scope
|
||||
self.bos_token_id = bos_token_id
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
@@ -145,6 +150,7 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
initializer_range=self.initializer_range,
|
||||
summary_type=self.summary_type,
|
||||
use_proj=self.use_proj,
|
||||
bos_token_id=self.bos_token_id,
|
||||
)
|
||||
|
||||
return (
|
||||
|
||||
@@ -51,6 +51,9 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
if is_tf_available()
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (
|
||||
(TFXLNetLMHeadModel,) if is_tf_available() else ()
|
||||
) # TODO (PVP): Check other models whether language generation is also applicable
|
||||
test_pruning = False
|
||||
|
||||
class TFXLNetModelTester(object):
|
||||
@@ -77,6 +80,9 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
initializer_range=0.05,
|
||||
seed=1,
|
||||
type_vocab_size=2,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
pad_token_id=5,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@@ -100,6 +106,9 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
self.seed = seed
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.bos_token_id = bos_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
@@ -139,6 +148,9 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
bi_data=self.bi_data,
|
||||
initializer_range=self.initializer_range,
|
||||
num_labels=self.type_sequence_label_size,
|
||||
bos_token_id=self.bos_token_id,
|
||||
pad_token_id=self.pad_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
)
|
||||
|
||||
return (
|
||||
|
||||
Reference in New Issue
Block a user