GPU acceleration¶
This notebook is intended to demonstrate GPU-acceleration of superscreen
models using the JAX package from Google.
Below we will look at the time required to solve both single-layer and multi-layer models using three different methods:
Numpy (CPU): This is the default behavior when you call
superscreen.solve()
orsuperscreen.solve(..., gpu=False)
.JAX (CPU): This method will be used if you call
superscreen.solve(..., gpu=True)
on a machine that has JAX installed but does not have a GPU available. We test this method by forcing JAX to run on the CPU even if a GPU is available using the context managerjax.default_device(jax.devices("cpu")[0])
.JAX (GPU): This method will be used if you call
superscreen.solve(..., gpu=True)
on a machine that has JAX installed and has a GPU available. The notebook will simply skip this method if no GPU is available.
Notes¶
There are several things to keep in mind in order to get the most out of superscreen
on a GPU:
Currently, only computation performed inside
superscreen.solve()
can be offloaded to a GPU. Post-processing methods belonging tosuperscreen.Solution
cannot (currently) be run on a GPU. Calculations that make heavy use of these post-processing methods (such as fluxoid optimization) can be expected to show only modest speed-ups from GPU acceleration.There is some overhead associated with moving data onto the GPU for processing, so calls to
superscreen.solve()
that take only 10s of ms on the CPU (e.g. single-layer models with only 1,000 or so mesh vertices) may actually be slightly slower on the GPU.On the other hand, most data only needs to be transferred to the GPU once per call to
superscreen.solve()
, so large gains can be seen when iteratively solving models involving multi-layer devices.
In short, you can expect the biggest speed-ups for models with many mesh vertices and/or more than one layer.
[1]:
%config InlineBackend.figure_formats = {"retina", "png"}
%matplotlib inline
import logging
from collections import defaultdict
import jax
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (6, 4.5)
plt.rcParams["font.size"] = 12
import superscreen as sc
from superscreen.geometry import circle
[2]:
sc.version_table()
[2]:
Software | Version |
---|---|
SuperScreen | 0.8.0 |
Numpy | 1.23.1 |
SciPy | 1.8.1 |
matplotlib | 3.5.2 |
ray | 1.13.0 |
jax | 0.3.15 |
IPython | 8.4.0 |
Python | 3.9.12 (main, Jun 1 2022, 11:38:51) [GCC 7.5.0] |
OS | posix [linux] |
Number of CPUs | Physical: 2, Logical: 2 |
BLAS Info | OPENBLAS |
Wed Dec 14 15:57:21 2022 PST |
[3]:
def run_benchmark(device, min_points_values, solve_kwargs):
# Suppress warnings about not having a GPU.
logging.getLogger("superscreen.solve").setLevel(logging.CRITICAL)
mesh_sizes = []
results = defaultdict(list)
for min_points in min_points_values:
device.make_mesh(min_points=min_points, smooth=50)
mesh_sizes.append(device.points.shape[0])
print(f"Mesh size: {device.points.shape[0]}")
key = "Numpy (CPU)"
print(f" {key}: ", end="")
timing = %timeit -o sc.solve(device, **solve_kwargs)
results[key].append(timing)
key = "JAX (CPU)"
print(f" {key}: ", end="")
with jax.default_device(jax.devices("cpu")[0]):
timing = %timeit -o sc.solve(device, gpu=True, **solve_kwargs)
results[key].append(timing)
key = "JAX (GPU)"
print(f" {key}: ", end="")
if "cpu" in jax.devices()[0].device_kind:
print("Skipping because there is no GPU available.")
else:
timing = %timeit -o sc.solve(device, gpu=True, **solve_kwargs)
results[key].append(timing)
logging.getLogger("superscreen.solve").setLevel(logging.WARNING)
return np.array(mesh_sizes), dict(results)
[4]:
def plot_benchmark(mesh_sizes, results):
fig, axes = plt.subplots(
1, 3, figsize=(10, 3), sharex=True, constrained_layout=True
)
ys_ref = np.array([t.average for t in results["Numpy (CPU)"]])
for label, timing in results.items():
xs = mesh_sizes
ys = np.array([t.average for t in timing])
yerr = np.array([t.stdev for t in timing])
for ax in axes[:2]:
ax.errorbar(xs, ys, yerr=yerr, marker="o", ls="--", label=label)
axes[2].plot(xs, ys_ref / ys, marker="o", ls="--")
for ax in axes:
ax.set_xlabel("Mesh size")
ax.set_ylabel("Solve time [s]")
ax.grid(True)
axes[2].set_ylabel("Speed-up vs. Numpy (CPU)")
axes[0].legend(loc=0)
axes[1].set_yscale("log")
return fig, axes
Single layer device¶
Here we model a single-layer superconducting ring with some circulating current.
[5]:
length_units = "um"
ro = 3 # outer radius
ri = 1 # inner radius
slit_width = 0.25
layer = sc.Layer("base", london_lambda=0.100, thickness=0.025, z0=0)
ring = circle(ro)
hole = circle(ri)
bounding_box = sc.Polygon("bounding_box", layer="base", points=circle(1.2 * ro))
device = sc.Device(
"ring",
layers=[sc.Layer("base", london_lambda=0.100, thickness=0.025, z0=0)],
films=[sc.Polygon("ring", layer="base", points=ring)],
holes=[sc.Polygon("hole", layer="base", points=hole)],
abstract_regions=[bounding_box],
length_units=length_units,
)
device.solve_dtype = "float32"
fig, ax = device.draw(exclude="bounding_box", legend=True)

