* fix missing import jnp * Fix missing jax and k=1 Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>