Add generate() functionality to TF 2.0 (#3063)

* add first copy past test to tf 2 generate

* add tf top_k_top_p_filter fn

* add generate function for TF

* add generate function for TF

* implemented generate for all models expect transfoXL

* implemented generate for all models expect transfoXL

* implemented generate for all models expect transfoXL

* make style

* change permission of test file to correct ones

* delete ipdb

* delete ipdb

* fix bug and finish simple gpt2 integration test

* clean test file

* clean test file

* make style

* make style

* make style

* make style

* change import style

* change import style

* make style

* make style

* add decorators

* add decorators

* fix tf ctrl bug dim => axis in TF

* make style

* make style

* refactored test file

* refactored test file

* take out test_torch_tf_conversion if nothing is defined

* take out test_torch_tf_conversion if nothing is defined

* remove useless files

* remove useless files

* fix conflicts

* fix conflicts

* fix conflicts

* fix conflicts

* fix conflicts

* solve conflicts

* solve conflicts

* fix conflicts

* fix conflicts

* merge conflicts

* delete ipdb

* exposed top_k_top_p_filtering fns

* delete weirdly created w! file

* add comment to test tf common modeling

* fix conflicts

* fix conflicts

* make style

* merge conflicts

* make style

* change tf.tensor.shape to shape_list(tensor)
This commit is contained in:
Patrick von Platen
2020-03-03 15:42:15 +01:00
committed by GitHub
parent b31f715019
commit 4134100363
20 changed files with 892 additions and 62 deletions

View File

@@ -105,8 +105,8 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
v = self.split_into_heads(v, batch_size)
if layer_past is not None:
past_key, past_value = tf.unstack(layer_past, axis=1)
k = tf.concat((past_key, k), dim=-2)
v = tf.concat((past_value, v), dim=-2)
k = tf.concat((past_key, k), axis=-2)
v = tf.concat((past_value, v), axis=-2)
present = tf.stack((k, v), axis=1)
output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask)
@@ -505,6 +505,13 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel):
def get_output_embeddings(self):
return self.lm_head.input_embeddings
def prepare_inputs_for_generation(self, inputs, past, **kwargs):
# only last token for inputs_ids if past is defined in kwargs
if past:
inputs = tf.expand_dims(inputs[:, -1], -1)
return {"inputs": inputs, "past": past}
@add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING)
def call(self, inputs, **kwargs):
r"""