Add generate() functionality to TF 2.0 (#3063)
* add first copy past test to tf 2 generate * add tf top_k_top_p_filter fn * add generate function for TF * add generate function for TF * implemented generate for all models expect transfoXL * implemented generate for all models expect transfoXL * implemented generate for all models expect transfoXL * make style * change permission of test file to correct ones * delete ipdb * delete ipdb * fix bug and finish simple gpt2 integration test * clean test file * clean test file * make style * make style * make style * make style * change import style * change import style * make style * make style * add decorators * add decorators * fix tf ctrl bug dim => axis in TF * make style * make style * refactored test file * refactored test file * take out test_torch_tf_conversion if nothing is defined * take out test_torch_tf_conversion if nothing is defined * remove useless files * remove useless files * fix conflicts * fix conflicts * fix conflicts * fix conflicts * fix conflicts * solve conflicts * solve conflicts * fix conflicts * fix conflicts * merge conflicts * delete ipdb * exposed top_k_top_p_filtering fns * delete weirdly created w! file * add comment to test tf common modeling * fix conflicts * fix conflicts * make style * merge conflicts * make style * change tf.tensor.shape to shape_list(tensor)
This commit is contained in:
committed by
GitHub
parent
b31f715019
commit
4134100363
@@ -105,8 +105,8 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
|
||||
v = self.split_into_heads(v, batch_size)
|
||||
if layer_past is not None:
|
||||
past_key, past_value = tf.unstack(layer_past, axis=1)
|
||||
k = tf.concat((past_key, k), dim=-2)
|
||||
v = tf.concat((past_value, v), dim=-2)
|
||||
k = tf.concat((past_key, k), axis=-2)
|
||||
v = tf.concat((past_value, v), axis=-2)
|
||||
present = tf.stack((k, v), axis=1)
|
||||
|
||||
output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask)
|
||||
@@ -505,6 +505,13 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head.input_embeddings
|
||||
|
||||
def prepare_inputs_for_generation(self, inputs, past, **kwargs):
|
||||
# only last token for inputs_ids if past is defined in kwargs
|
||||
if past:
|
||||
inputs = tf.expand_dims(inputs[:, -1], -1)
|
||||
|
||||
return {"inputs": inputs, "past": past}
|
||||
|
||||
@add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING)
|
||||
def call(self, inputs, **kwargs):
|
||||
r"""
|
||||
|
||||
Reference in New Issue
Block a user