TL;DR: I wrote a small Cython + OpenMP library that resamples 3D volumes (e.g. medical images). It’s an (almost) drop-in replacement for torch.nn.functional.interpolate and torch.nn.functional.grid_sample, but runs up to 13× faster on CPU and doesn’t require PyTorch at all.

GitHub: JoHof/volresample
pip install volresample


The Problem

If you’ve ever built a medical imaging pipeline — segmentation, registration, anything involving CT or MRI — you spend a surprising fraction of your runtime just moving voxels around. Resampling is everywhere:

  • Normalizing spacing before feeding a scan into a network
  • Augmenting training data with resizing or deformations
  • Running inference in a sliding-window fashion, then stitching patches back together at a different resolution
  • Post-processing: upsampling predicted segmentation masks back to the original patient space

For all of these, PyTorch’s F.interpolate and F.grid_sample are the go-to tools. They’re well-tested, GPU-accelerated, and have a clean API. But then comes the awkward reality in most production and research pipelines:

  1. You’re working on CPU anyway. Most preprocessing and postprocessing happens outside the GPU hot path, in DataLoader workers, in MONAI transforms, in post-hoc analysis scripts. You don’t have a CUDA context, and you don’t want one.
  2. PyTorch is at least a 2 GB install. If you’re shipping a lightweight container, a CLI tool, or just a utility script, pulling in all of PyTorch for F.interpolate is embarrassing.
  3. PyTorch CPU performance has some surprising gaps. Area-mode interpolation doesn’t parallelize well for single-image workloads on PyTorch CPU. int16 isn’t natively supported and needs round-trips through float32. These aren’t bugs — they’re just not the priority for a GPU-first library.

What volresample does

It exposes two functions, intentionally (almost) mirroring PyTorch’s API:

import volresample

# Like F.interpolate(..., mode='trilinear', align_corners=False)
resampled = volresample.resample(volume, (128, 128, 128), mode='linear')

# Like F.grid_sample(..., align_corners=False)
sampled = volresample.grid_sample(input, grid, mode='linear', padding_mode='zeros')

Both accept NumPy arrays. The resample() function handles 3D (D, H, W), 4D (C, D, H, W) and 5D (N, C, D, H, W) inputs. Thread count is configurable:

volresample.set_num_threads(8)

That’s essentially the whole API.


Use Cases

Preprocessing Pipelines

In medical imaging, raw scans come in all kinds of resolutions — a CT scan might be 512×512×350 with non-isotropic resolution, while your network expects 128×128×128 or an isotropic resolution of 1.5 mm. You need to resample every single scan before training, and potentially again for each augmented copy. For evaluation you might have to resample back to the original spacing.

Data Augmentation

Random elastic and affine augmentations are typically implemented as a grid_sample call: compute a displacement field, then sample the volume at displaced coordinates. This is the hot path in augmentation-heavy training.

grid_sample at 4 threads on a 128³ volume: 38 ms in volresample vs 169 ms in PyTorch — a 4.4× speedup, and it scales to nearly 6× with reflection padding mode.

For the person whose augmentation pipeline is the DataLoader bottleneck (it’s more people than you’d think), this kind of throughput difference directly translates into faster experiments.

Inference Without a GPU

Lots of clinical deployment happens on CPU-only hardware. Maybe it’s a hospital workstation with no GPU, a laptop, an edge device, or a cloud VM that doesn’t have GPU quota. Your model runs on PyTorch, but most of the work around the model — resampling the input to the right spacing, resampling the output back — can be done faster without PyTorch at all.


The Benchmarks

All benchmarks on an Intel i7-8565U (4 cores), mean over 10 iterations, comparing against PyTorch 2.8.0.

resample() — 512³ → 256³:

