GPU acceleration

This notebook is intended to demonstrate GPU-acceleration of superscreen models using the JAX package from Google.

Below we will look at the time required to solve both single-layer and multi-layer models using three different methods:

  1. Numpy (CPU): This is the default behavior when you call superscreen.solve() or superscreen.solve(..., gpu=False).

  2. JAX (CPU): This method will be used if you call superscreen.solve(..., gpu=True) on a machine that has JAX installed but does not have a GPU available. We test this method by forcing JAX to run on the CPU even if a GPU is available using the context manager jax.default_device(jax.devices("cpu")[0]).

  3. JAX (GPU): This method will be used if you call superscreen.solve(..., gpu=True) on a machine that has JAX installed and has a GPU available. The notebook will simply skip this method if no GPU is available.

Notes

There are several things to keep in mind in order to get the most out of superscreen on a GPU:

  • Currently, only computation performed inside superscreen.solve() can be offloaded to a GPU. Post-processing methods belonging to superscreen.Solution cannot (currently) be run on a GPU. Calculations that make heavy use of these post-processing methods (such as fluxoid optimization) can be expected to show only modest speed-ups from GPU acceleration.

  • There is some overhead associated with moving data onto the GPU for processing, so calls to superscreen.solve() that take only 10s of ms on the CPU (e.g. single-layer models with only 1,000 or so mesh vertices) may actually be slightly slower on the GPU.

  • On the other hand, most data only needs to be transferred to the GPU once per call to superscreen.solve(), so large gains can be seen when iteratively solving models involving multi-layer devices.

In short, you can expect the biggest speed-ups for models with many mesh vertices and/or more than one layer.

[1]:
%config InlineBackend.figure_formats = {"retina", "png"}
%matplotlib inline

import logging
from collections import defaultdict

import jax
import numpy as np
import matplotlib.pyplot as plt

plt.rcParams["figure.figsize"] = (6, 4.5)
plt.rcParams["font.size"] = 12

import superscreen as sc
from superscreen.geometry import circle
[2]:
sc.version_table()
[2]:
SoftwareVersion
SuperScreen0.6.1
Numpy1.23.1
SciPy1.8.1
matplotlib3.5.2
ray1.13.0
jax0.3.15
IPython8.4.0
Python3.9.12 (main, Jun 1 2022, 11:38:51) [GCC 7.5.0]
OSposix [linux]
Number of CPUsPhysical: 2, Logical: 2
BLAS InfoOPENBLAS
Wed Aug 03 08:52:54 2022 PDT
[3]:
def run_benchmark(device, min_points_values, solve_kwargs):
    # Suppress warnings about not having a GPU.
    logging.getLogger("superscreen.solve").setLevel(logging.CRITICAL)
    mesh_sizes = []
    results = defaultdict(list)
    for min_points in min_points_values:
        device.make_mesh(min_points=min_points, optimesh_steps=10)
        mesh_sizes.append(device.points.shape[0])
        print(f"Mesh size: {device.points.shape[0]}")

        key = "Numpy (CPU)"
        print(f"    {key}: ", end="")
        timing = %timeit -o sc.solve(device, **solve_kwargs)
        results[key].append(timing)

        key = "JAX (CPU)"
        print(f"    {key}: ", end="")
        with jax.default_device(jax.devices("cpu")[0]):
            timing = %timeit -o sc.solve(device, gpu=True, **solve_kwargs)
        results[key].append(timing)

        key = "JAX (GPU)"
        print(f"    {key}: ", end="")
        if "cpu" in jax.devices()[0].device_kind:
            print("Skipping because there is no GPU available.")
        else:
            timing = %timeit -o sc.solve(device, gpu=True, **solve_kwargs)
            results[key].append(timing)
    logging.getLogger("superscreen.solve").setLevel(logging.WARNING)
    return np.array(mesh_sizes), dict(results)
[4]:
def plot_benchmark(mesh_sizes, results):
    fig, axes = plt.subplots(
        1, 3, figsize=(10, 3), sharex=True, constrained_layout=True
    )
    ys_ref = np.array([t.average for t in results["Numpy (CPU)"]])
    for label, timing in results.items():
        xs = mesh_sizes
        ys = np.array([t.average for t in timing])
        yerr = np.array([t.stdev for t in timing])
        for ax in axes[:2]:
            ax.errorbar(xs, ys, yerr=yerr, marker="o", ls="--", label=label)
        axes[2].plot(xs, ys_ref / ys, marker="o", ls="--")
    for ax in axes:
        ax.set_xlabel("Mesh size")
        ax.set_ylabel("Solve time [s]")
        ax.grid(True)
    axes[2].set_ylabel("Speed-up vs. Numpy (CPU)")
    axes[0].legend(loc=0)
    axes[1].set_yscale("log")
    return fig, axes

