Changelog
All notable changes to TorchGW are documented in this file.
[0.4.1] — 2026-04-09
Exact differentiable gradients via implicit differentiation. Fixes a correctness bug where the previous “envelope theorem” backward produced gradients with up to 30x error (cosine similarity as low as 0.07).
Breaking Changes
Default gradient computation for
differentiable=Trueis now implicit differentiation (exact) instead of the old frozen-potentials approximation. No API change needed — the new default is strictly better.
New Features
grad_modeparameter forsampled_gw— controls how gradients are computed whendifferentiable=True:"implicit"(default): exact gradient via adjoint system at the Sinkhorn fixed point. Solved via Schur complement on the Sinkhorn Jacobian. Memory: O(NK + K^2). Same speed as the old approximate mode."unrolled": exact gradient via unrolled PyTorch autograd. Memory: O(NK * sinkhorn_iters). Useful as fallback at extremely small epsilon.
Bug Fixes
Gradient correctness — The old backward formula
grad_C = -grad_T * T / regtreated Sinkhorn potentials as constants (“frozen-potentials”). This is a first-order approximation that ignores how the potentials depend on C through the Sinkhorn iterations. The new implicit differentiation backward solves the adjoint system derived from the implicit function theorem at the Sinkhorn fixed point, giving exact gradients.Adjoint solver stability — Initial implementation used fixed-point iteration for the adjoint system, which diverges when the spectral radius >= 1 (common at small epsilon). Replaced with Schur complement direct solve on the Sinkhorn Jacobian J^T (eigenvalues in [0,2], well-conditioned). Null space from potential constant ambiguity removed via rank-1 correction (11^T/K).
Warning for non-differentiable pure GW —
differentiable=Truewithfgw_alpha=0now emits a warning, since gradients cannot flow through precomputed graph distances.
Internal
_SinkhornAutogradrenamed to_SinkhornApproximate(frozen-potentials, used internally for semi-relaxed only)New
_SinkhornImplicitclass (exact VJP via adjoint)New
_sinkhorn_unrolledfunction (exact VJP via autograd)_sinkhorn_differentiablerewritten as dispatcher for the three backends
Tests
8 new gradient correctness tests in
tests/test_sinkhorn_grad.py: implicit vs unrolled (rel_err < 2%), implicit vs finite differences, descent direction, non-uniform marginals, approximate formula verification, grad_mode validationIntegration tests for
grad_modethroughsampled_gw111 total tests, all passing
Documentation
docs/algorithm.rst: full derivation of implicit differentiation (Jacobian structure, adjoint equation, Schur complement, null-space handling)README: updated news, differentiable mode example with
grad_mode
[0.4.0] — 2026-04-07
Major performance and robustness release. 3-6x faster on typical workloads, with Triton GPU kernel acceleration, mixed precision support, and comprehensive numerical stability fixes.
Performance
Triton fused Sinkhorn kernels — Custom GPU kernels for the Sinkhorn row/column logsumexp updates, reducing kernel launches from ~6 to 1 per half-step. 2-5x speedup on the Sinkhorn portion (5001×6001 fp32: 261ms → 50ms). Includes fused transport plan materialization and fused marginal error check. Falls back to PyTorch automatically when Triton is unavailable. (
_triton_sinkhorn.py)Sinkhorn warm-start — Reuse log-domain potentials (log_u, log_v) from the previous GW iteration as initial values. Reduces Sinkhorn convergence from ~10 to ~3-5 steps.
GPU sampling — Replace CPU numpy sampling with
torch.multinomialon GPU. Transfers 2×M integers instead of the full N×K transport plan per iteration.Mixed precision — New
mixed_precision=Trueparameter runs Sinkhorn in float32 (safe in log domain) while keeping marginals and output in float64. Up to 1.7x faster on A100/L40S; larger gains expected on consumer GPUs.Dijkstra caching —
DijkstraProvidercaches per-node SSSP results across iterations with FIFO eviction (max 2000 rows per side). Avoids redundant computation when the same anchor nodes are re-sampled.Cost plateau early stopping — GW cost EMA + patience-based convergence detection. Stops when the smoothed cost stops improving, rather than waiting for the noisy
||T - T_prev||to drop belowtol(which may never happen due to sampling noise). Example: dijkstra 1000×1200 stops at 97 iters instead of running all 500.Parallel all-pairs Dijkstra —
PrecomputedProviderruns source and target graph Dijkstra in parallel via process-based parallelism (scipy holds the GIL). 1.2-1.5x speedup on large graphs (≥2000 total nodes).Reduced CUDA sync points — Convergence checks batched every 5 iterations; augmented penalty computed on GPU without
.item()sync.Sinkhorn convergence check via logsumexp — Avoids materializing full N×K matrix for the marginal error computation.
Pre-allocated augmented cost matrix — Reused across iterations instead of re-allocated each step.
New Features
mixed_precisionparameter forsampled_gwandsampled_lowrank_gwlambda_ema_betaparameter for cost matrix EMA smoothing (variance reduction)Verbose Sinkhorn output (
verbose=Trueprints per-iteration marginal errors)sample_pairs_gpu()— GPU-native weighted sampling functionsample_pairs_from_plan()now accepts optionalrngparameter for reproducibility
Bug Fixes
Numerical stability
torch.log(a + 1e-300)replaced with.clamp(min=1e-30)— the 1e-300 constant vanishes in float32, providing no protection against log(0)Regularization decay capped at 10x to prevent instability with large epsilon values
Low-rank mirror descent: enforce
gamma * reg >= 1to prevent exponential overflowHandle all-inf distance matrices (fully disconnected subgraphs) without crashing
sample_pairs_gpucasts to float32 beforetorch.multinomial(required on some PyTorch versions/devices)
Correctness
kNN graph symmetrized via
.maximum(.T)—kneighbors_graphreturns directed edgesSemi-relaxed Sinkhorn: correct KL proximal blend
tau * new + (1-tau) * oldinstead oftau * newwhich discarded historydifferentiable=True+semi_relaxed=Truenow raisesNotImplementedError(envelope theorem gradient is invalid for unbalanced Sinkhorn)Detach
T_previn differentiable mode to prevent computation graph accumulation across GW iterations (OOM after many iterations)joint_embedding: prevent index out-of-bounds whenout_dim > k_svdslambda_ema_beta=0.0now disables EMA (previously locked to first iteration’s cost)Dijkstra cache eviction safety: never evict keys needed by the current request
Compatibility
scipy.sparse.linalg.cg: auto-detecttolvsrtolparameter name for SciPy 1.10-1.17+ compatibilitysampled_lowrank_gw:semi_relaxedvalidation moved to function start (fail-fast)
API consistency
sampled_lowrank_gwnow acceptsmixed_precisionparameterRemoved unused
semi_relaxed/rho/**kwargsfromsinkhorn_lowranksignaturesample_pairs_from_planreturns(rows, cols)arrays instead oflist[tuple]
Tests
72 tests covering all solver modes, mixed precision, early stopping, Dijkstra cache, differentiable gradients, boundary values, and semi-relaxed mode
Test suite runs in ~18s (down from ~68s before optimizations)
Documentation
docs/optimization-log.md— Detailed optimization history with benchmarksdocs/improvements.md— Updated future directions (torch.compile, cuGraph, Triton extensions)
[0.3.0] — 2026-04-03
Initial public release.
Sampled Gromov-Wasserstein solver (
sampled_gw) with log-domain SinkhornLow-rank solver (
sampled_lowrank_gw) via mirror descent + DykstraThree distance modes:
dijkstra,precomputed,landmarkFused Gromov-Wasserstein (
fgw_alphablending)Multiscale warm-start via farthest-point sampling
Differentiable transport plans (
differentiable=True)Semi-relaxed mode for unbalanced transport
Joint manifold embedding (
joint_embedding)kNN graph construction with component stitching (
build_knn_graph)