Mode volresample (1T) PyTorch (1T) Speedup volresample (4T) PyTorch (4T) Speedup
nearest 23.6 ms 38.0 ms 1.6× 12.6 ms 16.7 ms 1.3×
linear 99.9 ms 182 ms 1.8× 34.3 ms 54.6 ms 1.6×
area 230 ms 611 ms 2.7× 64.5 ms 613 ms 9.5×
nearest (uint8) 13.7 ms 33.8 ms 2.5× 4.3 ms 10.4 ms 2.4×
nearest (int16) 16.5 ms 217 ms 13.2× 8.4 ms 93.2 ms 11.2×

grid_sample() — 128³ input:

Mode volresample (1T) PyTorch (1T) Speedup volresample (4T) PyTorch (4T) Speedup
linear/zeros 118 ms 181 ms 1.5× 38.1 ms 169 ms 4.4×
linear/reflection 103 ms 211 ms 2.1× 33.2 ms 194 ms 5.9×

Average across all benchmarks: 3.1× at 1 thread, 6.0× at 4 threads.

Two results deserve more explanation because they’re not just “Cython is fast”:

The int16 story (13.2× at 1 thread): PyTorch has no native int16 interpolation path. When you want to call F.interpolate on an int16 tensor, you have to cast to float32 and eventually back. For a 512³ volume, that’s two full-volume type conversions totalling ~256 MB of memory traffic, all to avoid a native implementation. volresample compiles a specialized nearest-neighbor path for int16 directly using Cython’s fused types — no conversion, no extra allocation. The 13× speedup is almost entirely the elimination of type casting overhead and unnecessary memory bandwidth.

The area mode story (9.5× at 4 threads): Area interpolation (box filtering) is often the right choice for downsampling because it avoids aliasing. PyTorch’s CPU implementation of area mode for 3D doesn’t appear to parallelize over spatial dimensions for single-image workloads — you get 611 ms whether you use 1 thread or 4. volresample parallelizes over the depth dimension, going from 230 ms at 1 thread down to 65 ms at 4. If you’ve been downsampling with trilinear to avoid bad PyTorch area performance, now you don’t have to.


Architecture Choices

Some choices I made which I think helped with speed.

Fused types for multi-dtype without code duplication

Nearest-neighbor interpolation works on any integer or float type; you’re just reading a value and copying it. Rather than writing three separate implementations, Cython’s fused types let you write one:

ctypedef fused numeric_type:
    uint8_t
    int16_t
    float

cdef void _resample_nearest(numeric_type* data_ptr, numeric_type* output_ptr, ...) noexcept nogil:
    # one implementation, compiled three times at specialization

At compile time, Cython generates separate specializations for each type. At runtime, it dispatches based on the dtype of the numpy array. The dispatch overhead is a single Python-level type check before the C call; the inner loop runs with no branches.

Linear and area modes only support float32 since interpolation requires fractional weights. This is the correct trade-off — operating in lower precision during interpolation would compromise accuracy.

Pre-computed index tables

The most obvious optimization that many naive implementations miss (also the helpful LLM I used during development by the way): source coordinates depend only on the output index and the scale ratio. Computing them inside the inner loop is wasteful. Pre-computing them into a small table eliminates redundant floating-point arithmetic for every voxel:

# Computed once, O(D + H + W) work
for od in range(out_d): d_indices[od] = compute_index(od, scale_d)
for oh in range(out_h): h_indices[oh] = compute_index(oh, scale_h)
# ...

# Main loop: pure table lookups
for od in prange(out_d, nogil=True):
    for oh in range(out_h):
        for ow in range(out_w):
            value = data_ptr[d_indices[od] * H * W + h_indices[oh] * W + w_indices[ow]]

For trilinear interpolation the pre-computed tables also hold the fractional weights, which are reused across the 8 neighbors without recomputation.

No branching in inner loops

For grid_sample, different padding modes (zeros, border, reflection) require fundamentally different out-of-bounds handling. Checking a flag inside the innermost loop is cheap in absolute terms, but it may break the compiler’s ability to reason about the code and prevent certain vectorization patterns.