Single layer device

Here we model a single-layer superconducting ring with some circulating current.

[5]:
length_units = "um"
ro = 3  # outer radius
ri = 1  # inner radius
slit_width = 0.25
layer = sc.Layer("base", london_lambda=0.100, thickness=0.025, z0=0)

ring = circle(ro)
hole = circle(ri)
bounding_box = sc.Polygon("bounding_box", layer="base", points=circle(1.2 * ro))

device = sc.Device(
    "ring",
    layers=[sc.Layer("base", london_lambda=0.100, thickness=0.025, z0=0)],
    films=[sc.Polygon("ring", layer="base", points=ring)],
    holes=[sc.Polygon("hole", layer="base", points=hole)],
    abstract_regions=[bounding_box],
    length_units=length_units,
)
device.solve_dtype = "float32"

fig, ax = device.draw(exclude="bounding_box", legend=True)
../_images/notebooks_gpu_6_0.png

We get the same answer for the mutual inductance using either Numpy on the CPU or JAX.

[6]:
device.make_mesh(min_points=5000, optimesh_steps=10)
[7]:
M_numpy = device.mutual_inductance_matrix(units="pH", gpu=False)
display(M_numpy)
Magnitude
[[5.814198608377977]]
Unitspicohenry
[8]:
M_jax = device.mutual_inductance_matrix(units="pH", gpu=True)
display(M_jax)
Magnitude
[[5.814189082004743]]
Unitspicohenry

Measure the wall time of superscreen.solve() vs. the number of points in the mesh.

[9]:
# Model the device with 1 mA circulating current
circulating_currents = {"hole": "1 mA"}
solve_kwargs = dict(
    circulating_currents=circulating_currents,
    field_units="mT",
    current_units="mA",
    log_level=None,
)

# Look at solve time vs. mesh size
min_points = 1000 * np.arange(1, 11, dtype=int)
[10]:
mesh_sizes, results = run_benchmark(device, min_points, solve_kwargs)
Mesh size: 1008
    Numpy (CPU): 14.3 ms ± 31.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
    JAX (CPU): 38.9 ms ± 1.26 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 38.5 ms ± 1.53 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 2056
    Numpy (CPU): 118 ms ± 110 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
    JAX (CPU): 184 ms ± 2.51 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 98.3 ms ± 15 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 3091
    Numpy (CPU): 249 ms ± 2.91 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 389 ms ± 3.75 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 182 ms ± 2.19 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 4150
    Numpy (CPU): 462 ms ± 212 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 685 ms ± 5.81 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 274 ms ± 1.92 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 5090
    Numpy (CPU): 735 ms ± 669 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 1.06 s ± 8.13 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 388 ms ± 2.14 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 6223
    Numpy (CPU): 1.19 s ± 1.02 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 1.68 s ± 13.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 537 ms ± 6.52 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 7667
    Numpy (CPU): 1.89 s ± 1.77 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 2.57 s ± 11.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 738 ms ± 7.21 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 8511
    Numpy (CPU): 2.46 s ± 6.01 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 3.29 s ± 14 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 898 ms ± 5.62 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 9506
    Numpy (CPU): 3.21 s ± 3.12 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 4.21 s ± 25.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 1.12 s ± 5.39 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 10501
    Numpy (CPU): 4.09 s ± 32.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 5.31 s ± 19.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 1.31 s ± 9.76 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
[11]:
fig, axes = plot_benchmark(mesh_sizes, results)
_ = fig.suptitle("Single layer device")
../_images/notebooks_gpu_14_0.png

Multi-layer device

Here we model a scanning Superconducting QUantum Interference Device (SQUID) susceptometer, which consists of three superconducting layers. In particular, we are interested in the mutual inductance between the field coil (green loop below) and pickup loop (orange loop below), which lie in different superconducting layers (called BE and W1, respectively).

This is a convenient benchmark because the mutual inductance for these devices has been thoroughly characterized experimentally and was found be in the range \(166 \pm 4\,\Phi_0/\mathrm{A}\) (see Table 1 of arXiv:1605.09483 or DOI:10.1063/1.4961982).

