Note
You can download this notebook
here.
Using Piquasso with JAX#
Piquasso and JAX can operate together. Just use pq.JaxCalculator as demonstrated by the following simple example:
[13]:
import numpy as np
import piquasso as pq
jax_calculator = pq.JaxCalculator()
simulator = pq.PureFockSimulator(
d=2,
config=pq.Config(cutoff=5, dtype=np.float32, normalize=False),
calculator=jax_calculator,
)
with pq.Program() as program:
pq.Q() | pq.Vacuum()
pq.Q(0) | pq.Displacement(r=0.43)
pq.Q(0, 1) | pq.Beamsplitter(theta=np.pi / 3)
state = simulator.execute(program).state
print(state.fock_probabilities)
[8.3118737e-01 3.8421631e-02 1.1526491e-01 8.8801968e-04 5.3281188e-03
7.9921819e-03 1.3682902e-05 1.2314616e-04 3.6943855e-04 3.6943876e-04
1.5812304e-07 1.8974771e-06 8.5386482e-06 1.7077307e-05 1.2807981e-05]
One can improve the speed of this calculation by using jax.jit:
[14]:
import time
import numpy as np
import piquasso as pq
from jax import jit
def func(r, theta):
jax_calculator = pq.JaxCalculator()
simulator = pq.PureFockSimulator(
d=2,
config=pq.Config(cutoff=5, dtype=np.float32, normalize=False),
calculator=jax_calculator,
)
with pq.Program() as program:
pq.Q() | pq.Vacuum()
pq.Q(0) | pq.Displacement(r=r)
pq.Q(0, 1) | pq.Beamsplitter(theta=theta)
return simulator.execute(program).state.fock_probabilities
compiled_func = jit(func)
iterations = 10
for i in range(iterations):
r = np.random.rand()
theta = np.random.rand()
start_time = time.time()
func(r, theta)
print(f"{i+1}. original runtime:\t{time.time() - start_time} s")
start_time = time.time()
compiled_func(r, theta)
print(f"{i+1}. compiled runtime:\t{time.time() - start_time} s")
1. original runtime: 0.053116798400878906 s
1. compiled runtime: 0.48732662200927734 s
2. original runtime: 0.04149889945983887 s
2. compiled runtime: 3.361701965332031e-05 s
3. original runtime: 0.03750133514404297 s
3. compiled runtime: 3.218650817871094e-05 s
4. original runtime: 0.03447985649108887 s
4. compiled runtime: 3.170967102050781e-05 s
5. original runtime: 0.033095598220825195 s
5. compiled runtime: 3.0994415283203125e-05 s
6. original runtime: 0.031104087829589844 s
6. compiled runtime: 3.123283386230469e-05 s
7. original runtime: 0.030498027801513672 s
7. compiled runtime: 3.504753112792969e-05 s
8. original runtime: 0.0323636531829834 s
8. compiled runtime: 3.123283386230469e-05 s
9. original runtime: 0.034629106521606445 s
9. compiled runtime: 3.218650817871094e-05 s
10. original runtime: 0.03313851356506348 s
10. compiled runtime: 3.218650817871094e-05 s
Notice, that the first run of the compiled version is a bit slower at first, but the consequent calculations are significantly faster. One can also calculate the jacobian of this function, e.g., with jax.jacfwd:
[15]:
from jax import jacfwd
jacobian_func = jit(jacfwd(compiled_func, argnums=(0, 1)))
jacobian = jacobian_func(0.43, np.pi / 3)
print("Jacobian by 'r': ", jacobian[0])
print("Jacobian by 'theta': ", jacobian[1])
Jacobian by 'r': [-7.1482110e-01 1.4566265e-01 4.3698803e-01 7.4969521e-03
4.4981722e-02 6.7472607e-02 1.7915692e-04 1.6124130e-03
4.8372396e-03 4.8372415e-03 2.8058382e-06 3.3670065e-05
1.5151533e-04 3.0303077e-04 2.2727314e-04]
Jacobian by 'theta': [ 0.0000000e+00 -1.3309644e-01 1.3309643e-01 -6.1523812e-03
-1.2304768e-02 1.8457148e-02 -1.4219691e-04 -7.1098475e-04
-4.2659120e-04 1.2797731e-03 -2.1910173e-06 -1.7528142e-05
-3.9438339e-05 -1.7730152e-11 5.9157519e-05]