Black 20 release
This commit is contained in:
@@ -32,7 +32,8 @@ if is_tf_available():
|
||||
|
||||
class TFT5ModelTester:
|
||||
def __init__(
|
||||
self, parent,
|
||||
self,
|
||||
parent,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = 13
|
||||
@@ -181,7 +182,10 @@ class TFT5ModelTester:
|
||||
|
||||
# append to next input_ids and attn_mask
|
||||
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
|
||||
attn_mask = tf.concat([attn_mask, tf.ones((attn_mask.shape[0], 1), dtype=tf.int32)], axis=1,)
|
||||
attn_mask = tf.concat(
|
||||
[attn_mask, tf.ones((attn_mask.shape[0], 1), dtype=tf.int32)],
|
||||
axis=1,
|
||||
)
|
||||
|
||||
# get two different outputs
|
||||
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)[0]
|
||||
|
||||
Reference in New Issue
Block a user