[12]:
import squids
[13]:
device = squids.ibm.medium.make_squid(align_layers="bottom")
device.solve_dtype = "float32"
[14]:
_ = device.draw(exclude="bounding_box", legend=True)
../_images/notebooks_gpu_18_0.png
[15]:
# Model the device with 1 mA circulating current
I_circ = device.ureg("1 mA")
circulating_currents = {"fc_center": str(I_circ)}
solve_kwargs = dict(
    circulating_currents=circulating_currents,
    field_units="mT",
    current_units="mA",
    log_level=None,
    iterations=5,
)

We get the same answer for the mutual inductance using either Numpy on the CPU or JAX.

[16]:
device.make_mesh(min_points=6000, optimesh_steps=20)
[17]:
solution_numpy = sc.solve(device, **solve_kwargs)[-1]
mutual_numpy = (sum(solution_numpy.hole_fluxoid("pl_center")) / I_circ).to("Phi_0 / A")
print(f"mutual_numpy = {mutual_numpy:.5f~P}")
WARNING:superscreen.solve:Layer 'W2': The film thickness, d = 0.2000 µm, is greater than or equal to the London penetration depth, resulting in an effective penetration depth Λ = 0.0320 µm <= λ = 0.0800 µm. The assumption that the current density is nearly constant over the thickness of the film may not be valid.
WARNING:superscreen.solve:Layer 'W1': The film thickness, d = 0.1000 µm, is greater than or equal to the London penetration depth, resulting in an effective penetration depth Λ = 0.0640 µm <= λ = 0.0800 µm. The assumption that the current density is nearly constant over the thickness of the film may not be valid.
WARNING:superscreen.solve:Layer 'BE': The film thickness, d = 0.1600 µm, is greater than or equal to the London penetration depth, resulting in an effective penetration depth Λ = 0.0400 µm <= λ = 0.0800 µm. The assumption that the current density is nearly constant over the thickness of the film may not be valid.
mutual_numpy = 164.89460 Φ_0/A
[18]:
solution_jax = sc.solve(device, gpu=True, **solve_kwargs)[-1]
mutual_jax = (sum(solution_jax.hole_fluxoid("pl_center")) / I_circ).to("Phi_0 / A")
print(f"mutual_jax = {mutual_jax:.5f~P}")
WARNING:superscreen.solve:Layer 'W2': The film thickness, d = 0.2000 µm, is greater than or equal to the London penetration depth, resulting in an effective penetration depth Λ = 0.0320 µm <= λ = 0.0800 µm. The assumption that the current density is nearly constant over the thickness of the film may not be valid.
WARNING:superscreen.solve:Layer 'W1': The film thickness, d = 0.1000 µm, is greater than or equal to the London penetration depth, resulting in an effective penetration depth Λ = 0.0640 µm <= λ = 0.0800 µm. The assumption that the current density is nearly constant over the thickness of the film may not be valid.
WARNING:superscreen.solve:Layer 'BE': The film thickness, d = 0.1600 µm, is greater than or equal to the London penetration depth, resulting in an effective penetration depth Λ = 0.0400 µm <= λ = 0.0800 µm. The assumption that the current density is nearly constant over the thickness of the film may not be valid.
mutual_jax = 164.89461 Φ_0/A
[19]:
fig, axes = solution_jax.plot_fields(figsize=(10, 2.75))
fig.tight_layout()
../_images/notebooks_gpu_24_0.png

Measure the wall time of superscreen.solve() vs. the number of points in the mesh.

[20]:
min_points = 1000 * np.arange(2, 11, dtype=int)

mesh_sizes, results = run_benchmark(device, min_points, solve_kwargs)
Mesh size: 2298
    Numpy (CPU): 1.68 s ± 1.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 1.62 s ± 4.21 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 832 ms ± 11.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 3012
    Numpy (CPU): 2.88 s ± 10.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 2.97 s ± 15.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 992 ms ± 9.68 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 4240
    Numpy (CPU): 5.51 s ± 12.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 4.93 s ± 11.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 1.28 s ± 3.49 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 5318
    Numpy (CPU): 8.57 s ± 19.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 7.69 s ± 35.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 1.61 s ± 5.37 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 6244
    Numpy (CPU): 11.8 s ± 23.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 10.2 s ± 42.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 1.91 s ± 6.56 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 7448
    Numpy (CPU): 16.6 s ± 32.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 14.4 s ± 21.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 2.32 s ± 7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 8187
    Numpy (CPU): 20.3 s ± 32.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 17.2 s ± 31 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 2.63 s ± 8.29 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 9016
    Numpy (CPU): 24.5 s ± 26.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 20.8 s ± 51.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 2.93 s ± 10.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 11124
    Numpy (CPU): 37.3 s ± 67.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 31.4 s ± 16.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 3.86 s ± 14.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
