Algorithm
Gromov-Wasserstein Optimal Transport
Given two metric spaces \((\mathcal{X}, C_1)\) and \((\mathcal{Y}, C_2)\) with distributions \(p\) and \(q\), the Gromov-Wasserstein distance finds a transport plan \(T\) that minimizes:
subject to \(T \mathbf{1} = p\) and \(T^\top \mathbf{1} = q\).
Unlike Wasserstein distance, GW does not require the two spaces to share a common metric — it compares intra-domain distances, making it suitable for cross-domain alignment (e.g., different modalities, different dimensionalities).
Sampled GW (TorchGW)
Standard entropic GW computes a cost matrix of size \(N \times K\) at each iteration using all \(N \times N\) and \(K \times K\) pairwise distances. TorchGW reduces this by sampling \(M\) anchor pairs per iteration.
Each iteration:
Sample \(M\) anchor pairs \((i, j)\) from the current transport plan \(T\), weighted by coupling mass.
Dijkstra shortest paths from the \(\leq M\) unique sampled source nodes on both kNN graphs.
Cost matrix assembly on GPU:
\[\Lambda = \text{mean}(D_{\text{left}}^2) - \frac{2}{M} D_{\text{left}} D_{\text{tgt}}^\top + \text{mean}(D_{\text{tgt}}^2)\]Augmented Sinkhorn with slack variables for partial transport. The cost matrix is augmented to \((N+1) \times (K+1)\) with penalty rows/columns, allowing the solver to assign mass to “slack” when alignment is poor.
Momentum update:
\[T \leftarrow (1 - \alpha) T_{\text{prev}} + \alpha T_{\text{new}}\]
Complexity
Component |
Standard GW |
TorchGW |
|---|---|---|
Cost matrix per iter |
\(O(NK(N+K))\) |
\(O(NKM)\) |
Dijkstra per iter |
\(O((N+K)(N+K) \log(N+K))\) |
\(O(M(N+K) \log(N+K))\) |
Sinkhorn per iter |
\(O(NK)\) |
\(O(NK)\) (same) |
With \(M \ll \min(N, K)\), TorchGW achieves sub-quadratic scaling in the number of Dijkstra computations while maintaining the same Sinkhorn cost.
Log-Domain Sinkhorn
TorchGW uses a pure-PyTorch log-domain Sinkhorn implementation for numerical stability with small regularization \(\varepsilon\):
where \(\log K = -C / \varepsilon\) and \(\tau = 1\) for balanced GW.
Semi-Relaxed GW
Setting \(\tau = \rho / (\rho + \varepsilon)\) with \(\tau < 1\) relaxes the target marginal constraint via a KL divergence penalty:
\(\rho \to \infty\): \(\tau \to 1\), recovers balanced GW
\(\rho \to 0\): \(\tau \to 0\), target marginal is completely free
This is useful when source and target have different compositions.
Differentiable Sinkhorn
When differentiable=True, TorchGW computes exact gradients through the
Sinkhorn solver, enabling end-to-end learning where the transport plan is a
differentiable function of the cost matrix.
The problem
The entropic OT solution is \(T^*_{ij} = \exp\bigl((f_i + g_j - C_{ij})/\varepsilon\bigr)\), where \(f, g\) are the Sinkhorn dual potentials (Kantorovich potentials). To backpropagate through \(T^*(C)\), we need \(\partial T^* / \partial C\).
A naive approach — freeze \(f, g\) and differentiate the exponential directly — gives \(\partial T^*_{ij}/\partial C_{ij} \approx -T^*_{ij}/\varepsilon\). This frozen-potentials approximation ignores that \(f, g\) themselves depend on \(C\) through the Sinkhorn iterations. In practice this produces gradients with cosine similarity as low as 0.07 against the true gradient, especially at small \(\varepsilon\).
Implicit differentiation (default: grad_mode="implicit")
Instead of differentiating through the iterative Sinkhorn algorithm, we use the implicit function theorem (IFT) at the converged fixed point.
Step 1: Fixed-point conditions. At convergence, the Sinkhorn potentials satisfy:
where \(K_{ij} = e^{-C_{ij}/\varepsilon}\).
Step 2: Jacobian of the fixed-point map. The Jacobian \(J = \partial F / \partial (\log u, \log v)\) has a clean block structure:
where \(P_{ij} = T^*_{ij}/a_i\) and \(R_{ji} = T^*_{ij}/b_j\) are the row-normalized and column-normalized transport plans (softmax outputs from each Sinkhorn half-step). Both \(P\) and \(R\) are row-stochastic, so the eigenvalues of \(J\) lie in \([0, 2]\) — a well-conditioned system.
Step 3: Adjoint equation. By the IFT, the vector-Jacobian product (VJP) for an upstream loss \(\mathcal{L}(T^*)\) requires solving the adjoint system:
where \(r_u = (G \odot T^*)\,\mathbf{1}\), \(r_v = (G \odot T^*)^\top \mathbf{1}\), and \(G = \partial\mathcal{L}/\partial T^*\) is the upstream gradient.
Step 4: Schur complement solve. Eliminating \(\lambda_u\) gives a \(K \times K\) system:
\(S\) has a rank-1 null space (eigenvector \(\mathbf{1}_K\),
from the constant ambiguity in Sinkhorn potentials: \(f + c, g - c\)
yield the same \(T^*\)). This null mode cancels in the final gradient,
so we remove it by adding \(\mathbf{1}\mathbf{1}^\top\!/K\) to \(S\),
replacing the zero eigenvalue with 1. The solve is then a standard
torch.linalg.solve call.
Step 5: Final VJP.
Complexity: \(O(NK^2 + K^3)\) for the Schur complement construction and solve. Memory: \(O(K^2)\) for the Schur complement matrix plus \(O(NK)\) for \(T^*\). No Sinkhorn iterations are stored.
Unrolled autograd (grad_mode="unrolled")
An alternative is to simply run the Sinkhorn loop under torch.enable_grad(),
letting PyTorch’s autograd record and differentiate through every iteration.
This gives exact gradients (matching the implicit mode up to floating-point
precision) but stores all intermediate states:
Memory: \(O(NK \times \text{sinkhorn\_iters})\)
Speed: ~1.5–2x slower than implicit (extra graph bookkeeping)
When to use: debugging, or when \(\varepsilon\) is extremely small (< 0.001) and the transport plan has severe floating-point underflow that limits the implicit mode’s accuracy.
Summary
|
Gradient |
Memory |
Notes |
|---|---|---|---|
|
Exact (IFT) |
\(O(NK + K^2)\) |
Best default; Schur solve |
|
Exact (autograd) |
\(O(NK \times \text{iters})\) |
Fallback for extreme \(\varepsilon\) |
Regularization Schedule
The entropic regularization \(\varepsilon\) is decayed exponentially during optimization:
where \(\varepsilon_{\min} = \min(5 \times 10^{-4}, \varepsilon_0)\). Large initial regularization helps exploration; small final regularization sharpens the transport plan.