Source code for torchgw._solver

import warnings

import numpy as np
import torch

from torchgw._graph import build_knn_graph
from torchgw._sampling import sample_pairs_gpu
from torchgw._utils import get_device


# ── Sinkhorn core (shared by both no_grad and differentiable paths) ──────

def _sinkhorn_iterations(
    log_K: torch.Tensor,
    log_a: torch.Tensor,
    log_b: torch.Tensor,
    log_u: torch.Tensor,
    log_v: torch.Tensor,
    is_balanced: bool,
    tau: float,
    n_iter: int,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Pure Sinkhorn iterations without convergence check (compilable)."""
    for _ in range(n_iter):
        log_u = log_a - torch.logsumexp(log_K + log_v.unsqueeze(0), dim=1)
        log_v_raw = log_b - torch.logsumexp(log_K + log_u.unsqueeze(1), dim=0)
        if is_balanced:
            log_v = log_v_raw
        else:
            log_v = tau * log_v_raw + (1 - tau) * log_v
    return log_u, log_v


# torch.compile for kernel fusion (lazy init to avoid import-time compilation)
_sinkhorn_iterations_compiled = None


def _get_compiled_sinkhorn():
    global _sinkhorn_iterations_compiled
    if _sinkhorn_iterations_compiled is None:
        _sinkhorn_iterations_compiled = torch.compile(
            _sinkhorn_iterations, mode="reduce-overhead", dynamic=False,
        )
    return _sinkhorn_iterations_compiled


def _sinkhorn_loop(
    log_K: torch.Tensor, log_a: torch.Tensor, log_b: torch.Tensor,
    tau_a: float, tau_b: float, max_iter: int, tol: float, check_every: int,
    a: torch.Tensor, verbose: bool = False,
    log_u_init: torch.Tensor | None = None,
    log_v_init: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Dispatch order: Triton (CUDA only, balanced or single-tau semi-relaxed) →
    pure PyTorch fallback (handles both sides via tau_a, tau_b)."""
    fully_unbalanced = (tau_a != 1.0) and (tau_b != 1.0) and (tau_a != tau_b or tau_a < 1.0)
    if log_K.is_cuda and not fully_unbalanced:
        try:
            from torchgw._triton_sinkhorn import triton_sinkhorn_loop
            tau_legacy = tau_b  # legacy single-tau was on the v side
            return triton_sinkhorn_loop(log_K, log_a, log_b, tau_legacy, max_iter,
                                        tol, check_every, a, verbose,
                                        log_u_init=log_u_init, log_v_init=log_v_init)
        except (ImportError, RuntimeError):
            pass
    return _sinkhorn_loop_pytorch(log_K, log_a, log_b, tau_a, tau_b,
                                   max_iter, tol, check_every, a, verbose,
                                   log_u_init=log_u_init, log_v_init=log_v_init)


def _sinkhorn_loop_pytorch(
    log_K: torch.Tensor,
    log_a: torch.Tensor,
    log_b: torch.Tensor,
    tau_a: float,
    tau_b: float,
    max_iter: int,
    tol: float,
    check_every: int,
    a: torch.Tensor,
    verbose: bool = False,
    log_u_init: torch.Tensor | None = None,
    log_v_init: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Pure PyTorch Sinkhorn fallback. tau_a, tau_b control KL damping per side
    (1.0 = strict balanced; <1 = unbalanced KL relaxation)."""
    log_u = log_u_init if log_u_init is not None else torch.zeros_like(log_a)
    log_v = log_v_init if log_v_init is not None else torch.zeros_like(log_b)
    is_balanced_a = (tau_a == 1.0)
    is_balanced_b = (tau_b == 1.0)

    for it in range(max_iter):
        log_u_raw = log_a - torch.logsumexp(log_K + log_v.unsqueeze(0), dim=1)
        log_u = log_u_raw if is_balanced_a else tau_a * log_u_raw
        log_v_raw = log_b - torch.logsumexp(log_K + log_u.unsqueeze(1), dim=0)
        log_v = log_v_raw if is_balanced_b else tau_b * log_v_raw

        if tol > 0 and (it + 1) % check_every == 0:
            log_marginal = log_u + torch.logsumexp(log_K + log_v.unsqueeze(0), dim=1)
            marginal_err = torch.abs(torch.exp(log_marginal) - a).max().item()
            if verbose:
                print(f"    sinkhorn {it+1:>4}/{max_iter} | marginal_err: {marginal_err:.4e}")
            if marginal_err < tol:
                if verbose:
                    print(f"    sinkhorn converged at {it+1} (err={marginal_err:.4e})")
                break
    return log_u, log_v


[docs] def _sinkhorn_torch( a: torch.Tensor, b: torch.Tensor, C: torch.Tensor, reg: float, max_iter: int = 100, tol: float = 5e-4, check_every: int = 10, semi_relaxed: bool = False, rho: float | None = None, # legacy single-side semi-relaxed alias rho_a: float = 1.0, rho_b: float = 1.0, _inplace_C: bool = False, verbose: bool = False, log_u_init: torch.Tensor | None = None, log_v_init: torch.Tensor | None = None, ) -> torch.Tensor: """Log-domain Sinkhorn supporting balanced, single-side semi-relaxed, and fully-unbalanced via (rho_a, rho_b). Backward-compatible: if `rho` is provided (and `rho_a`/`rho_b` are at their defaults), legacy single-side semi-relaxed semantics is used (damping on the v/target side only — original behavior pre-PR). """ if rho is not None: if rho_a != 1.0 or rho_b != 1.0: raise ValueError("pass either rho (legacy) OR rho_a/rho_b, not both") rho_a, rho_b = 1.0, rho # legacy: damping only on v-side (target) log_K = C.neg_().div_(reg) if _inplace_C else -C / reg log_a = torch.log(a.clamp(min=1e-30)) log_b = torch.log(b.clamp(min=1e-30)) if semi_relaxed: tau_a = rho_a / (rho_a + reg) tau_b = rho_b / (rho_b + reg) else: tau_a = tau_b = 1.0 log_u, log_v = _sinkhorn_loop(log_K, log_a, log_b, tau_a, tau_b, max_iter, tol, check_every, a, verbose=verbose, log_u_init=log_u_init, log_v_init=log_v_init) # Fused T materialization (Triton on CUDA, PyTorch fallback) if log_K.is_cuda: try: from torchgw._triton_sinkhorn import triton_materialize_T T = triton_materialize_T(log_u, log_K, log_v) except (ImportError, RuntimeError): T = torch.exp(log_u.unsqueeze(1) + log_K + log_v.unsqueeze(0)) else: T = torch.exp(log_u.unsqueeze(1) + log_K + log_v.unsqueeze(0)) # Stash potentials for warm-starting the next call T._log_u = log_u.detach() # type: ignore[attr-defined] T._log_v = log_v.detach() # type: ignore[attr-defined] return T
def _adjoint_sinkhorn_vjp( T: torch.Tensor, a: torch.Tensor, b: torch.Tensor, reg: float, grad_T: torch.Tensor, ) -> torch.Tensor: """Compute dL/dC via implicit differentiation at the Sinkhorn fixed point. The adjoint system (from IFT on Sinkhorn fixed-point conditions) is: J^T · [λ_u, λ_v] = [r_u, r_v] where J = [[I, P], [R, I]], P_{ij} = T_{ij}/a_i, R_{ji} = T_{ij}/b_j, r_u = (G ⊙ T)·1, r_v = (G ⊙ T)^T·1, G = grad_T. Solved via Schur complement on J^T (well-conditioned, eigenvalues in [0, 2]). The system has a rank-1 null space from the Sinkhorn potential constant ambiguity, removed by a rank-1 correction (11^T/K) that preserves the gradient-relevant components. Final VJP: dL/dC_{kl} = (T_{kl}/ε) · (-G_{kl} + λ_u_k/a_k + λ_v_l/b_l) """ N, K = T.shape G_T = grad_T * T # G ⊙ T, shape (N, K) r_u = G_T.sum(dim=1) # (N,) r_v = G_T.sum(dim=0) # (K,) # Jacobian blocks (row-stochastic matrices) P = T / a.unsqueeze(1) # (N, K), P_{ij} = T_{ij}/a_i R_T = T / b.unsqueeze(0) # (N, K), R^T_{ij} = T_{ij}/b_j # Schur complement on J^T: eliminate λ_u = r_u - R^T λ_v # → (I_K - P^T R^T) λ_v = r_v - P^T r_u S = torch.eye(K, dtype=T.dtype, device=T.device) - P.T @ R_T # (K, K) # S has a rank-1 null space (eigvec ∝ 1_K) from the potential constant # ambiguity. Adding 11^T/K replaces the zero eigenvalue with 1, making # S nonsingular. Since the RHS is orthogonal to 1_K (sum(r_u) = sum(r_v) # for any valid upstream gradient), this doesn't affect the solution. S += torch.ones(K, K, dtype=T.dtype, device=T.device) / K rhs_v = r_v - P.T @ r_u # (K,) lambda_v = torch.linalg.solve(S, rhs_v) # (K,) lambda_u = r_u - R_T @ lambda_v # (N,) grad_C = (T / reg) * (-grad_T + (lambda_u / a).unsqueeze(1) + (lambda_v / b).unsqueeze(0)) return grad_C class _SinkhornImplicit(torch.autograd.Function): """Differentiable Sinkhorn with exact gradient via implicit differentiation.""" @staticmethod def forward(ctx, C, a, b, reg, max_iter, tol, check_every): log_K = -C / reg log_a = torch.log(a.clamp(min=1e-30)) log_b = torch.log(b.clamp(min=1e-30)) log_u, log_v = _sinkhorn_loop(log_K, log_a, log_b, 1.0, 1.0, max_iter, tol, check_every, a) T = torch.exp(log_u.unsqueeze(1) + log_K + log_v.unsqueeze(0)) ctx.save_for_backward(T, a, b) ctx.reg = reg return T @staticmethod def backward(ctx, grad_T): T, a, b = ctx.saved_tensors grad_C = _adjoint_sinkhorn_vjp(T, a, b, ctx.reg, grad_T) return grad_C, None, None, None, None, None, None class _SinkhornApproximate(torch.autograd.Function): """Differentiable Sinkhorn with frozen-potentials approximation. Backward: dT/dC ≈ -T/ε (treats potentials as constants). Fast but inexact — use _SinkhornImplicit for exact gradients. """ @staticmethod def forward(ctx, C, a, b, reg, max_iter, tol, check_every, semi_relaxed, rho_a, rho_b): if semi_relaxed: tau_a = rho_a / (rho_a + reg) tau_b = rho_b / (rho_b + reg) else: tau_a = tau_b = 1.0 log_K = -C / reg log_a = torch.log(a.clamp(min=1e-30)) log_b = torch.log(b.clamp(min=1e-30)) log_u, log_v = _sinkhorn_loop(log_K, log_a, log_b, tau_a, tau_b, max_iter, tol, check_every, a, False) T = torch.exp(log_u.unsqueeze(1) + log_K + log_v.unsqueeze(0)) ctx.save_for_backward(T, a, b) ctx.reg = reg return T @staticmethod def backward(ctx, grad_T): (T, a, b) = ctx.saved_tensors grad_C = -grad_T * T / ctx.reg return grad_C, None, None, None, None, None, None, None, None, None def _sinkhorn_unrolled( C, a, b, reg, max_iter=100, tol=5e-4, check_every=10, semi_relaxed=False, rho: float | None = None, rho_a: float = 1.0, rho_b: float = 1.0, grad_mode="autograd", verbose=False, ): if rho is not None: if rho_a != 1.0 or rho_b != 1.0: raise ValueError("pass either rho (legacy) OR rho_a/rho_b, not both") rho_a, rho_b = 1.0, rho if semi_relaxed: tau_a = rho_a / (rho_a + reg) tau_b = rho_b / (rho_b + reg) else: tau_a = tau_b = 1.0 log_K = -C / reg log_a = torch.log(a.clamp(min=1e-30)) log_b = torch.log(b.clamp(min=1e-30)) log_u = torch.zeros_like(log_a) log_v = torch.zeros_like(log_b) is_a_balanced = (tau_a == 1.0) is_b_balanced = (tau_b == 1.0) for it in range(max_iter): log_u_raw = log_a - torch.logsumexp(log_K + log_v.unsqueeze(0), dim=1) log_u = log_u_raw if is_a_balanced else tau_a * log_u_raw log_v_raw = log_b - torch.logsumexp(log_K + log_u.unsqueeze(1), dim=0) log_v = log_v_raw if is_b_balanced else tau_b * log_v_raw if tol > 0 and (it + 1) % check_every == 0: log_marginal = log_u + torch.logsumexp(log_K + log_v.unsqueeze(0), dim=1) if torch.abs(torch.exp(log_marginal) - a).max().item() < tol: break return torch.exp(log_u.unsqueeze(1) + log_K + log_v.unsqueeze(0)) _VALID_GRAD_MODES = {"implicit", "unrolled", "approximate"}
[docs] def _sinkhorn_differentiable( C, a, b, reg, max_iter=100, tol=5e-4, check_every=10, semi_relaxed=False, rho: float | None = None, rho_a: float = 1.0, rho_b: float = 1.0, grad_mode="autograd", verbose=False, ): if rho is not None: if rho_a != 1.0 or rho_b != 1.0: raise ValueError("pass either rho (legacy) OR rho_a/rho_b, not both") rho_a, rho_b = 1.0, rho if grad_mode not in {"autograd", "implicit", "approximate", "unrolled"}: raise ValueError(f"invalid grad_mode={grad_mode!r}; " "must be one of 'autograd', 'implicit', 'approximate', 'unrolled'") if semi_relaxed: return _sinkhorn_unrolled(C, a, b, reg, max_iter, tol, check_every, semi_relaxed, None, rho_a, rho_b, grad_mode, verbose) if grad_mode == "implicit": return _SinkhornImplicit.apply(C, a, b, reg, max_iter, tol, check_every) if grad_mode == "approximate": return _SinkhornApproximate.apply( C, a, b, reg, max_iter, tol, check_every, semi_relaxed, rho_a, rho_b, ) return _sinkhorn_unrolled(C, a, b, reg, max_iter, tol, check_every, semi_relaxed, None, rho_a, rho_b, grad_mode, verbose)
# ── Input coercion ────────────────────────────────────────────────────── def _to_tensor(x): """Convert numpy array or tensor to torch.Tensor. Pass through None.""" if x is None: return None if isinstance(x, torch.Tensor): return x return torch.as_tensor(x) # ── Shared preprocessing ──────────────────────────────────────────────── def _prepare_inputs( X_source, X_target, p, q, dist_source, dist_target, C_linear, distance_mode, fgw_alpha, k, n_landmarks, device, ): """Coerce inputs, infer N/K, validate, build distance provider. Returns (X_source, X_target, dist_source, dist_target, C_linear_t, N, K, provider, device). """ from torchgw._distances import DijkstraProvider, PrecomputedProvider, LandmarkProvider X_source = _to_tensor(X_source) X_target = _to_tensor(X_target) p = _to_tensor(p) q = _to_tensor(q) dist_source = _to_tensor(dist_source) dist_target = _to_tensor(dist_target) C_linear_t = _to_tensor(C_linear) # Infer N, K if dist_source is not None and dist_target is not None: N, K = dist_source.shape[0], dist_target.shape[0] elif X_source is not None and X_target is not None: N, K = X_source.shape[0], X_target.shape[0] elif C_linear_t is not None: N, K = C_linear_t.shape else: raise ValueError( "Cannot infer dataset sizes. Provide (X_source, X_target), " "(dist_source, dist_target), or C_linear." ) # Validate _VALID_MODES = {"precomputed", "dijkstra", "landmark"} if distance_mode not in _VALID_MODES: raise ValueError( f"distance_mode must be one of {_VALID_MODES}, got {distance_mode!r}" ) if fgw_alpha > 0 and C_linear_t is None: raise ValueError("fgw_alpha > 0 requires C_linear to be provided") if device is None: device = get_device() # Build distance provider if fgw_alpha >= 1.0: provider = None elif distance_mode == "precomputed": if dist_source is not None and dist_target is not None: provider = PrecomputedProvider(dist_source=dist_source, dist_target=dist_target) elif X_source is not None and X_target is not None: graphs = ( build_knn_graph(X_source.cpu().numpy(), k=k), build_knn_graph(X_target.cpu().numpy(), k=k), ) provider = PrecomputedProvider(graph_source=graphs[0], graph_target=graphs[1]) else: raise ValueError( "distance_mode='precomputed' requires (dist_source, dist_target) " "or (X_source, X_target)" ) elif distance_mode == "dijkstra": if X_source is None or X_target is None: raise ValueError("distance_mode='dijkstra' requires X_source and X_target") graph_source = build_knn_graph(X_source.cpu().numpy(), k=k) graph_target = build_knn_graph(X_target.cpu().numpy(), k=k) provider = DijkstraProvider(graph_source, graph_target) elif distance_mode == "landmark": if X_source is None or X_target is None: raise ValueError("distance_mode='landmark' requires X_source and X_target") graph_source = build_knn_graph(X_source.cpu().numpy(), k=k) graph_target = build_knn_graph(X_target.cpu().numpy(), k=k) provider = LandmarkProvider(graph_source, graph_target, n_landmarks=n_landmarks) return X_source, X_target, p, q, dist_source, dist_target, C_linear_t, N, K, provider, device def _gw_loop( *, N: int, K: int, provider, p_real: torch.Tensor, q_real: torch.Tensor, T_init: torch.Tensor, sinkhorn_fn, use_augmented: bool, s_shared: int | None, fgw_alpha: float, C_lin_device: torch.Tensor | None, M: int, alpha: float, max_iter: int, tol: float, epsilon: float, min_iter_before_converge: int, device: torch.device, verbose: bool, verbose_every: int, semi_relaxed: bool, rho_a: float, rho_b: float, differentiable: bool = False, lambda_ema_beta: float | None = None, mixed_precision: bool = False, ) -> tuple[torch.Tensor, list, int, float]: """Shared main loop for sampled_gw and sampled_lowrank_gw. Parameters ---------- sinkhorn_fn : callable For standard: takes (p_aug, q_aug, Lambda_aug, reg, **kw) -> T_aug For low-rank: takes (p_real, q_real, Lambda, reg, **kw) -> T_new use_augmented : bool If True, build augmented cost/marginals and call sinkhorn_fn on them. If False, call sinkhorn_fn directly on (p_real, q_real, Lambda). Returns ------- T_out, err_list, n_iter, gw_cost_val """ if lambda_ema_beta is not None and not (0.0 <= lambda_ema_beta <= 1.0): raise ValueError(f"lambda_ema_beta must be in [0, 1], got {lambda_ema_beta}") if M < 1: raise ValueError(f"M must be >= 1, got {M}") # Sinkhorn internal dtype: float32 when mixed_precision, else float64 sink_dtype = torch.float32 if mixed_precision else torch.float64 T_real = T_init.to(sink_dtype) # Augmented marginals (only needed for standard Sinkhorn) if use_augmented: m_frac = s_shared / max(N, K) if s_shared is not None else min(N, K) / max(N, K) slack_p = max(q_real.sum().item() - m_frac, 1e-10) slack_q = max(p_real.sum().item() - m_frac, 1e-10) p_aug = torch.cat([p_real, torch.tensor([slack_p], device=device, dtype=torch.float64)]) q_aug = torch.cat([q_real, torch.tensor([slack_q], device=device, dtype=torch.float64)]) # Regularization decay (at most 10x reduction to avoid instability) initial_reg = epsilon if epsilon > 0 else 1e-4 final_reg = max(initial_reg / 10.0, min(5e-4, initial_reg)) decay = (final_reg / initial_reg) ** (1 / max(1, max_iter)) err_list = [] gw_cost_val = 0.0 n_iter = 0 Lambda_ema = None # EMA state for cost matrix smoothing _warm_log_u: torch.Tensor | None = None # Sinkhorn warm-start potentials _warm_log_v: torch.Tensor | None = None # Cost plateau detection via EMA + patience. # Critical because err = ||T - T_prev|| reflects sampling noise (not # optimization progress) and may never converge to tol. _cost_ema: float | None = None _best_cost_ema = float('inf') _no_improve = 0 _patience = max(min_iter_before_converge // 2, 20) # Pre-allocate augmented cost matrix and cast marginals (reused every iteration) if use_augmented: Lambda_aug = torch.zeros(N + 1, K + 1, device=device, dtype=sink_dtype) p_aug_sink = p_aug.to(sink_dtype) q_aug_sink = q_aug.to(sink_dtype) p_sink = p_real.to(sink_dtype) q_sink = q_real.to(sink_dtype) for i in range(max_iter): current_reg = initial_reg * (decay ** i) # Sample anchor pairs (on GPU, only transfers 2*M ints back) j_left, l_target = sample_pairs_gpu(T_real.detach(), M) # Compute distances via provider if provider is not None: D_left, D_tgt = provider.get_distances(j_left, l_target, device) for D in [D_left, D_tgt]: inf_mask = torch.isinf(D) if torch.any(inf_mask): finite_vals = D[~inf_mask] fill = finite_vals.max() * 1.5 if finite_vals.numel() > 0 else 1.0 D[inf_mask] = fill mx = D.max() if mx > 0: D /= mx # Build Lambda_gw in sink_dtype (float32 when mixed_precision) # to avoid a full N*K float64 allocation. D_left_s = D_left if D_left.dtype == sink_dtype else D_left.to(sink_dtype) D_tgt_s = D_tgt if D_tgt.dtype == sink_dtype else D_tgt.to(sink_dtype) term_A = torch.mean(D_left_s ** 2, dim=1, keepdim=True) term_C = torch.mean(D_tgt_s ** 2, dim=1, keepdim=True).T Lambda_gw = torch.mm(D_left_s, D_tgt_s.T) Lambda_gw.mul_(-2.0 / M).add_(term_A).add_(term_C) del D_left_s, D_tgt_s # Lambda EMA: smooth cost matrix across iterations # beta=0.0 is treated as disabled (same as None) if lambda_ema_beta is not None and lambda_ema_beta > 0: if Lambda_ema is None: Lambda_ema = Lambda_gw else: Lambda_ema = (1 - lambda_ema_beta) * Lambda_ema + lambda_ema_beta * Lambda_gw Lambda_gw = Lambda_ema.clone() else: Lambda_gw = None # FGW blending if fgw_alpha >= 1.0: Lambda = C_lin_device elif fgw_alpha > 0: Lambda = (1 - fgw_alpha) * Lambda_gw + fgw_alpha * C_lin_device else: Lambda = Lambda_gw # Sinkhorn step if use_augmented: Lambda_aug[:N, :K] = Lambda if Lambda.dtype == sink_dtype else Lambda.to(sink_dtype) penalty = 100.0 * Lambda.max().clamp(min=1.0) # stays on GPU, no sync Lambda_aug[:-1, -1] = penalty Lambda_aug[-1, :-1] = penalty Lambda_aug[-1, -1] = 0.0 verbose_sink = verbose and (n_iter + 1) % verbose_every == 0 T_aug = sinkhorn_fn(p_aug_sink, q_aug_sink, Lambda_aug, current_reg, semi_relaxed=semi_relaxed, rho_a=rho_a, rho_b=rho_b, verbose=verbose_sink, log_u_init=_warm_log_u, log_v_init=_warm_log_v, _inplace_C=True) T_new = T_aug[:-1, :-1] # Retrieve potentials for warm-starting next iteration _warm_log_u = getattr(T_aug, '_log_u', None) _warm_log_v = getattr(T_aug, '_log_v', None) else: verbose_sink = verbose and (n_iter + 1) % verbose_every == 0 Lambda_sink = Lambda if Lambda.dtype == sink_dtype else Lambda.to(sink_dtype) T_new = sinkhorn_fn(p_sink, q_sink, Lambda_sink, current_reg, semi_relaxed=semi_relaxed, rho_a=rho_a, rho_b=rho_b, verbose=verbose_sink, log_u_init=_warm_log_u, log_v_init=_warm_log_v) _warm_log_u = getattr(T_new, '_log_u', None) _warm_log_v = getattr(T_new, '_log_v', None) # In-place momentum update: T_real = (1-alpha)*T_real + alpha*T_new # Avoids allocating a separate T_prev copy (saves one N*K buffer). # Compute convergence metric BEFORE the in-place update. if differentiable: # Differentiable mode: keep T_new in graph, no in-place T_prev = T_real.detach().clone() T_real = (1 - alpha) * T_prev + alpha * T_new err_tensor = torch.linalg.norm(T_real - T_prev) del T_prev else: # In-place momentum: T_real ← (1-α)T_real + αT_new # After update: T_real_new - T_real_old = α(T_new - T_real_old) # = α/(1-α) * (T_real_new - T_new) [since T_real_new - T_new = (1-α)(T_real_old - T_new)] # Compute in-place to avoid N*K temporaries: T_real.mul_(1 - alpha).add_(T_new, alpha=alpha) # Reuse T_new buffer for err: T_new ← T_real - T_new (in-place) T_new.neg_().add_(T_real) # T_new is now (T_real_new - T_new_orig) err_tensor = (alpha / (1 - alpha)) * T_new.norm() n_iter = i + 1 _check_interval = 5 # sync with CPU every N iterations # Only sync to CPU at check intervals (reduces CUDA sync overhead) if n_iter % _check_interval == 0 or i == max_iter - 1 or i >= min_iter_before_converge: # Frobenius inner product <Lambda, T>, computed via batched row # dots to avoid N*K temporary and int32 overflow (N*K > 2^31). Lambda_s = Lambda if Lambda.dtype == sink_dtype else Lambda.to(sink_dtype) gw_cost_val = torch.bmm( Lambda_s.unsqueeze(1), T_real.unsqueeze(2) ).sum().item() err = err_tensor.item() err_list.append(err) if verbose and (n_iter % verbose_every == 0 or i == max_iter - 1): print(f" iter {n_iter:>4}/{max_iter} | err: {err:.4e} | " f"gw_cost: {gw_cost_val:.4e} | reg: {current_reg:.4e}") # Convergence: plan change OR cost EMA plateau _cost_ema = gw_cost_val if _cost_ema is None else 0.8 * _cost_ema + 0.2 * gw_cost_val if i >= min_iter_before_converge: if err < tol: if verbose: print(f" converged at iteration {n_iter} (err={err:.4e})") break if _cost_ema < _best_cost_ema * 0.995: _best_cost_ema = _cost_ema _no_improve = 0 else: _no_improve += 1 if _no_improve >= _patience: if verbose: print(f" cost plateau at iteration {n_iter} " f"(no improve for {_patience} iters, gw_cost={gw_cost_val:.4e})") break del T_new if use_augmented: del T_aug if provider is not None: del D_left, D_tgt, Lambda_gw # Cast back to float64 for output precision if T_real.dtype != torch.float64: try: T_out = T_real.to(torch.float64) except torch.cuda.OutOfMemoryError: T_out = T_real # keep float32 if float64 copy would OOM else: T_out = T_real if differentiable: return T_out, err_list, n_iter, gw_cost_val return T_out.detach(), err_list, n_iter, gw_cost_val # ── Multiscale helper ──────────────────────────────────────────────────── def _maybe_multiscale( multiscale, n_coarse, X_source, X_target, N, K, dist_source, dist_target, C_linear_t, fgw_alpha, distance_mode, n_landmarks, device, p_real, q_real, solver_fn, solver_kwargs, ): """Run coarse solve and return upsampled T_init, or None.""" if not multiscale or X_source is None or X_target is None: return None from torchgw._multiscale import fps_downsample, upsample_plan _n_coarse = n_coarse if n_coarse is not None else min(500, N // 4, K // 4) _n_coarse = max(_n_coarse, 10) if _n_coarse >= N or _n_coarse >= K: return None idx_src, assign_src = fps_downsample(X_source, _n_coarse) idx_tgt, assign_tgt = fps_downsample(X_target, _n_coarse) X_src_coarse = X_source[idx_src] X_tgt_coarse = X_target[idx_tgt] C_lin_coarse = None if C_linear_t is not None and fgw_alpha > 0: C_lin_coarse = C_linear_t[idx_src][:, idx_tgt] dist_src_coarse = None dist_tgt_coarse = None if dist_source is not None and dist_target is not None: dist_src_coarse = dist_source[idx_src][:, idx_src] dist_tgt_coarse = dist_target[idx_tgt][:, idx_tgt] coarse_kwargs = {**solver_kwargs} coarse_kwargs['multiscale'] = False # no recursion coarse_kwargs['M'] = min(solver_kwargs.get('M', 50), max(_n_coarse // 2, 10)) coarse_kwargs['k'] = min(solver_kwargs.get('k', 30), _n_coarse - 1) if solver_kwargs.get('s_shared') is not None: coarse_kwargs['s_shared'] = min(solver_kwargs['s_shared'], _n_coarse) T_coarse = solver_fn( X_src_coarse, X_tgt_coarse, distance_mode=distance_mode, dist_source=dist_src_coarse, dist_target=dist_tgt_coarse, n_landmarks=n_landmarks, fgw_alpha=fgw_alpha, C_linear=C_lin_coarse, device=device, **coarse_kwargs, ) T_init = upsample_plan(T_coarse, assign_src, assign_tgt, p_real, q_real) return T_init # ── Public API: standard solver ─────────────────────────────────────────
[docs] def sampled_gw( X_source: np.ndarray | torch.Tensor | None = None, X_target: np.ndarray | torch.Tensor | None = None, p: np.ndarray | torch.Tensor | None = None, q: np.ndarray | torch.Tensor | None = None, *, distance_mode: str = "dijkstra", dist_source: np.ndarray | torch.Tensor | None = None, dist_target: np.ndarray | torch.Tensor | None = None, n_landmarks: int = 50, fgw_alpha: float = 0.0, C_linear: np.ndarray | torch.Tensor | None = None, s_shared: int | None = None, M: int = 50, alpha: float = 0.9, max_iter: int = 500, tol: float = 1e-5, epsilon: float = 0.001, k: int = 30, min_iter_before_converge: int = 50, device: torch.device | None = None, verbose: bool = False, verbose_every: int = 20, log: bool = False, differentiable: bool = False, grad_mode: str = "implicit", semi_relaxed: bool = False, rho: float | None = None, # legacy single-side semi-relaxed alias rho_a: float = 1.0, rho_b: float = 1.0, multiscale: bool = False, n_coarse: int | None = None, lambda_ema_beta: float | None = None, mixed_precision: bool = False, T_init: np.ndarray | torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, dict]: """Sampled Gromov-Wasserstein alignment between two datasets. Uses standard log-domain Sinkhorn with slack augmentation for partial transport. Parameters ---------- X_source, X_target : ndarray or Tensor, optional Feature matrices. Required unless dist matrices or C_linear provided. p, q : ndarray or Tensor, optional Marginal distributions (uniform if None). distance_mode : str ``"dijkstra"`` (default), ``"precomputed"``, or ``"landmark"``. dist_source, dist_target : ndarray or Tensor, optional Precomputed (ns, ns) and (nt, nt) distance matrices. n_landmarks : int Landmark count for ``distance_mode="landmark"``. fgw_alpha : float FGW blending: 0 = pure GW, 1 = pure Wasserstein. C_linear : ndarray or Tensor, optional (ns, nt) feature cost for FGW. s_shared, M, alpha, max_iter, tol, epsilon, k : solver parameters min_iter_before_converge : int device : torch.device, optional verbose, verbose_every : progress printing log : bool Return (T, log_dict) if True. differentiable : bool Keep computation graph for backprop. grad_mode : str Gradient computation mode when ``differentiable=True``. ``"implicit"`` (default): exact via adjoint at Sinkhorn fixed point. Memory-efficient: O(NK). ``"unrolled"``: exact via unrolled PyTorch autograd. Memory: O(NK * sinkhorn_iters). semi_relaxed : bool Relax target marginal via KL penalty. rho_a, rho_b : float KL penalty weights for source and target marginals (semi_relaxed only). multiscale : bool Two-stage coarse-to-fine warm start. n_coarse : int, optional Coarse problem size (auto if None). lambda_ema_beta : float, optional EMA smoothing factor for the cost matrix. When set, maintains a running average: Lambda_ema = (1-beta)*Lambda_ema + beta*Lambda_sample. Reduces sampling variance at the cost of small bias that vanishes at convergence. Typical values: 0.3–0.7. None disables (default). mixed_precision : bool Run Sinkhorn iterations in float32 for speed, cast result back to float64. Safe because all critical ops are in log domain where values are O(log N). Marginals and transport plan stay in float64. Returns ------- T : Tensor (ns, nt) log_dict : dict (only if log=True) """ if rho is not None: if rho_a != 1.0 or rho_b != 1.0: raise ValueError("pass either rho (legacy) OR rho_a/rho_b, not both") rho_a, rho_b = 1.0, rho # legacy: damping only on v-side (target) (X_source, X_target, p, q, dist_source, dist_target, C_linear_t, N, K, provider, device) = _prepare_inputs( X_source, X_target, p, q, dist_source, dist_target, C_linear, distance_mode, fgw_alpha, k, n_landmarks, device, ) # Marginals (float64) if p is not None: p_real = p.to(dtype=torch.float64, device=device) else: p_real = torch.ones(N, device=device, dtype=torch.float64) / N if q is not None: q_real = q.to(dtype=torch.float64, device=device) else: q_real = torch.ones(K, device=device, dtype=torch.float64) / K # User-supplied warm start takes precedence; else multiscale; else uniform. if T_init is not None: if isinstance(T_init, np.ndarray): T_init = torch.from_numpy(T_init) T_init = T_init.to(device=device, dtype=torch.float64) else: T_init = _maybe_multiscale( multiscale, n_coarse, X_source, X_target, N, K, dist_source, dist_target, C_linear_t, fgw_alpha, distance_mode, n_landmarks, device, p_real, q_real, solver_fn=sampled_gw, solver_kwargs=dict( s_shared=s_shared, M=M, alpha=alpha, max_iter=max_iter, tol=tol, epsilon=epsilon, k=k, min_iter_before_converge=min_iter_before_converge, verbose=False, log=False, differentiable=differentiable, semi_relaxed=semi_relaxed, rho_a=rho_a, rho_b=rho_b, ), ) if T_init is None: T_init = torch.outer(p_real, q_real) # C_linear on device C_lin_device = C_linear_t.to(dtype=torch.float64, device=device) if C_linear_t is not None and fgw_alpha > 0 else None # Sinkhorn function if differentiable and fgw_alpha == 0.0: warnings.warn( "differentiable=True with fgw_alpha=0 (pure GW): gradients cannot " "flow because the GW cost matrix is built from precomputed graph " "distances that are not part of the computation graph. Set " "fgw_alpha > 0 with a differentiable C_linear to get useful gradients.", stacklevel=2, ) ctx = torch.no_grad() if not differentiable else torch.enable_grad() if differentiable: _gm = grad_mode def sinkhorn_fn(a, b, C, reg, **kw): # _sinkhorn_differentiable does not support warm-start kwargs for _stripped in ('_inplace_C', 'log_u_init', 'log_v_init', 'verbose'): kw.pop(_stripped, None) return _sinkhorn_differentiable(C, a, b, reg, grad_mode=_gm, **kw) else: sinkhorn_fn = _sinkhorn_torch with ctx: T_out, err_list, n_iter, gw_cost_val = _gw_loop( N=N, K=K, provider=provider, p_real=p_real, q_real=q_real, T_init=T_init, sinkhorn_fn=sinkhorn_fn, use_augmented=True, s_shared=s_shared, fgw_alpha=fgw_alpha, C_lin_device=C_lin_device, M=M, alpha=alpha, max_iter=max_iter, tol=tol, epsilon=epsilon, min_iter_before_converge=min_iter_before_converge, device=device, verbose=verbose, verbose_every=verbose_every, semi_relaxed=semi_relaxed, rho_a=rho_a, rho_b=rho_b, differentiable=differentiable, lambda_ema_beta=lambda_ema_beta, mixed_precision=mixed_precision, ) if log: return T_out, {"err_list": err_list, "n_iter": n_iter, "gw_cost": gw_cost_val} return T_out
# ── Public API: low-rank solver ─────────────────────────────────────────
[docs] def sampled_lowrank_gw( X_source: np.ndarray | torch.Tensor | None = None, X_target: np.ndarray | torch.Tensor | None = None, p: np.ndarray | torch.Tensor | None = None, q: np.ndarray | torch.Tensor | None = None, *, rank: int = 20, lr_max_iter: int = 5, lr_dykstra_max_iter: int = 50, distance_mode: str = "dijkstra", dist_source: np.ndarray | torch.Tensor | None = None, dist_target: np.ndarray | torch.Tensor | None = None, n_landmarks: int = 50, fgw_alpha: float = 0.0, C_linear: np.ndarray | torch.Tensor | None = None, s_shared: int | None = None, M: int = 50, alpha: float = 0.9, max_iter: int = 500, tol: float = 1e-5, epsilon: float = 0.001, k: int = 30, min_iter_before_converge: int = 50, device: torch.device | None = None, verbose: bool = False, verbose_every: int = 20, log: bool = False, semi_relaxed: bool = False, rho: float | None = None, # legacy single-side semi-relaxed alias rho_a: float = 1.0, rho_b: float = 1.0, multiscale: bool = False, n_coarse: int | None = None, lambda_ema_beta: float | None = None, mixed_precision: bool = False, ) -> torch.Tensor | tuple[torch.Tensor, dict]: """Sampled Gromov-Wasserstein with low-rank Sinkhorn. Uses the low-rank Sinkhorn factorization (Scetbon, Cuturi & Peyre 2021) to reduce memory from O(NK) to O((N+K)*rank) per Sinkhorn step. This is a **memory optimization** for large-scale problems (N, K > 50k). At smaller scales, ``sampled_gw`` with standard Sinkhorn is faster. Parameters ---------- X_source, X_target : ndarray or Tensor, optional p, q : ndarray or Tensor, optional rank : int Nonneg. rank of the transport plan factorization. lr_max_iter : int Outer mirror descent iterations per Sinkhorn call. lr_dykstra_max_iter : int Inner Dykstra projection iterations per Sinkhorn call. distance_mode, dist_source, dist_target, n_landmarks : distance params fgw_alpha, C_linear : Fused GW params s_shared, M, alpha, max_iter, tol, epsilon, k : solver params min_iter_before_converge : int device : torch.device, optional verbose, verbose_every : progress printing log : bool semi_relaxed : bool rho_a, rho_b : float KL penalty weights for source and target marginals (semi_relaxed only). multiscale : bool n_coarse : int, optional lambda_ema_beta : float, optional EMA smoothing factor for the cost matrix (see ``sampled_gw``). mixed_precision : bool Run internal computations in float32 for speed (see ``sampled_gw``). Returns ------- T : Tensor (ns, nt) log_dict : dict (only if log=True) """ if rho is not None: if rho_a != 1.0 or rho_b != 1.0: raise ValueError("pass either rho (legacy) OR rho_a/rho_b, not both") rho_a, rho_b = 1.0, rho if semi_relaxed: raise ValueError("semi_relaxed is not supported for low-rank Sinkhorn") if rho_a != rho_b: raise NotImplementedError( "sampled_lowrank_gw does not yet support rho_a != rho_b " "(low-rank Dykstra requires symmetric KL). Use sampled_gw for " "fully-unbalanced (rho_a, rho_b) FGW." ) from torchgw._lowrank import sinkhorn_lowrank (X_source, X_target, p, q, dist_source, dist_target, C_linear_t, N, K, provider, device) = _prepare_inputs( X_source, X_target, p, q, dist_source, dist_target, C_linear, distance_mode, fgw_alpha, k, n_landmarks, device, ) # Marginals (float64) if p is not None: p_real = p.to(dtype=torch.float64, device=device) else: p_real = torch.ones(N, device=device, dtype=torch.float64) / N if q is not None: q_real = q.to(dtype=torch.float64, device=device) else: q_real = torch.ones(K, device=device, dtype=torch.float64) / K # Multiscale warm start T_init = _maybe_multiscale( multiscale, n_coarse, X_source, X_target, N, K, dist_source, dist_target, C_linear_t, fgw_alpha, distance_mode, n_landmarks, device, p_real, q_real, solver_fn=sampled_lowrank_gw, solver_kwargs=dict( rank=rank, lr_max_iter=lr_max_iter, lr_dykstra_max_iter=lr_dykstra_max_iter, s_shared=s_shared, M=M, alpha=alpha, max_iter=max_iter, tol=tol, epsilon=epsilon, k=k, min_iter_before_converge=min_iter_before_converge, verbose=False, log=False, semi_relaxed=semi_relaxed, rho_a=rho_a, rho_b=rho_b, ), ) if T_init is None: T_init = torch.outer(p_real, q_real) # C_linear on device C_lin_device = C_linear_t.to(dtype=torch.float64, device=device) if C_linear_t is not None and fgw_alpha > 0 else None # Wrap sinkhorn_lowrank with fixed rank/iteration params def _lr_sinkhorn(a, b, C, reg, semi_relaxed=False, rho_a=1.0, rho_b=1.0, verbose=False, log_u_init=None, log_v_init=None): return sinkhorn_lowrank( a, b, C, rank=rank, reg=reg, max_iter=lr_max_iter, dykstra_max_iter=lr_dykstra_max_iter, ) with torch.no_grad(): T_out, err_list, n_iter, gw_cost_val = _gw_loop( N=N, K=K, provider=provider, p_real=p_real, q_real=q_real, T_init=T_init, sinkhorn_fn=_lr_sinkhorn, use_augmented=False, s_shared=s_shared, fgw_alpha=fgw_alpha, C_lin_device=C_lin_device, M=M, alpha=alpha, max_iter=max_iter, tol=tol, epsilon=epsilon, min_iter_before_converge=min_iter_before_converge, device=device, verbose=verbose, verbose_every=verbose_every, semi_relaxed=False, rho_a=rho_a, rho_b=rho_b, lambda_ema_beta=lambda_ema_beta, mixed_precision=mixed_precision, ) if log: return T_out, {"err_list": err_list, "n_iter": n_iter, "gw_cost": gw_cost_val} return T_out