Refactor: Removed un-necessary object base class (#32230)
* Refactored to remove un-necessary object base class. * small fix.
This commit is contained in:
@@ -284,7 +284,7 @@ def make_fast_generalized_attention(
|
||||
return attention_fn
|
||||
|
||||
|
||||
class RandomMatrix(object):
|
||||
class RandomMatrix:
|
||||
r"""
|
||||
Abstract class providing a method for constructing 2D random arrays. Class is responsible for constructing 2D
|
||||
random arrays.
|
||||
@@ -348,7 +348,7 @@ class GaussianOrthogonalRandomMatrix(RandomMatrix):
|
||||
return jnp.matmul(jnp.diag(multiplier), final_matrix)
|
||||
|
||||
|
||||
class FastAttention(object):
|
||||
class FastAttention:
|
||||
r"""
|
||||
Abstract class providing a method for fast attention. Class is responsible for providing a method
|
||||
<dot_product_attention> for fast approximate attention.
|
||||
|
||||
Reference in New Issue
Block a user