Note

# Using Piquasso with JAX#

Piquasso and JAX can operate together. Just use pq.JaxCalculator as demonstrated by the following simple example:

[13]:

import numpy as np
import piquasso as pq

jax_calculator = pq.JaxCalculator()

simulator = pq.PureFockSimulator(
d=2,
config=pq.Config(cutoff=5, dtype=np.float32, normalize=False),
calculator=jax_calculator,
)

with pq.Program() as program:
pq.Q() | pq.Vacuum()

pq.Q(0) | pq.Displacement(r=0.43)

pq.Q(0, 1) | pq.Beamsplitter(theta=np.pi / 3)

state = simulator.execute(program).state

print(state.fock_probabilities)

[8.3118737e-01 3.8421631e-02 1.1526491e-01 8.8801968e-04 5.3281188e-03
7.9921819e-03 1.3682902e-05 1.2314616e-04 3.6943855e-04 3.6943876e-04
1.5812304e-07 1.8974771e-06 8.5386482e-06 1.7077307e-05 1.2807981e-05]


One can improve the speed of this calculation by using jax.jit:

[14]:

import time

import numpy as np
import piquasso as pq

from jax import jit

def func(r, theta):
jax_calculator = pq.JaxCalculator()

simulator = pq.PureFockSimulator(
d=2,
config=pq.Config(cutoff=5, dtype=np.float32, normalize=False),
calculator=jax_calculator,
)

with pq.Program() as program:
pq.Q() | pq.Vacuum()

pq.Q(0) | pq.Displacement(r=r)

pq.Q(0, 1) | pq.Beamsplitter(theta=theta)

return simulator.execute(program).state.fock_probabilities

compiled_func = jit(func)

iterations = 10

for i in range(iterations):
r = np.random.rand()
theta = np.random.rand()

start_time = time.time()
func(r, theta)
print(f"{i+1}. original runtime:\t{time.time() - start_time} s")
start_time = time.time()
compiled_func(r, theta)
print(f"{i+1}. compiled runtime:\t{time.time() - start_time} s")

1. original runtime:    0.053116798400878906 s
1. compiled runtime:    0.48732662200927734 s
2. original runtime:    0.04149889945983887 s
2. compiled runtime:    3.361701965332031e-05 s
3. original runtime:    0.03750133514404297 s
3. compiled runtime:    3.218650817871094e-05 s
4. original runtime:    0.03447985649108887 s
4. compiled runtime:    3.170967102050781e-05 s
5. original runtime:    0.033095598220825195 s
5. compiled runtime:    3.0994415283203125e-05 s
6. original runtime:    0.031104087829589844 s
6. compiled runtime:    3.123283386230469e-05 s
7. original runtime:    0.030498027801513672 s
7. compiled runtime:    3.504753112792969e-05 s
8. original runtime:    0.0323636531829834 s
8. compiled runtime:    3.123283386230469e-05 s
9. original runtime:    0.034629106521606445 s
9. compiled runtime:    3.218650817871094e-05 s
10. original runtime:   0.03313851356506348 s
10. compiled runtime:   3.218650817871094e-05 s


Notice, that the first run of the compiled version is a bit slower at first, but the consequent calculations are significantly faster. One can also calculate the jacobian of this function, e.g., with jax.jacfwd:

[15]:

from jax import jacfwd

jacobian_func = jit(jacfwd(compiled_func, argnums=(0, 1)))

jacobian = jacobian_func(0.43, np.pi / 3)

print("Jacobian by 'r': ", jacobian[0])
print("Jacobian by 'theta': ", jacobian[1])

Jacobian by 'r':  [-7.1482110e-01  1.4566265e-01  4.3698803e-01  7.4969521e-03
4.4981722e-02  6.7472607e-02  1.7915692e-04  1.6124130e-03
4.8372396e-03  4.8372415e-03  2.8058382e-06  3.3670065e-05
1.5151533e-04  3.0303077e-04  2.2727314e-04]
Jacobian by 'theta':  [ 0.0000000e+00 -1.3309644e-01  1.3309643e-01 -6.1523812e-03
-1.2304768e-02  1.8457148e-02 -1.4219691e-04 -7.1098475e-04
-4.2659120e-04  1.2797731e-03 -2.1910173e-06 -1.7528142e-05
-3.9438339e-05 -1.7730152e-11  5.9157519e-05]