Ikeda Map in JAX
Build classical discriminator
import jax import jax.numpy as jnp from jax import jit from jax import lax key = jax.random.key(0) size = 10**6 points = jax.random.uniform(key, shape=(size, 2), minval=0, maxval=0.5) @jit def ikeda_update(point): u = 0.9 x, y = point[0], point[1] t = 0.4 - (6 / (1 + x ** 2 + y ** 2)) x_new = 1 + u * (x * jnp.cos(t) - y * jnp.sin(t)) y_new = u * (x * jnp.sin(t) + y * jnp.cos(t)) return jnp.array((x_new, y_new)) ikeda_update = jit(jax.vmap(ikeda_update)) @jit def iterate(points): def body_fn(i, val): return ikeda_update(val) return lax.fori_loop(0, 50, body_fn, points) points = iterate(points) points
Array([[ 0.16112472, -0.16200231],
[ 0.6199205 , -0.66500235],
[ 0.0068292 , -0.62714374],
...,
[-0.1013843 , 0.36023578],
[ 0.13512027, -1.1439236 ],
[ 0.28192288, -0.15245877]], dtype=float32)
%%timeit ikeda_update(points)
2.59 ms ± 42.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)