Ikeda Map in JAX

Build classical discriminator

import jax
import jax.numpy as jnp
from jax import jit
from jax import lax

key = jax.random.key(0)
size = 10**6
points = jax.random.uniform(key, shape=(size, 2), minval=0, maxval=0.5)

@jit
def ikeda_update(point):
    u = 0.9
    x, y = point[0], point[1]
    t = 0.4 - (6 / (1 + x ** 2 + y ** 2))
    x_new = 1 + u * (x * jnp.cos(t) - y * jnp.sin(t))
    y_new = u * (x * jnp.sin(t) + y * jnp.cos(t))
    return jnp.array((x_new, y_new))

ikeda_update = jit(jax.vmap(ikeda_update))

@jit
def iterate(points):
    def body_fn(i, val):
        return ikeda_update(val)
    return lax.fori_loop(0, 50, body_fn, points)

points = iterate(points)
points
Array([[ 0.16112472, -0.16200231],
       [ 0.6199205 , -0.66500235],
       [ 0.0068292 , -0.62714374],
       ...,
       [-0.1013843 ,  0.36023578],
       [ 0.13512027, -1.1439236 ],
       [ 0.28192288, -0.15245877]], dtype=float32)
%%timeit
ikeda_update(points)
2.59 ms ± 42.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)