Changelog
All notable changes to TorchGW are documented in this file.
[0.4.0] — 2026-04-07
Major performance and robustness release. 3-6x faster on typical workloads, with Triton GPU kernel acceleration, mixed precision support, and comprehensive numerical stability fixes.
Performance
Triton fused Sinkhorn kernels — Custom GPU kernels for the Sinkhorn row/column logsumexp updates, reducing kernel launches from ~6 to 1 per half-step. 2-5x speedup on the Sinkhorn portion (5001×6001 fp32: 261ms → 50ms). Includes fused transport plan materialization and fused marginal error check. Falls back to PyTorch automatically when Triton is unavailable. (
_triton_sinkhorn.py)Sinkhorn warm-start — Reuse log-domain potentials (log_u, log_v) from the previous GW iteration as initial values. Reduces Sinkhorn convergence from ~10 to ~3-5 steps.
GPU sampling — Replace CPU numpy sampling with
torch.multinomialon GPU. Transfers 2×M integers instead of the full N×K transport plan per iteration.Mixed precision — New
mixed_precision=Trueparameter runs Sinkhorn in float32 (safe in log domain) while keeping marginals and output in float64. Up to 1.7x faster on A100/L40S; larger gains expected on consumer GPUs.Dijkstra caching —
DijkstraProvidercaches per-node SSSP results across iterations with FIFO eviction (max 2000 rows per side). Avoids redundant computation when the same anchor nodes are re-sampled.Cost plateau early stopping — GW cost EMA + patience-based convergence detection. Stops when the smoothed cost stops improving, rather than waiting for the noisy
||T - T_prev||to drop belowtol(which may never happen due to sampling noise). Example: dijkstra 1000×1200 stops at 97 iters instead of running all 500.Parallel all-pairs Dijkstra —
PrecomputedProviderruns source and target graph Dijkstra in parallel via process-based parallelism (scipy holds the GIL). 1.2-1.5x speedup on large graphs (≥2000 total nodes).Reduced CUDA sync points — Convergence checks batched every 5 iterations; augmented penalty computed on GPU without
.item()sync.Sinkhorn convergence check via logsumexp — Avoids materializing full N×K matrix for the marginal error computation.
Pre-allocated augmented cost matrix — Reused across iterations instead of re-allocated each step.
New Features
mixed_precisionparameter forsampled_gwandsampled_lowrank_gwlambda_ema_betaparameter for cost matrix EMA smoothing (variance reduction)Verbose Sinkhorn output (
verbose=Trueprints per-iteration marginal errors)sample_pairs_gpu()— GPU-native weighted sampling functionsample_pairs_from_plan()now accepts optionalrngparameter for reproducibility
Bug Fixes
Numerical stability
torch.log(a + 1e-300)replaced with.clamp(min=1e-30)— the 1e-300 constant vanishes in float32, providing no protection against log(0)Regularization decay capped at 10x to prevent instability with large epsilon values
Low-rank mirror descent: enforce
gamma * reg >= 1to prevent exponential overflowHandle all-inf distance matrices (fully disconnected subgraphs) without crashing
sample_pairs_gpucasts to float32 beforetorch.multinomial(required on some PyTorch versions/devices)
Correctness
kNN graph symmetrized via
.maximum(.T)—kneighbors_graphreturns directed edgesSemi-relaxed Sinkhorn: correct KL proximal blend
tau * new + (1-tau) * oldinstead oftau * newwhich discarded historydifferentiable=True+semi_relaxed=Truenow raisesNotImplementedError(envelope theorem gradient is invalid for unbalanced Sinkhorn)Detach
T_previn differentiable mode to prevent computation graph accumulation across GW iterations (OOM after many iterations)joint_embedding: prevent index out-of-bounds whenout_dim > k_svdslambda_ema_beta=0.0now disables EMA (previously locked to first iteration’s cost)Dijkstra cache eviction safety: never evict keys needed by the current request
Compatibility
scipy.sparse.linalg.cg: auto-detecttolvsrtolparameter name for SciPy 1.10-1.17+ compatibilitysampled_lowrank_gw:semi_relaxedvalidation moved to function start (fail-fast)
API consistency
sampled_lowrank_gwnow acceptsmixed_precisionparameterRemoved unused
semi_relaxed/rho/**kwargsfromsinkhorn_lowranksignaturesample_pairs_from_planreturns(rows, cols)arrays instead oflist[tuple]
Tests
72 tests covering all solver modes, mixed precision, early stopping, Dijkstra cache, differentiable gradients, boundary values, and semi-relaxed mode
Test suite runs in ~18s (down from ~68s before optimizations)
Documentation
docs/optimization-log.md— Detailed optimization history with benchmarksdocs/improvements.md— Updated future directions (torch.compile, cuGraph, Triton extensions)
[0.3.0] — 2026-04-03
Initial public release.
Sampled Gromov-Wasserstein solver (
sampled_gw) with log-domain SinkhornLow-rank solver (
sampled_lowrank_gw) via mirror descent + DykstraThree distance modes:
dijkstra,precomputed,landmarkFused Gromov-Wasserstein (
fgw_alphablending)Multiscale warm-start via farthest-point sampling
Differentiable transport plans (
differentiable=True)Semi-relaxed mode for unbalanced transport
Joint manifold embedding (
joint_embedding)kNN graph construction with component stitching (
build_knn_graph)