[21]:
fig, axes = plot_benchmark(mesh_sizes, results)
_ = fig.suptitle(
    f"Multi-layer device: {len(device.layers)} layers, "
    f"{solve_kwargs['iterations']} iterations"
)
../_images/notebooks_gpu_27_0.png

Here we measure the solve time vs. the number of iterations for a multi-layer device with a fixed number of mesh vertices.

[22]:
num_iterations = np.arange(11, dtype=int)
min_points = 6000

results = {}
for iterations in num_iterations:
    print(f"Number of iterations: {iterations}")
    solve_kwargs["iterations"] = iterations
    mesh_size, timing = run_benchmark(device, [min_points], solve_kwargs)
    results[iterations] = timing
    print()
Number of iterations: 0
Mesh size: 6244
    Numpy (CPU): 919 ms ± 786 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 1.47 s ± 5.92 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 662 ms ± 7.93 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Number of iterations: 1
Mesh size: 6244
    Numpy (CPU): 9.68 s ± 29.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 7.07 s ± 19.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 1.08 s ± 5.42 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Number of iterations: 2
Mesh size: 6244
    Numpy (CPU): 10.2 s ± 36 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 7.92 s ± 31.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 1.29 s ± 6.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Number of iterations: 3
Mesh size: 6244
    Numpy (CPU): 10.8 s ± 23.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 8.71 s ± 24.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 1.52 s ± 11.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Number of iterations: 4
Mesh size: 6244
    Numpy (CPU): 11.4 s ± 23.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 9.57 s ± 26.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 1.73 s ± 7.58 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Number of iterations: 5
Mesh size: 6244
    Numpy (CPU): 11.9 s ± 22.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 10.4 s ± 36.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 1.95 s ± 6.92 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Number of iterations: 6
Mesh size: 6244
    Numpy (CPU): 12.5 s ± 20.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 11.2 s ± 28.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 2.16 s ± 12.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Number of iterations: 7
Mesh size: 6244
    Numpy (CPU): 13.1 s ± 12.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 12.1 s ± 26.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 2.37 s ± 15.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Number of iterations: 8
Mesh size: 6244
    Numpy (CPU): 13.6 s ± 14.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 12.9 s ± 22.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 2.58 s ± 6.99 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Number of iterations: 9
Mesh size: 6244
    Numpy (CPU): 14.2 s ± 34.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 13.7 s ± 30.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 2.81 s ± 11 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Number of iterations: 10
Mesh size: 6244
    Numpy (CPU): 14.8 s ± 30 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 14.6 s ± 63.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 3.02 s ± 17.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

[23]:
def plot_iteration_sweep(num_iterations, results):
    fig, axes = plt.subplots(
        1, 3, figsize=(10, 3), sharex=True, constrained_layout=True
    )
    xs = num_iterations
    ys_ref = np.array([result["Numpy (CPU)"][0].average for result in results.values()])
    labels = list(results[num_iterations[0]])
    for label in labels:
        ys = np.array([result[label][0].average for result in results.values()])
        yerr = np.array([result[label][0].stdev for result in results.values()])
        for ax in axes[:2]:
            ax.errorbar(xs, ys, yerr=yerr, marker="o", ls="--", label=label)
        axes[2].plot(xs, ys_ref / ys, marker="o", ls="--")
    for ax in axes:
        ax.set_xlabel("Iterations")
        ax.set_ylabel("Solve time [s]")
        ax.grid(True)
    axes[2].set_ylabel("Speed-up vs. Numpy (CPU)")
    axes[0].legend(loc=0)
    axes[1].set_yscale("log")
    return fig, axes
[24]:
fig, axes = plot_iteration_sweep(num_iterations, results)
_ = fig.suptitle(
    f"Multi-layer device: {len(device.layers)} layers, "
    f"{device.points.shape[0]} mesh vertices"
)
../_images/notebooks_gpu_31_0.png
[ ]: