TorchGW – Fast Sampled Gromov-Wasserstein Optimal Transport
Note
Source code: github.com/chansigit/torchgw — clone, star, or open issues on GitHub.
TorchGW is a scalable solver for Gromov-Wasserstein optimal transport, implemented in pure PyTorch with GPU-accelerated Triton fused Sinkhorn kernels.
It aligns two point clouds by matching their internal distance structures — even when the point clouds live in different dimensions — making it ideal for manifold alignment, single-cell multi-omics integration, and cross-domain graph matching.
Key features:
Up to 175x faster than POT on typical workloads (spiral 4000×5000: 1s vs 183s)
Triton fused Sinkhorn — single-pass online logsumexp, no intermediate N×K matrices
Mixed precision — float32 Sinkhorn + float64 output, zero quality loss
Smart early stopping — cost plateau detection, not just transport plan norm
Differentiable — use GW cost as a training loss with autograd support
No POT dependency at runtime — pure PyTorch + scipy + scikit-learn
What’s New in v0.4.0
Triton fused Sinkhorn kernels (2–5x GPU speedup)
Mixed precision support (
mixed_precision=True)Cost plateau early stopping
Sinkhorn warm-start across GW iterations
15 numerical stability and correctness fixes
See Changelog for details
Contents
- Quick Start
- API Reference
- Algorithm
- Benchmark
- Changelog
- TorchGW Optimization Log
- Summary
- Phase 1: GPU Sampling + Kernel Fusion Prep
- Phase 2: Dijkstra Caching
- Phase 3: Mixed Precision
- Phase 4: Early Stopping
- Phase 5: Sync Reduction
- Phase 6: Triton Fused Sinkhorn
- Phase 7: Sinkhorn Warm-Start
- Phase 8: Parallel Preprocessing
- Bug Fixes During Optimization
- Remaining Bottlenecks
- Future Directions
Installation
pip install -e .
Dependencies: numpy, scipy, scikit-learn, torch, joblib.
Triton (ships with PyTorch 2.0+) enables GPU kernel fusion automatically.
Quick Example
from torchgw import sampled_gw
T = sampled_gw(X_source, X_target, distance_mode="landmark", mixed_precision=True)
# T[i,j] = optimal coupling weight between source point i and target point j
Source Code & Links
GitHub repository: chansigit/torchgw
Issue tracker: GitHub Issues
Changelog: CHANGELOG.md
PyPI: coming soon
# Clone and install from source
git clone https://github.com/chansigit/torchgw.git
cd torchgw
pip install -e .
Citation
If you use TorchGW in your research, please cite:
@software{torchgw,
author = {Sijie Chen},
title = {TorchGW: Fast Sampled Gromov-Wasserstein Optimal Transport},
url = {https://github.com/chansigit/torchgw},
version = {0.4.0},
year = {2026},
}
License
Free for academic and non-commercial use. Commercial use requires a separate license. See LICENSE and COMMERCIAL_LICENSE.md for details. Contact: chansigit@gmail.com