Algorithm
Gromov-Wasserstein Optimal Transport
Given two metric spaces \((\mathcal{X}, C_1)\) and \((\mathcal{Y}, C_2)\) with distributions \(p\) and \(q\), the Gromov-Wasserstein distance finds a transport plan \(T\) that minimizes:
subject to \(T \mathbf{1} = p\) and \(T^\top \mathbf{1} = q\).
Unlike Wasserstein distance, GW does not require the two spaces to share a common metric — it compares intra-domain distances, making it suitable for cross-domain alignment (e.g., different modalities, different dimensionalities).
Sampled GW (TorchGW)
Standard entropic GW computes a cost matrix of size \(N \times K\) at each iteration using all \(N \times N\) and \(K \times K\) pairwise distances. TorchGW reduces this by sampling \(M\) anchor pairs per iteration.
Each iteration:
Sample \(M\) anchor pairs \((i, j)\) from the current transport plan \(T\), weighted by coupling mass.
Dijkstra shortest paths from the \(\leq M\) unique sampled source nodes on both kNN graphs.
Cost matrix assembly on GPU:
\[\Lambda = \text{mean}(D_{\text{left}}^2) - \frac{2}{M} D_{\text{left}} D_{\text{tgt}}^\top + \text{mean}(D_{\text{tgt}}^2)\]Augmented Sinkhorn with slack variables for partial transport. The cost matrix is augmented to \((N+1) \times (K+1)\) with penalty rows/columns, allowing the solver to assign mass to “slack” when alignment is poor.
Momentum update:
\[T \leftarrow (1 - \alpha) T_{\text{prev}} + \alpha T_{\text{new}}\]
Complexity
Component |
Standard GW |
TorchGW |
|---|---|---|
Cost matrix per iter |
\(O(NK(N+K))\) |
\(O(NKM)\) |
Dijkstra per iter |
\(O((N+K)(N+K) \log(N+K))\) |
\(O(M(N+K) \log(N+K))\) |
Sinkhorn per iter |
\(O(NK)\) |
\(O(NK)\) (same) |
With \(M \ll \min(N, K)\), TorchGW achieves sub-quadratic scaling in the number of Dijkstra computations while maintaining the same Sinkhorn cost.
Log-Domain Sinkhorn
TorchGW uses a pure-PyTorch log-domain Sinkhorn implementation for numerical stability with small regularization \(\varepsilon\):
where \(\log K = -C / \varepsilon\) and \(\tau = 1\) for balanced GW.
Semi-Relaxed GW
Setting \(\tau = \rho / (\rho + \varepsilon)\) with \(\tau < 1\) relaxes the target marginal constraint via a KL divergence penalty:
\(\rho \to \infty\): \(\tau \to 1\), recovers balanced GW
\(\rho \to 0\): \(\tau \to 0\), target marginal is completely free
This is useful when source and target have different compositions.
Differentiable Sinkhorn
When differentiable=True, TorchGW uses a custom torch.autograd.Function
that:
Runs the Sinkhorn loop in forward without recording the computation graph
Saves only the resulting transport plan \(T\)
Computes gradients via the envelope theorem:
\[\frac{\partial \mathcal{L}}{\partial C} = -\frac{T}{\varepsilon} \cdot \frac{\partial \mathcal{L}}{\partial T}\]
This avoids backpropagating through all Sinkhorn iterations, making memory cost \(O(NK)\) regardless of the number of Sinkhorn steps.
Regularization Schedule
The entropic regularization \(\varepsilon\) is decayed exponentially during optimization:
where \(\varepsilon_{\min} = \min(5 \times 10^{-4}, \varepsilon_0)\). Large initial regularization helps exploration; small final regularization sharpens the transport plan.