The solution is to specialize: there are 6 separate functions, one for each (mode, padding_mode) combination. The Python entry point dispatches to the right one based on the arguments. The inner loop is clean, predictable code without branching that the compiler can actually optimize. I’d not code like that without LLM support though — maintaining so much duplicated code would be a nightmare. Strange new world.

_grid_sample_bilinear_zeros()
_grid_sample_bilinear_border()
_grid_sample_bilinear_reflection()
_grid_sample_nearest_zeros()
...

OpenMP on the outermost spatial dimension

Parallelization happens on the depth dimension (prange over out_d). This is the standard choice because:

  • It maximizes the chunk size per thread, minimizing synchronization overhead
  • Each thread gets a contiguous slab of memory to write, improving cache efficiency
  • There are no data dependencies between output voxels

Thread count defaults to min(cpu_count, 4) on the theory that most machines top out their beneficial scaling around 4 threads for memory-bound workloads, and you don’t want a library quietly saturating all cores. It’s fully configurable though.

Build system: architecture-aware compiler flags

# setup.py
if platform.machine() in ('x86_64', 'AMD64', 'i686', 'i386'):
    extra_compile_args = ['-O3', '-mavx2', '-mfma', '-fopenmp']
else:
    extra_compile_args = ['-O3', '-fopenmp']  # ARM, etc.

AVX2 + FMA flags are applied on x86 to let the compiler auto-vectorize. The combined Cython directives (boundscheck=False, wraparound=False, cdivision=True, initializedcheck=False) eliminate nearly all Python overhead from the inner loop.

Testing against PyTorch

Every mode, dtype, and padding combination is tested against PyTorch’s output with atol=1e-5. This isn’t just a correctness check — it’s also documentation. The test suite is the canonical statement of behavioral compatibility:

torch_output = F.interpolate(input, size, mode='trilinear', align_corners=False)
cython_output = volresample.resample(input, size, mode='linear')
assert np.allclose(torch_output.numpy(), cython_output, atol=1e-5)

What it doesn’t do

To be honest about scope:

  • No GPU. If your data is already on a CUDA device, use PyTorch. The library is for CPU workloads.
  • No 2D images. Not directly exposed in the API, but achievable via volumes with a singleton dimension (1, H, W). No multi-threading benefit in that case, though.
  • No autograd. It’s a NumPy library. There’s no gradient computation.
  • float32 only for trilinear/area. By design — interpolation in lower precision gives questionable results, especially for medical images where intensity values carry clinical meaning.
  • No complete duplication of the PyTorch API. I didn’t want to carry over the technical debt of the API. Specifically the legacy nearest-neighbor coordinate convention (nearest vs nearest-exact). linear is linear regardless of dimensionality, and I also didn’t bother with an optional align_corners parameter.

Thus:

PyTorch correspondence:

volresample PyTorch F.interpolate
mode='nearest' mode='nearest-exact'
mode='linear' mode='trilinear'
mode='area' mode='area'
volresample PyTorch F.grid_sample
mode='nearest' mode='nearest'
mode='linear' mode='bilinear'

Prior art

SciPy’s ndimage.zoom also does this and has been doing it for decades. The main differences: volresample is faster for the modes it supports (SciPy zoom with trilinear is slower), has explicit batched multi-channel support, and matches PyTorch’s coordinate conventions rather than SciPy’s, which matters when mixing them in a training pipeline.

SimpleITK’s ResampleImageFilter is the standard for actual clinical resampling and handles all the metadata correctly (spacing, direction cosines, etc.). volresample makes no attempt to handle image metadata — it’s a tensor operation, not a specific medical image processing library.


Closing

I’m not going to pretend this is a revolutionary piece of work. It’s a fairly small Cython library that does one thing well. But “one thing well” is exactly what some people might need.

If you’re building a medical imaging system in Python and your profiler keeps pointing at resampling, give it a try. If you find something wrong with the numerical results or the performance claims, I want to know.