We get the same answer for the mutual inductance using either Numpy on the CPU or JAX.
[6]:
device.make_mesh(min_points=5000, smooth=50)
[7]:
M_numpy = device.mutual_inductance_matrix(units="pH", gpu=False)
display(M_numpy)
Magnitude | [[5.769204365742559]] |
---|---|
Units | picohenry |
[8]:
M_jax = device.mutual_inductance_matrix(units="pH", gpu=True)
display(M_jax)
Magnitude | [[5.76919768357049]] |
---|---|
Units | picohenry |
Measure the wall time of superscreen.solve()
vs. the number of points in the mesh.
[9]:
# Model the device with 1 mA circulating current
circulating_currents = {"hole": "1 mA"}
solve_kwargs = dict(
circulating_currents=circulating_currents,
field_units="mT",
current_units="mA",
log_level=None,
)
# Look at solve time vs. mesh size
min_points = 1000 * np.arange(1, 11, dtype=int)
[10]:
mesh_sizes, results = run_benchmark(device, min_points, solve_kwargs)
Mesh size: 1015
Numpy (CPU): 27.6 ms ± 252 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
JAX (CPU): 53 ms ± 1.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (GPU): 41 ms ± 2.06 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 2105
Numpy (CPU): 126 ms ± 1.2 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
JAX (CPU): 198 ms ± 3.39 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (GPU): 122 ms ± 3.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 3105
Numpy (CPU): 266 ms ± 485 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (CPU): 411 ms ± 7.71 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (GPU): 201 ms ± 2.92 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 4285
Numpy (CPU): 549 ms ± 7.94 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (CPU): 816 ms ± 13.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (GPU): 331 ms ± 2.81 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 5248
Numpy (CPU): 899 ms ± 8.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (CPU): 1.33 s ± 23.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (GPU): 509 ms ± 18.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 6390
Numpy (CPU): 1.46 s ± 75.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (CPU): 2.2 s ± 133 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (GPU): 668 ms ± 10.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 7051
Numpy (CPU): 1.91 s ± 85.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (CPU): 2.77 s ± 109 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (GPU): 854 ms ± 57.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 8748
Numpy (CPU): 3.04 s ± 89.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (CPU): 4.34 s ± 129 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (GPU): 1.27 s ± 62.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 9711
Numpy (CPU): 3.97 s ± 98 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (CPU): 5.37 s ± 171 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (GPU): 1.46 s ± 70.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 10744
Numpy (CPU): 4.92 s ± 133 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (CPU): 6.22 s ± 85.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (GPU): 1.59 s ± 15.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
[11]:
fig, axes = plot_benchmark(mesh_sizes, results)
_ = fig.suptitle("Single layer device")

Multi-layer device¶
Here we model a scanning Superconducting QUantum Interference Device (SQUID) susceptometer, which consists of three superconducting layers. In particular, we are interested in the mutual inductance between the field coil (green loop below) and pickup loop (orange loop below), which lie in different superconducting layers (called BE and W1, respectively).
This is a convenient benchmark because the mutual inductance for these devices has been thoroughly characterized experimentally and was found be in the range \(166 \pm 4\,\Phi_0/\mathrm{A}\) (see Table 1 of arXiv:1605.09483 or DOI:10.1063/1.4961982).
[12]:
import squids
[13]:
device = squids.ibm.medium.make_squid(align_layers="bottom")
device.solve_dtype = "float32"
[14]:
_ = device.draw(exclude="bounding_box", legend=True)

[15]:
# Model the device with 1 mA circulating current
I_circ = device.ureg("1 mA")
circulating_currents = {"fc_center": str(I_circ)}
solve_kwargs = dict(
circulating_currents=circulating_currents,
field_units="mT",
current_units="mA",
log_level=None,
iterations=5,
)
We get the same answer for the mutual inductance using either Numpy on the CPU or JAX.
[16]:
device.make_mesh(min_points=6000, smooth=50)
[17]:
solution_numpy = sc.solve(device, **solve_kwargs)[-1]
mutual_numpy = (sum(solution_numpy.hole_fluxoid("pl_center")) / I_circ).to("Phi_0 / A")
print(f"mutual_numpy = {mutual_numpy:.5f~P}")
WARNING:superscreen.solve:Layer 'W2': The film thickness, d = 0.2000 µm, is greater than or equal to the London penetration depth, resulting in an effective penetration depth Λ = 0.0320 µm <= λ = 0.0800 µm. The assumption that the current density is nearly constant over the thickness of the film may not be valid.
WARNING:superscreen.solve:Layer 'W1': The film thickness, d = 0.1000 µm, is greater than or equal to the London penetration depth, resulting in an effective penetration depth Λ = 0.0640 µm <= λ = 0.0800 µm. The assumption that the current density is nearly constant over the thickness of the film may not be valid.
WARNING:superscreen.solve:Layer 'BE': The film thickness, d = 0.1600 µm, is greater than or equal to the London penetration depth, resulting in an effective penetration depth Λ = 0.0400 µm <= λ = 0.0800 µm. The assumption that the current density is nearly constant over the thickness of the film may not be valid.
mutual_numpy = 162.92121 Φ_0/A
[18]:
solution_jax = sc.solve(device, gpu=True, **solve_kwargs)[-1]
mutual_jax = (sum(solution_jax.hole_fluxoid("pl_center")) / I_circ).to("Phi_0 / A")
print(f"mutual_jax = {mutual_jax:.5f~P}")
WARNING:superscreen.solve:Layer 'W2': The film thickness, d = 0.2000 µm, is greater than or equal to the London penetration depth, resulting in an effective penetration depth Λ = 0.0320 µm <= λ = 0.0800 µm. The assumption that the current density is nearly constant over the thickness of the film may not be valid.
WARNING:superscreen.solve:Layer 'W1': The film thickness, d = 0.1000 µm, is greater than or equal to the London penetration depth, resulting in an effective penetration depth Λ = 0.0640 µm <= λ = 0.0800 µm. The assumption that the current density is nearly constant over the thickness of the film may not be valid.
WARNING:superscreen.solve:Layer 'BE': The film thickness, d = 0.1600 µm, is greater than or equal to the London penetration depth, resulting in an effective penetration depth Λ = 0.0400 µm <= λ = 0.0800 µm. The assumption that the current density is nearly constant over the thickness of the film may not be valid.
mutual_jax = 162.92123 Φ_0/A
[19]:
fig, axes = solution_jax.plot_fields(figsize=(10, 2.75))
fig.tight_layout()

Measure the wall time of superscreen.solve()
vs. the number of points in the mesh.
[20]:
min_points = 1000 * np.arange(2, 11, dtype=int)
mesh_sizes, results = run_benchmark(device, min_points, solve_kwargs)
Mesh size: 2280
Numpy (CPU): 1.72 s ± 4.93 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (CPU): 1.98 s ± 7.63 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (GPU): 1.03 s ± 11.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 3032
Numpy (CPU): 3.15 s ± 14.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (CPU): 3.48 s ± 21.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (GPU): 1.21 s ± 12 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 4290
Numpy (CPU): 6.02 s ± 27.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (CPU): 6.03 s ± 30.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (GPU): 1.6 s ± 13.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 5399
Numpy (CPU): 9.39 s ± 33.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (CPU): 8.77 s ± 28.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (GPU): 1.94 s ± 10.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 6374
Numpy (CPU): 12.9 s ± 37.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (CPU): 11.7 s ± 46.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (GPU): 2.29 s ± 14.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 7568
Numpy (CPU): 18 s ± 36.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (CPU): 16 s ± 34.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (GPU): 2.72 s ± 19.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 8379
Numpy (CPU): 21.9 s ± 33.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (CPU): 19.5 s ± 38.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (GPU): 3.12 s ± 17.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 9169
Numpy (CPU): 26.4 s ± 79 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (CPU): 23.4 s ± 68.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (GPU): 3.44 s ± 23.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 10112
Numpy (CPU): 32 s ± 105 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (CPU): 28 s ± 63.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (GPU): 3.85 s ± 22.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
[21]:
fig, axes = plot_benchmark(mesh_sizes, results)
_ = fig.suptitle(
f"Multi-layer device: {len(device.layers)} layers, "
f"{solve_kwargs['iterations']} iterations"
)

