Add generate() functionality to TF 2.0 (#3063)
* add first copy past test to tf 2 generate * add tf top_k_top_p_filter fn * add generate function for TF * add generate function for TF * implemented generate for all models expect transfoXL * implemented generate for all models expect transfoXL * implemented generate for all models expect transfoXL * make style * change permission of test file to correct ones * delete ipdb * delete ipdb * fix bug and finish simple gpt2 integration test * clean test file * clean test file * make style * make style * make style * make style * change import style * change import style * make style * make style * add decorators * add decorators * fix tf ctrl bug dim => axis in TF * make style * make style * refactored test file * refactored test file * take out test_torch_tf_conversion if nothing is defined * take out test_torch_tf_conversion if nothing is defined * remove useless files * remove useless files * fix conflicts * fix conflicts * fix conflicts * fix conflicts * fix conflicts * solve conflicts * solve conflicts * fix conflicts * fix conflicts * merge conflicts * delete ipdb * exposed top_k_top_p_filtering fns * delete weirdly created w! file * add comment to test tf common modeling * fix conflicts * fix conflicts * make style * merge conflicts * make style * change tf.tensor.shape to shape_list(tensor)
This commit is contained in:
committed by
GitHub
parent
b31f715019
commit
4134100363
@@ -37,6 +37,8 @@ if is_tf_available():
|
||||
class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (TFTransfoXLModel, TFTransfoXLLMHeadModel) if is_tf_available() else ()
|
||||
all_generative_model_classes = () if is_tf_available() else ()
|
||||
# TODO: add this test when TFTransfoXLLMHead has a linear output layer implemented
|
||||
test_pruning = False
|
||||
test_torchscript = False
|
||||
test_resize_embeddings = False
|
||||
@@ -62,6 +64,7 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
num_hidden_layers=5,
|
||||
scope=None,
|
||||
seed=1,
|
||||
eos_token_id=0,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@@ -82,6 +85,7 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.scope = scope
|
||||
self.seed = seed
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
@@ -103,6 +107,7 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
d_inner=self.d_inner,
|
||||
div_val=self.div_val,
|
||||
n_layer=self.num_hidden_layers,
|
||||
eos_token_ids=self.eos_token_id,
|
||||
)
|
||||
|
||||
return (config, input_ids_1, input_ids_2, lm_labels)
|
||||
|
||||
Reference in New Issue
Block a user