Add generate() functionality to TF 2.0 (#3063)
* add first copy past test to tf 2 generate * add tf top_k_top_p_filter fn * add generate function for TF * add generate function for TF * implemented generate for all models expect transfoXL * implemented generate for all models expect transfoXL * implemented generate for all models expect transfoXL * make style * change permission of test file to correct ones * delete ipdb * delete ipdb * fix bug and finish simple gpt2 integration test * clean test file * clean test file * make style * make style * make style * make style * change import style * change import style * make style * make style * add decorators * add decorators * fix tf ctrl bug dim => axis in TF * make style * make style * refactored test file * refactored test file * take out test_torch_tf_conversion if nothing is defined * take out test_torch_tf_conversion if nothing is defined * remove useless files * remove useless files * fix conflicts * fix conflicts * fix conflicts * fix conflicts * fix conflicts * solve conflicts * solve conflicts * fix conflicts * fix conflicts * merge conflicts * delete ipdb * exposed top_k_top_p_filtering fns * delete weirdly created w! file * add comment to test tf common modeling * fix conflicts * fix conflicts * make style * merge conflicts * make style * change tf.tensor.shape to shape_list(tensor)
This commit is contained in:
committed by
GitHub
parent
b31f715019
commit
4134100363
@@ -136,7 +136,7 @@ if is_sklearn_available():
|
||||
|
||||
# Modeling
|
||||
if is_torch_available():
|
||||
from .modeling_utils import PreTrainedModel, prune_layer, Conv1D
|
||||
from .modeling_utils import PreTrainedModel, prune_layer, Conv1D, top_k_top_p_filtering
|
||||
from .modeling_auto import (
|
||||
AutoModel,
|
||||
AutoModelForPreTraining,
|
||||
@@ -291,7 +291,13 @@ if is_torch_available():
|
||||
|
||||
# TensorFlow
|
||||
if is_tf_available():
|
||||
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary, shape_list
|
||||
from .modeling_tf_utils import (
|
||||
TFPreTrainedModel,
|
||||
TFSharedEmbeddings,
|
||||
TFSequenceSummary,
|
||||
shape_list,
|
||||
tf_top_k_top_p_filtering,
|
||||
)
|
||||
from .modeling_tf_auto import (
|
||||
TFAutoModel,
|
||||
TFAutoModelForPreTraining,
|
||||
|
||||
Reference in New Issue
Block a user