Here we measure the solve time vs. the number of iterations for a multi-layer device with a fixed number of mesh vertices.
[22]:
num_iterations = np.arange(11, dtype=int)
min_points = 6000
results = {}
for iterations in num_iterations:
print(f"Number of iterations: {iterations}")
solve_kwargs["iterations"] = iterations
mesh_size, timing = run_benchmark(device, [min_points], solve_kwargs)
results[iterations] = timing
print()
Number of iterations: 0
Mesh size: 6374
Numpy (CPU): 1.11 s ± 8.27 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (CPU): 1.79 s ± 11.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (GPU): 798 ms ± 6.58 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Number of iterations: 1
Mesh size: 6374
Numpy (CPU): 10.3 s ± 29.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (CPU): 7.86 s ± 31.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (GPU): 1.28 s ± 11.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Number of iterations: 2
Mesh size: 6374
Numpy (CPU): 10.9 s ± 23.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (CPU): 8.83 s ± 43.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (GPU): 1.53 s ± 9.82 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Number of iterations: 3
Mesh size: 6374
Numpy (CPU): 11.6 s ± 50.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (CPU): 9.75 s ± 16.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (GPU): 1.78 s ± 9.99 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Number of iterations: 4
Mesh size: 6374
Numpy (CPU): 12.2 s ± 32.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (CPU): 10.7 s ± 38.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (GPU): 2.02 s ± 14.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Number of iterations: 5
Mesh size: 6374
Numpy (CPU): 12.8 s ± 20.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (CPU): 11.8 s ± 274 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (GPU): 2.35 s ± 22 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Number of iterations: 6
Mesh size: 6374
Numpy (CPU): 13.7 s ± 14.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (CPU): 13 s ± 37.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (GPU): 2.57 s ± 19.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Number of iterations: 7
Mesh size: 6374
Numpy (CPU): 14.3 s ± 21.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (CPU): 14 s ± 39 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (GPU): 2.8 s ± 23.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Number of iterations: 8
Mesh size: 6374
Numpy (CPU): 15 s ± 30.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (CPU): 15 s ± 40.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (GPU): 3.02 s ± 20.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Number of iterations: 9
Mesh size: 6374
Numpy (CPU): 15.5 s ± 48.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (CPU): 15.6 s ± 29.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (GPU): 3.26 s ± 19.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Number of iterations: 10
Mesh size: 6374
Numpy (CPU): 16.1 s ± 27.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (CPU): 16.6 s ± 68.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX (GPU): 3.5 s ± 9.07 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
[23]:
def plot_iteration_sweep(num_iterations, results):
fig, axes = plt.subplots(
1, 3, figsize=(10, 3), sharex=True, constrained_layout=True
)
xs = num_iterations
ys_ref = np.array([result["Numpy (CPU)"][0].average for result in results.values()])
labels = list(results[num_iterations[0]])
for label in labels:
ys = np.array([result[label][0].average for result in results.values()])
yerr = np.array([result[label][0].stdev for result in results.values()])
for ax in axes[:2]:
ax.errorbar(xs, ys, yerr=yerr, marker="o", ls="--", label=label)
axes[2].plot(xs, ys_ref / ys, marker="o", ls="--")
for ax in axes:
ax.set_xlabel("Iterations")
ax.set_ylabel("Solve time [s]")
ax.grid(True)
axes[2].set_ylabel("Speed-up vs. Numpy (CPU)")
axes[0].legend(loc=0)
axes[1].set_yscale("log")
return fig, axes
[24]:
fig, axes = plot_iteration_sweep(num_iterations, results)
_ = fig.suptitle(
f"Multi-layer device: {len(device.layers)} layers, "
f"{device.points.shape[0]} mesh vertices"
)

[ ]: