GPU acceleration

This notebook is intended to demonstrate GPU-acceleration of superscreen models using the JAX package from Google.

Below we will look at the time required to solve both single-layer and multi-layer models using three different methods:

  1. Numpy (CPU): This is the default behavior when you call superscreen.solve() or superscreen.solve(..., gpu=False).

  2. JAX (CPU): This method will be used if you call superscreen.solve(..., gpu=True) on a machine that has JAX installed but does not have a GPU available. We test this method by forcing JAX to run on the CPU even if a GPU is available using the context manager jax.default_device(jax.devices("cpu")[0]).

  3. JAX (GPU): This method will be used if you call superscreen.solve(..., gpu=True) on a machine that has JAX installed and has a GPU available. The notebook will simply skip this method if no GPU is available.

Notes

There are several things to keep in mind in order to get the most out of superscreen on a GPU:

  • Currently, only computation performed inside superscreen.solve() can be offloaded to a GPU. Post-processing methods belonging to superscreen.Solution cannot (currently) be run on a GPU. Calculations that make heavy use of these post-processing methods (such as fluxoid optimization) can be expected to show only modest speed-ups from GPU acceleration.

  • There is some overhead associated with moving data onto the GPU for processing, so calls to superscreen.solve() that take only 10s of ms on the CPU (e.g. single-layer models with only 1,000 or so mesh vertices) may actually be slightly slower on the GPU.

  • On the other hand, most data only needs to be transferred to the GPU once per call to superscreen.solve(), so large gains can be seen when iteratively solving models involving multi-layer devices.

In short, you can expect the biggest speed-ups for models with many mesh vertices and/or more than one layer.

[1]:
%config InlineBackend.figure_formats = {"retina", "png"}
%matplotlib inline

import logging
from collections import defaultdict

import jax
import numpy as np
import matplotlib.pyplot as plt

plt.rcParams["figure.figsize"] = (6, 4.5)
plt.rcParams["font.size"] = 12

import superscreen as sc
from superscreen.geometry import circle
[2]:
sc.version_table()
[2]:
SoftwareVersion
SuperScreen0.8.0
Numpy1.23.1
SciPy1.8.1
matplotlib3.5.2
ray1.13.0
jax0.3.15
IPython8.4.0
Python3.9.12 (main, Jun 1 2022, 11:38:51) [GCC 7.5.0]
OSposix [linux]
Number of CPUsPhysical: 2, Logical: 2
BLAS InfoOPENBLAS
Wed Dec 14 15:57:21 2022 PST
[3]:
def run_benchmark(device, min_points_values, solve_kwargs):
    # Suppress warnings about not having a GPU.
    logging.getLogger("superscreen.solve").setLevel(logging.CRITICAL)
    mesh_sizes = []
    results = defaultdict(list)
    for min_points in min_points_values:
        device.make_mesh(min_points=min_points, smooth=50)
        mesh_sizes.append(device.points.shape[0])
        print(f"Mesh size: {device.points.shape[0]}")

        key = "Numpy (CPU)"
        print(f"    {key}: ", end="")
        timing = %timeit -o sc.solve(device, **solve_kwargs)
        results[key].append(timing)

        key = "JAX (CPU)"
        print(f"    {key}: ", end="")
        with jax.default_device(jax.devices("cpu")[0]):
            timing = %timeit -o sc.solve(device, gpu=True, **solve_kwargs)
        results[key].append(timing)

        key = "JAX (GPU)"
        print(f"    {key}: ", end="")
        if "cpu" in jax.devices()[0].device_kind:
            print("Skipping because there is no GPU available.")
        else:
            timing = %timeit -o sc.solve(device, gpu=True, **solve_kwargs)
            results[key].append(timing)
    logging.getLogger("superscreen.solve").setLevel(logging.WARNING)
    return np.array(mesh_sizes), dict(results)
[4]:
def plot_benchmark(mesh_sizes, results):
    fig, axes = plt.subplots(
        1, 3, figsize=(10, 3), sharex=True, constrained_layout=True
    )
    ys_ref = np.array([t.average for t in results["Numpy (CPU)"]])
    for label, timing in results.items():
        xs = mesh_sizes
        ys = np.array([t.average for t in timing])
        yerr = np.array([t.stdev for t in timing])
        for ax in axes[:2]:
            ax.errorbar(xs, ys, yerr=yerr, marker="o", ls="--", label=label)
        axes[2].plot(xs, ys_ref / ys, marker="o", ls="--")
    for ax in axes:
        ax.set_xlabel("Mesh size")
        ax.set_ylabel("Solve time [s]")
        ax.grid(True)
    axes[2].set_ylabel("Speed-up vs. Numpy (CPU)")
    axes[0].legend(loc=0)
    axes[1].set_yscale("log")
    return fig, axes

Single layer device

Here we model a single-layer superconducting ring with some circulating current.

[5]:
length_units = "um"
ro = 3  # outer radius
ri = 1  # inner radius
slit_width = 0.25
layer = sc.Layer("base", london_lambda=0.100, thickness=0.025, z0=0)

ring = circle(ro)
hole = circle(ri)
bounding_box = sc.Polygon("bounding_box", layer="base", points=circle(1.2 * ro))

device = sc.Device(
    "ring",
    layers=[sc.Layer("base", london_lambda=0.100, thickness=0.025, z0=0)],
    films=[sc.Polygon("ring", layer="base", points=ring)],
    holes=[sc.Polygon("hole", layer="base", points=hole)],
    abstract_regions=[bounding_box],
    length_units=length_units,
)
device.solve_dtype = "float32"

fig, ax = device.draw(exclude="bounding_box", legend=True)
../_images/notebooks_gpu_6_0.png

We get the same answer for the mutual inductance using either Numpy on the CPU or JAX.

[6]:
device.make_mesh(min_points=5000, smooth=50)
[7]:
M_numpy = device.mutual_inductance_matrix(units="pH", gpu=False)
display(M_numpy)
Magnitude
[[5.769204365742559]]
Unitspicohenry
[8]:
M_jax = device.mutual_inductance_matrix(units="pH", gpu=True)
display(M_jax)
Magnitude
[[5.76919768357049]]
Unitspicohenry

Measure the wall time of superscreen.solve() vs. the number of points in the mesh.

[9]:
# Model the device with 1 mA circulating current
circulating_currents = {"hole": "1 mA"}
solve_kwargs = dict(
    circulating_currents=circulating_currents,
    field_units="mT",
    current_units="mA",
    log_level=None,
)

# Look at solve time vs. mesh size
min_points = 1000 * np.arange(1, 11, dtype=int)
[10]:
mesh_sizes, results = run_benchmark(device, min_points, solve_kwargs)
Mesh size: 1015
    Numpy (CPU): 27.6 ms ± 252 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
    JAX (CPU): 53 ms ± 1.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 41 ms ± 2.06 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 2105
    Numpy (CPU): 126 ms ± 1.2 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
    JAX (CPU): 198 ms ± 3.39 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 122 ms ± 3.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 3105
    Numpy (CPU): 266 ms ± 485 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 411 ms ± 7.71 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 201 ms ± 2.92 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 4285
    Numpy (CPU): 549 ms ± 7.94 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 816 ms ± 13.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 331 ms ± 2.81 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 5248
    Numpy (CPU): 899 ms ± 8.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 1.33 s ± 23.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 509 ms ± 18.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 6390
    Numpy (CPU): 1.46 s ± 75.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 2.2 s ± 133 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 668 ms ± 10.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 7051
    Numpy (CPU): 1.91 s ± 85.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 2.77 s ± 109 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 854 ms ± 57.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 8748
    Numpy (CPU): 3.04 s ± 89.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 4.34 s ± 129 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 1.27 s ± 62.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 9711
    Numpy (CPU): 3.97 s ± 98 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 5.37 s ± 171 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 1.46 s ± 70.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 10744
    Numpy (CPU): 4.92 s ± 133 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 6.22 s ± 85.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 1.59 s ± 15.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
[11]:
fig, axes = plot_benchmark(mesh_sizes, results)
_ = fig.suptitle("Single layer device")
../_images/notebooks_gpu_14_0.png

Multi-layer device

Here we model a scanning Superconducting QUantum Interference Device (SQUID) susceptometer, which consists of three superconducting layers. In particular, we are interested in the mutual inductance between the field coil (green loop below) and pickup loop (orange loop below), which lie in different superconducting layers (called BE and W1, respectively).

This is a convenient benchmark because the mutual inductance for these devices has been thoroughly characterized experimentally and was found be in the range \(166 \pm 4\,\Phi_0/\mathrm{A}\) (see Table 1 of arXiv:1605.09483 or DOI:10.1063/1.4961982).

[12]:
import squids
[13]:
device = squids.ibm.medium.make_squid(align_layers="bottom")
device.solve_dtype = "float32"
[14]:
_ = device.draw(exclude="bounding_box", legend=True)
../_images/notebooks_gpu_18_0.png
[15]:
# Model the device with 1 mA circulating current
I_circ = device.ureg("1 mA")
circulating_currents = {"fc_center": str(I_circ)}
solve_kwargs = dict(
    circulating_currents=circulating_currents,
    field_units="mT",
    current_units="mA",
    log_level=None,
    iterations=5,
)

We get the same answer for the mutual inductance using either Numpy on the CPU or JAX.

[16]:
device.make_mesh(min_points=6000, smooth=50)
[17]:
solution_numpy = sc.solve(device, **solve_kwargs)[-1]
mutual_numpy = (sum(solution_numpy.hole_fluxoid("pl_center")) / I_circ).to("Phi_0 / A")
print(f"mutual_numpy = {mutual_numpy:.5f~P}")
WARNING:superscreen.solve:Layer 'W2': The film thickness, d = 0.2000 µm, is greater than or equal to the London penetration depth, resulting in an effective penetration depth Λ = 0.0320 µm <= λ = 0.0800 µm. The assumption that the current density is nearly constant over the thickness of the film may not be valid.
WARNING:superscreen.solve:Layer 'W1': The film thickness, d = 0.1000 µm, is greater than or equal to the London penetration depth, resulting in an effective penetration depth Λ = 0.0640 µm <= λ = 0.0800 µm. The assumption that the current density is nearly constant over the thickness of the film may not be valid.
WARNING:superscreen.solve:Layer 'BE': The film thickness, d = 0.1600 µm, is greater than or equal to the London penetration depth, resulting in an effective penetration depth Λ = 0.0400 µm <= λ = 0.0800 µm. The assumption that the current density is nearly constant over the thickness of the film may not be valid.
mutual_numpy = 162.92121 Φ_0/A
[18]:
solution_jax = sc.solve(device, gpu=True, **solve_kwargs)[-1]
mutual_jax = (sum(solution_jax.hole_fluxoid("pl_center")) / I_circ).to("Phi_0 / A")
print(f"mutual_jax = {mutual_jax:.5f~P}")
WARNING:superscreen.solve:Layer 'W2': The film thickness, d = 0.2000 µm, is greater than or equal to the London penetration depth, resulting in an effective penetration depth Λ = 0.0320 µm <= λ = 0.0800 µm. The assumption that the current density is nearly constant over the thickness of the film may not be valid.
WARNING:superscreen.solve:Layer 'W1': The film thickness, d = 0.1000 µm, is greater than or equal to the London penetration depth, resulting in an effective penetration depth Λ = 0.0640 µm <= λ = 0.0800 µm. The assumption that the current density is nearly constant over the thickness of the film may not be valid.
WARNING:superscreen.solve:Layer 'BE': The film thickness, d = 0.1600 µm, is greater than or equal to the London penetration depth, resulting in an effective penetration depth Λ = 0.0400 µm <= λ = 0.0800 µm. The assumption that the current density is nearly constant over the thickness of the film may not be valid.
mutual_jax = 162.92123 Φ_0/A
[19]:
fig, axes = solution_jax.plot_fields(figsize=(10, 2.75))
fig.tight_layout()
../_images/notebooks_gpu_24_0.png

Measure the wall time of superscreen.solve() vs. the number of points in the mesh.

[20]:
min_points = 1000 * np.arange(2, 11, dtype=int)

mesh_sizes, results = run_benchmark(device, min_points, solve_kwargs)
Mesh size: 2280
    Numpy (CPU): 1.72 s ± 4.93 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 1.98 s ± 7.63 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 1.03 s ± 11.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 3032
    Numpy (CPU): 3.15 s ± 14.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 3.48 s ± 21.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 1.21 s ± 12 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 4290
    Numpy (CPU): 6.02 s ± 27.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 6.03 s ± 30.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 1.6 s ± 13.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 5399
    Numpy (CPU): 9.39 s ± 33.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 8.77 s ± 28.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 1.94 s ± 10.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 6374
    Numpy (CPU): 12.9 s ± 37.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 11.7 s ± 46.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 2.29 s ± 14.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 7568
    Numpy (CPU): 18 s ± 36.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 16 s ± 34.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 2.72 s ± 19.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 8379
    Numpy (CPU): 21.9 s ± 33.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 19.5 s ± 38.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 3.12 s ± 17.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 9169
    Numpy (CPU): 26.4 s ± 79 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 23.4 s ± 68.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 3.44 s ± 23.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mesh size: 10112
    Numpy (CPU): 32 s ± 105 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 28 s ± 63.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 3.85 s ± 22.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
[21]:
fig, axes = plot_benchmark(mesh_sizes, results)
_ = fig.suptitle(
    f"Multi-layer device: {len(device.layers)} layers, "
    f"{solve_kwargs['iterations']} iterations"
)
../_images/notebooks_gpu_27_0.png

Here we measure the solve time vs. the number of iterations for a multi-layer device with a fixed number of mesh vertices.

[22]:
num_iterations = np.arange(11, dtype=int)
min_points = 6000

results = {}
for iterations in num_iterations:
    print(f"Number of iterations: {iterations}")
    solve_kwargs["iterations"] = iterations
    mesh_size, timing = run_benchmark(device, [min_points], solve_kwargs)
    results[iterations] = timing
    print()
Number of iterations: 0
Mesh size: 6374
    Numpy (CPU): 1.11 s ± 8.27 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 1.79 s ± 11.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 798 ms ± 6.58 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Number of iterations: 1
Mesh size: 6374
    Numpy (CPU): 10.3 s ± 29.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 7.86 s ± 31.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 1.28 s ± 11.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Number of iterations: 2
Mesh size: 6374
    Numpy (CPU): 10.9 s ± 23.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 8.83 s ± 43.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 1.53 s ± 9.82 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Number of iterations: 3
Mesh size: 6374
    Numpy (CPU): 11.6 s ± 50.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 9.75 s ± 16.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 1.78 s ± 9.99 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Number of iterations: 4
Mesh size: 6374
    Numpy (CPU): 12.2 s ± 32.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 10.7 s ± 38.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 2.02 s ± 14.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Number of iterations: 5
Mesh size: 6374
    Numpy (CPU): 12.8 s ± 20.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 11.8 s ± 274 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 2.35 s ± 22 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Number of iterations: 6
Mesh size: 6374
    Numpy (CPU): 13.7 s ± 14.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 13 s ± 37.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 2.57 s ± 19.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Number of iterations: 7
Mesh size: 6374
    Numpy (CPU): 14.3 s ± 21.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 14 s ± 39 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 2.8 s ± 23.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Number of iterations: 8
Mesh size: 6374
    Numpy (CPU): 15 s ± 30.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 15 s ± 40.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 3.02 s ± 20.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Number of iterations: 9
Mesh size: 6374
    Numpy (CPU): 15.5 s ± 48.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 15.6 s ± 29.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 3.26 s ± 19.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Number of iterations: 10
Mesh size: 6374
    Numpy (CPU): 16.1 s ± 27.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (CPU): 16.6 s ± 68.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    JAX (GPU): 3.5 s ± 9.07 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

[23]:
def plot_iteration_sweep(num_iterations, results):
    fig, axes = plt.subplots(
        1, 3, figsize=(10, 3), sharex=True, constrained_layout=True
    )
    xs = num_iterations
    ys_ref = np.array([result["Numpy (CPU)"][0].average for result in results.values()])
    labels = list(results[num_iterations[0]])
    for label in labels:
        ys = np.array([result[label][0].average for result in results.values()])
        yerr = np.array([result[label][0].stdev for result in results.values()])
        for ax in axes[:2]:
            ax.errorbar(xs, ys, yerr=yerr, marker="o", ls="--", label=label)
        axes[2].plot(xs, ys_ref / ys, marker="o", ls="--")
    for ax in axes:
        ax.set_xlabel("Iterations")
        ax.set_ylabel("Solve time [s]")
        ax.grid(True)
    axes[2].set_ylabel("Speed-up vs. Numpy (CPU)")
    axes[0].legend(loc=0)
    axes[1].set_yscale("log")
    return fig, axes
[24]:
fig, axes = plot_iteration_sweep(num_iterations, results)
_ = fig.suptitle(
    f"Multi-layer device: {len(device.layers)} layers, "
    f"{device.points.shape[0]} mesh vertices"
)
../_images/notebooks_gpu_31_0.png
[ ]: