Source code for torchgw._solver

import numpy as np
import torch

from torchgw._graph import build_knn_graph
from torchgw._sampling import sample_pairs_gpu
from torchgw._utils import get_device


# ── Sinkhorn core (shared by both no_grad and differentiable paths) ──────

def _sinkhorn_iterations(
    log_K: torch.Tensor,
    log_a: torch.Tensor,
    log_b: torch.Tensor,
    log_u: torch.Tensor,
    log_v: torch.Tensor,
    is_balanced: bool,
    tau: float,
    n_iter: int,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Pure Sinkhorn iterations without convergence check (compilable)."""
    for _ in range(n_iter):
        log_u = log_a - torch.logsumexp(log_K + log_v.unsqueeze(0), dim=1)
        log_v_raw = log_b - torch.logsumexp(log_K + log_u.unsqueeze(1), dim=0)
        if is_balanced:
            log_v = log_v_raw
        else:
            log_v = tau * log_v_raw + (1 - tau) * log_v
    return log_u, log_v


# torch.compile for kernel fusion (lazy init to avoid import-time compilation)
_sinkhorn_iterations_compiled = None


def _get_compiled_sinkhorn():
    global _sinkhorn_iterations_compiled
    if _sinkhorn_iterations_compiled is None:
        _sinkhorn_iterations_compiled = torch.compile(
            _sinkhorn_iterations, mode="reduce-overhead", dynamic=False,
        )
    return _sinkhorn_iterations_compiled


def _sinkhorn_loop(
    log_K: torch.Tensor,
    log_a: torch.Tensor,
    log_b: torch.Tensor,
    tau: float,
    max_iter: int,
    tol: float,
    check_every: int,
    a: torch.Tensor,
    verbose: bool = False,
    log_u_init: torch.Tensor | None = None,
    log_v_init: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Run log-domain Sinkhorn iterations. Returns (log_u, log_v).

    Dispatch order:
      1. Triton fused kernels (if available + CUDA) — fastest
      2. torch.compile batched iterations (if available + CUDA) — fast
      3. Pure PyTorch fallback — always works
    """
    # Try Triton path first (single-pass fused logsumexp, no intermediate N×K)
    if log_K.is_cuda:
        try:
            from torchgw._triton_sinkhorn import triton_sinkhorn_loop
            return triton_sinkhorn_loop(log_K, log_a, log_b, tau, max_iter, tol, check_every, a,
                                        verbose, log_u_init=log_u_init, log_v_init=log_v_init)
        except (ImportError, RuntimeError):
            pass

        # Try torch.compile path
        try:
            iter_fn = _get_compiled_sinkhorn()
        except Exception:
            iter_fn = None

        if iter_fn is not None and not verbose and check_every > 1 and tol > 0:
            log_u = log_u_init if log_u_init is not None else torch.zeros_like(log_a)
            log_v = log_v_init if log_v_init is not None else torch.zeros_like(log_b)
            is_balanced = (tau == 1.0)
            done = 0
            while done < max_iter:
                batch = min(check_every, max_iter - done)
                log_u, log_v = iter_fn(
                    log_K, log_a, log_b, log_u, log_v,
                    is_balanced, tau, batch,
                )
                done += batch
                if tol > 0 and done % check_every == 0:
                    log_marginal = log_u + torch.logsumexp(log_K + log_v.unsqueeze(0), dim=1)
                    marginal_err = torch.abs(torch.exp(log_marginal) - a).max().item()
                    if marginal_err < tol:
                        break
            return log_u, log_v

    # Pure PyTorch fallback
    return _sinkhorn_loop_pytorch(log_K, log_a, log_b, tau, max_iter, tol, check_every, a,
                                   verbose, log_u_init=log_u_init, log_v_init=log_v_init)


def _sinkhorn_loop_pytorch(
    log_K: torch.Tensor,
    log_a: torch.Tensor,
    log_b: torch.Tensor,
    tau: float,
    max_iter: int,
    tol: float,
    check_every: int,
    a: torch.Tensor,
    verbose: bool = False,
    log_u_init: torch.Tensor | None = None,
    log_v_init: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Pure PyTorch Sinkhorn fallback (CPU or when Triton/compile unavailable)."""
    log_u = log_u_init if log_u_init is not None else torch.zeros_like(log_a)
    log_v = log_v_init if log_v_init is not None else torch.zeros_like(log_b)
    is_balanced = (tau == 1.0)

    for it in range(max_iter):
        log_u = log_a - torch.logsumexp(log_K + log_v.unsqueeze(0), dim=1)
        log_v_raw = log_b - torch.logsumexp(log_K + log_u.unsqueeze(1), dim=0)
        if is_balanced:
            log_v = log_v_raw
        else:
            log_v = tau * log_v_raw + (1 - tau) * log_v

        if tol > 0 and (it + 1) % check_every == 0:
            log_marginal = log_u + torch.logsumexp(log_K + log_v.unsqueeze(0), dim=1)
            marginal_err = torch.abs(torch.exp(log_marginal) - a).max().item()
            if verbose:
                print(f"    sinkhorn {it+1:>4}/{max_iter} | marginal_err: {marginal_err:.4e}")
            if marginal_err < tol:
                if verbose:
                    print(f"    sinkhorn converged at {it+1} (err={marginal_err:.4e})")
                break

    return log_u, log_v


[docs] def _sinkhorn_torch( a: torch.Tensor, b: torch.Tensor, C: torch.Tensor, reg: float, max_iter: int = 100, tol: float = 5e-4, check_every: int = 10, semi_relaxed: bool = False, rho: float = 1.0, verbose: bool = False, log_u_init: torch.Tensor | None = None, log_v_init: torch.Tensor | None = None, ) -> torch.Tensor: """Log-domain Sinkhorn for numerical stability. Pure PyTorch. Operates in whatever dtype the inputs are given (float32 or float64). Dtype selection is handled by the caller (_gw_loop via sink_dtype). Supports warm-starting via log_u_init/log_v_init from a previous solve. """ log_K = -C / reg log_a = torch.log(a.clamp(min=1e-30)) log_b = torch.log(b.clamp(min=1e-30)) tau = rho / (rho + reg) if semi_relaxed else 1.0 log_u, log_v = _sinkhorn_loop(log_K, log_a, log_b, tau, max_iter, tol, check_every, a, verbose=verbose, log_u_init=log_u_init, log_v_init=log_v_init) # Fused T materialization (Triton on CUDA, PyTorch fallback) if log_K.is_cuda: try: from torchgw._triton_sinkhorn import triton_materialize_T T = triton_materialize_T(log_u, log_K, log_v) except (ImportError, RuntimeError): T = torch.exp(log_u.unsqueeze(1) + log_K + log_v.unsqueeze(0)) else: T = torch.exp(log_u.unsqueeze(1) + log_K + log_v.unsqueeze(0)) # Stash potentials for warm-starting the next call T._log_u = log_u.detach() # type: ignore[attr-defined] T._log_v = log_v.detach() # type: ignore[attr-defined] return T
class _SinkhornAutograd(torch.autograd.Function): """Memory-efficient differentiable Sinkhorn. Gradient via envelope theorem: dL/dC = -T * grad_T / reg. """ @staticmethod def forward(ctx, C, a, b, reg, max_iter, tol, check_every, semi_relaxed, rho): log_K = -C / reg log_a = torch.log(a.clamp(min=1e-30)) log_b = torch.log(b.clamp(min=1e-30)) tau = rho / (rho + reg) if semi_relaxed else 1.0 log_u, log_v = _sinkhorn_loop(log_K, log_a, log_b, tau, max_iter, tol, check_every, a) T = torch.exp(log_u.unsqueeze(1) + log_K + log_v.unsqueeze(0)) ctx.save_for_backward(T) ctx.reg = reg return T @staticmethod def backward(ctx, grad_T): (T,) = ctx.saved_tensors grad_C = -grad_T * T / ctx.reg return grad_C, None, None, None, None, None, None, None, None
[docs] def _sinkhorn_differentiable( a: torch.Tensor, b: torch.Tensor, C: torch.Tensor, reg: float, max_iter: int = 100, tol: float = 5e-4, check_every: int = 10, semi_relaxed: bool = False, rho: float = 1.0, verbose: bool = False, log_u_init: torch.Tensor | None = None, log_v_init: torch.Tensor | None = None, ) -> torch.Tensor: """Differentiable Sinkhorn using custom autograd (memory-efficient).""" if semi_relaxed: raise NotImplementedError( "differentiable=True is not supported with semi_relaxed=True: " "the envelope theorem gradient is only valid for balanced Sinkhorn" ) return _SinkhornAutograd.apply( C, a, b, reg, max_iter, tol, check_every, semi_relaxed, rho, )
# ── Input coercion ────────────────────────────────────────────────────── def _to_tensor(x): """Convert numpy array or tensor to torch.Tensor. Pass through None.""" if x is None: return None if isinstance(x, torch.Tensor): return x return torch.as_tensor(x) # ── Shared preprocessing ──────────────────────────────────────────────── def _prepare_inputs( X_source, X_target, p, q, dist_source, dist_target, C_linear, distance_mode, fgw_alpha, k, n_landmarks, device, ): """Coerce inputs, infer N/K, validate, build distance provider. Returns (X_source, X_target, dist_source, dist_target, C_linear_t, N, K, provider, device). """ from torchgw._distances import DijkstraProvider, PrecomputedProvider, LandmarkProvider X_source = _to_tensor(X_source) X_target = _to_tensor(X_target) p = _to_tensor(p) q = _to_tensor(q) dist_source = _to_tensor(dist_source) dist_target = _to_tensor(dist_target) C_linear_t = _to_tensor(C_linear) # Infer N, K if dist_source is not None and dist_target is not None: N, K = dist_source.shape[0], dist_target.shape[0] elif X_source is not None and X_target is not None: N, K = X_source.shape[0], X_target.shape[0] elif C_linear_t is not None: N, K = C_linear_t.shape else: raise ValueError( "Cannot infer dataset sizes. Provide (X_source, X_target), " "(dist_source, dist_target), or C_linear." ) # Validate _VALID_MODES = {"precomputed", "dijkstra", "landmark"} if distance_mode not in _VALID_MODES: raise ValueError( f"distance_mode must be one of {_VALID_MODES}, got {distance_mode!r}" ) if fgw_alpha > 0 and C_linear_t is None: raise ValueError("fgw_alpha > 0 requires C_linear to be provided") if device is None: device = get_device() # Build distance provider if fgw_alpha >= 1.0: provider = None elif distance_mode == "precomputed": if dist_source is not None and dist_target is not None: provider = PrecomputedProvider(dist_source=dist_source, dist_target=dist_target) elif X_source is not None and X_target is not None: graphs = ( build_knn_graph(X_source.cpu().numpy(), k=k), build_knn_graph(X_target.cpu().numpy(), k=k), ) provider = PrecomputedProvider(graph_source=graphs[0], graph_target=graphs[1]) else: raise ValueError( "distance_mode='precomputed' requires (dist_source, dist_target) " "or (X_source, X_target)" ) elif distance_mode == "dijkstra": if X_source is None or X_target is None: raise ValueError("distance_mode='dijkstra' requires X_source and X_target") graph_source = build_knn_graph(X_source.cpu().numpy(), k=k) graph_target = build_knn_graph(X_target.cpu().numpy(), k=k) provider = DijkstraProvider(graph_source, graph_target) elif distance_mode == "landmark": if X_source is None or X_target is None: raise ValueError("distance_mode='landmark' requires X_source and X_target") graph_source = build_knn_graph(X_source.cpu().numpy(), k=k) graph_target = build_knn_graph(X_target.cpu().numpy(), k=k) provider = LandmarkProvider(graph_source, graph_target, n_landmarks=n_landmarks) return X_source, X_target, p, q, dist_source, dist_target, C_linear_t, N, K, provider, device def _gw_loop( *, N: int, K: int, provider, p_real: torch.Tensor, q_real: torch.Tensor, T_init: torch.Tensor, sinkhorn_fn, use_augmented: bool, s_shared: int | None, fgw_alpha: float, C_lin_device: torch.Tensor | None, M: int, alpha: float, max_iter: int, tol: float, epsilon: float, min_iter_before_converge: int, device: torch.device, verbose: bool, verbose_every: int, semi_relaxed: bool, rho: float, differentiable: bool = False, lambda_ema_beta: float | None = None, mixed_precision: bool = False, ) -> tuple[torch.Tensor, list, int, float]: """Shared main loop for sampled_gw and sampled_lowrank_gw. Parameters ---------- sinkhorn_fn : callable For standard: takes (p_aug, q_aug, Lambda_aug, reg, **kw) -> T_aug For low-rank: takes (p_real, q_real, Lambda, reg, **kw) -> T_new use_augmented : bool If True, build augmented cost/marginals and call sinkhorn_fn on them. If False, call sinkhorn_fn directly on (p_real, q_real, Lambda). Returns ------- T_out, err_list, n_iter, gw_cost_val """ if lambda_ema_beta is not None and not (0.0 <= lambda_ema_beta <= 1.0): raise ValueError(f"lambda_ema_beta must be in [0, 1], got {lambda_ema_beta}") if M < 1: raise ValueError(f"M must be >= 1, got {M}") # Sinkhorn internal dtype: float32 when mixed_precision, else float64 sink_dtype = torch.float32 if mixed_precision else torch.float64 T_real = T_init.to(sink_dtype) # Augmented marginals (only needed for standard Sinkhorn) if use_augmented: m_frac = s_shared / max(N, K) if s_shared is not None else min(N, K) / max(N, K) slack_p = max(q_real.sum().item() - m_frac, 1e-10) slack_q = max(p_real.sum().item() - m_frac, 1e-10) p_aug = torch.cat([p_real, torch.tensor([slack_p], device=device, dtype=torch.float64)]) q_aug = torch.cat([q_real, torch.tensor([slack_q], device=device, dtype=torch.float64)]) # Regularization decay (at most 10x reduction to avoid instability) initial_reg = epsilon if epsilon > 0 else 1e-4 final_reg = max(initial_reg / 10.0, min(5e-4, initial_reg)) decay = (final_reg / initial_reg) ** (1 / max(1, max_iter)) err_list = [] gw_cost_val = 0.0 n_iter = 0 Lambda_ema = None # EMA state for cost matrix smoothing _warm_log_u: torch.Tensor | None = None # Sinkhorn warm-start potentials _warm_log_v: torch.Tensor | None = None # Cost plateau detection via EMA + patience. # Critical because err = ||T - T_prev|| reflects sampling noise (not # optimization progress) and may never converge to tol. _cost_ema: float | None = None _best_cost_ema = float('inf') _no_improve = 0 _patience = max(min_iter_before_converge // 2, 20) # Pre-allocate augmented cost matrix and cast marginals (reused every iteration) if use_augmented: Lambda_aug = torch.zeros(N + 1, K + 1, device=device, dtype=sink_dtype) p_aug_sink = p_aug.to(sink_dtype) q_aug_sink = q_aug.to(sink_dtype) p_sink = p_real.to(sink_dtype) q_sink = q_real.to(sink_dtype) for i in range(max_iter): current_reg = initial_reg * (decay ** i) # In differentiable mode, detach T_prev to prevent computation graph # accumulation across iterations; only the final iteration's Sinkhorn # step will carry gradients through the momentum blend. T_prev = T_real.detach().clone() if differentiable else T_real.clone() # Sample anchor pairs (on GPU, only transfers 2*M ints back) j_left, l_target = sample_pairs_gpu(T_real.detach(), M) # Compute distances via provider if provider is not None: D_left, D_tgt = provider.get_distances(j_left, l_target, device) for D in [D_left, D_tgt]: inf_mask = torch.isinf(D) if torch.any(inf_mask): finite_vals = D[~inf_mask] fill = finite_vals.max() * 1.5 if finite_vals.numel() > 0 else 1.0 D[inf_mask] = fill mx = D.max() if mx > 0: D /= mx term_A = torch.mean(D_left ** 2, dim=1, keepdim=True) term_C = torch.mean(D_tgt ** 2, dim=1, keepdim=True).T term_B = -2 * (D_left @ D_tgt.T) / M Lambda_gw = term_A + term_B + term_C # Lambda EMA: smooth cost matrix across iterations # beta=0.0 is treated as disabled (same as None) if lambda_ema_beta is not None and lambda_ema_beta > 0: if Lambda_ema is None: Lambda_ema = Lambda_gw else: Lambda_ema = (1 - lambda_ema_beta) * Lambda_ema + lambda_ema_beta * Lambda_gw Lambda_gw = Lambda_ema.clone() else: Lambda_gw = None # FGW blending if fgw_alpha >= 1.0: Lambda = C_lin_device elif fgw_alpha > 0: Lambda = (1 - fgw_alpha) * Lambda_gw + fgw_alpha * C_lin_device else: Lambda = Lambda_gw # Sinkhorn step if use_augmented: Lambda_aug[:N, :K] = Lambda if Lambda.dtype == sink_dtype else Lambda.to(sink_dtype) penalty = 100.0 * Lambda.max().clamp(min=1.0) # stays on GPU, no sync Lambda_aug[:-1, -1] = penalty Lambda_aug[-1, :-1] = penalty Lambda_aug[-1, -1] = 0.0 verbose_sink = verbose and (n_iter + 1) % verbose_every == 0 T_aug = sinkhorn_fn(p_aug_sink, q_aug_sink, Lambda_aug, current_reg, semi_relaxed=semi_relaxed, rho=rho, verbose=verbose_sink, log_u_init=_warm_log_u, log_v_init=_warm_log_v) T_new = T_aug[:-1, :-1] # Retrieve potentials for warm-starting next iteration _warm_log_u = getattr(T_aug, '_log_u', None) _warm_log_v = getattr(T_aug, '_log_v', None) else: verbose_sink = verbose and (n_iter + 1) % verbose_every == 0 Lambda_sink = Lambda if Lambda.dtype == sink_dtype else Lambda.to(sink_dtype) T_new = sinkhorn_fn(p_sink, q_sink, Lambda_sink, current_reg, semi_relaxed=semi_relaxed, rho=rho, verbose=verbose_sink, log_u_init=_warm_log_u, log_v_init=_warm_log_v) _warm_log_u = getattr(T_new, '_log_u', None) _warm_log_v = getattr(T_new, '_log_v', None) # Momentum update T_real = (1 - alpha) * T_prev + alpha * T_new n_iter = i + 1 _check_interval = 5 # sync with CPU every N iterations # Compute metrics on GPU (no .item() sync) every iteration err_tensor = torch.linalg.norm(T_real - T_prev) # Only sync to CPU at check intervals (reduces CUDA sync overhead) if n_iter % _check_interval == 0 or i == max_iter - 1 or i >= min_iter_before_converge: Lambda_flat = Lambda.reshape(-1) if Lambda.dtype == sink_dtype else Lambda.to(sink_dtype).reshape(-1) gw_cost_val = torch.dot(Lambda_flat, T_real.reshape(-1)).item() err = err_tensor.item() err_list.append(err) if verbose and (n_iter % verbose_every == 0 or i == max_iter - 1): print(f" iter {n_iter:>4}/{max_iter} | err: {err:.4e} | " f"gw_cost: {gw_cost_val:.4e} | reg: {current_reg:.4e}") # Convergence: plan change OR cost EMA plateau _cost_ema = gw_cost_val if _cost_ema is None else 0.8 * _cost_ema + 0.2 * gw_cost_val if i >= min_iter_before_converge: if err < tol: if verbose: print(f" converged at iteration {n_iter} (err={err:.4e})") break if _cost_ema < _best_cost_ema * 0.995: _best_cost_ema = _cost_ema _no_improve = 0 else: _no_improve += 1 if _no_improve >= _patience: if verbose: print(f" cost plateau at iteration {n_iter} " f"(no improve for {_patience} iters, gw_cost={gw_cost_val:.4e})") break del T_new if use_augmented: del T_aug if provider is not None: del D_left, D_tgt, Lambda_gw # Cast back to float64 for output precision T_out = T_real if T_real.dtype == torch.float64 else T_real.to(torch.float64) if differentiable: return T_out, err_list, n_iter, gw_cost_val return T_out.detach(), err_list, n_iter, gw_cost_val # ── Multiscale helper ──────────────────────────────────────────────────── def _maybe_multiscale( multiscale, n_coarse, X_source, X_target, N, K, dist_source, dist_target, C_linear_t, fgw_alpha, distance_mode, n_landmarks, device, p_real, q_real, solver_fn, solver_kwargs, ): """Run coarse solve and return upsampled T_init, or None.""" if not multiscale or X_source is None or X_target is None: return None from torchgw._multiscale import fps_downsample, upsample_plan _n_coarse = n_coarse if n_coarse is not None else min(500, N // 4, K // 4) _n_coarse = max(_n_coarse, 10) if _n_coarse >= N or _n_coarse >= K: return None idx_src, assign_src = fps_downsample(X_source, _n_coarse) idx_tgt, assign_tgt = fps_downsample(X_target, _n_coarse) X_src_coarse = X_source[idx_src] X_tgt_coarse = X_target[idx_tgt] C_lin_coarse = None if C_linear_t is not None and fgw_alpha > 0: C_lin_coarse = C_linear_t[idx_src][:, idx_tgt] dist_src_coarse = None dist_tgt_coarse = None if dist_source is not None and dist_target is not None: dist_src_coarse = dist_source[idx_src][:, idx_src] dist_tgt_coarse = dist_target[idx_tgt][:, idx_tgt] coarse_kwargs = {**solver_kwargs} coarse_kwargs['multiscale'] = False # no recursion coarse_kwargs['M'] = min(solver_kwargs.get('M', 50), max(_n_coarse // 2, 10)) coarse_kwargs['k'] = min(solver_kwargs.get('k', 30), _n_coarse - 1) if solver_kwargs.get('s_shared') is not None: coarse_kwargs['s_shared'] = min(solver_kwargs['s_shared'], _n_coarse) T_coarse = solver_fn( X_src_coarse, X_tgt_coarse, distance_mode=distance_mode, dist_source=dist_src_coarse, dist_target=dist_tgt_coarse, n_landmarks=n_landmarks, fgw_alpha=fgw_alpha, C_linear=C_lin_coarse, device=device, **coarse_kwargs, ) T_init = upsample_plan(T_coarse, assign_src, assign_tgt, p_real, q_real) return T_init # ── Public API: standard solver ─────────────────────────────────────────
[docs] def sampled_gw( X_source: np.ndarray | torch.Tensor | None = None, X_target: np.ndarray | torch.Tensor | None = None, p: np.ndarray | torch.Tensor | None = None, q: np.ndarray | torch.Tensor | None = None, *, distance_mode: str = "dijkstra", dist_source: np.ndarray | torch.Tensor | None = None, dist_target: np.ndarray | torch.Tensor | None = None, n_landmarks: int = 50, fgw_alpha: float = 0.0, C_linear: np.ndarray | torch.Tensor | None = None, s_shared: int | None = None, M: int = 50, alpha: float = 0.9, max_iter: int = 500, tol: float = 1e-5, epsilon: float = 0.001, k: int = 30, min_iter_before_converge: int = 50, device: torch.device | None = None, verbose: bool = False, verbose_every: int = 20, log: bool = False, differentiable: bool = False, semi_relaxed: bool = False, rho: float = 1.0, multiscale: bool = False, n_coarse: int | None = None, lambda_ema_beta: float | None = None, mixed_precision: bool = False, ) -> torch.Tensor | tuple[torch.Tensor, dict]: """Sampled Gromov-Wasserstein alignment between two datasets. Uses standard log-domain Sinkhorn with slack augmentation for partial transport. Parameters ---------- X_source, X_target : ndarray or Tensor, optional Feature matrices. Required unless dist matrices or C_linear provided. p, q : ndarray or Tensor, optional Marginal distributions (uniform if None). distance_mode : str ``"dijkstra"`` (default), ``"precomputed"``, or ``"landmark"``. dist_source, dist_target : ndarray or Tensor, optional Precomputed (ns, ns) and (nt, nt) distance matrices. n_landmarks : int Landmark count for ``distance_mode="landmark"``. fgw_alpha : float FGW blending: 0 = pure GW, 1 = pure Wasserstein. C_linear : ndarray or Tensor, optional (ns, nt) feature cost for FGW. s_shared, M, alpha, max_iter, tol, epsilon, k : solver parameters min_iter_before_converge : int device : torch.device, optional verbose, verbose_every : progress printing log : bool Return (T, log_dict) if True. differentiable : bool Keep computation graph for backprop. semi_relaxed : bool Relax target marginal via KL penalty. rho : float KL penalty weight (semi_relaxed only). multiscale : bool Two-stage coarse-to-fine warm start. n_coarse : int, optional Coarse problem size (auto if None). lambda_ema_beta : float, optional EMA smoothing factor for the cost matrix. When set, maintains a running average: Lambda_ema = (1-beta)*Lambda_ema + beta*Lambda_sample. Reduces sampling variance at the cost of small bias that vanishes at convergence. Typical values: 0.3–0.7. None disables (default). mixed_precision : bool Run Sinkhorn iterations in float32 for speed, cast result back to float64. Safe because all critical ops are in log domain where values are O(log N). Marginals and transport plan stay in float64. Returns ------- T : Tensor (ns, nt) log_dict : dict (only if log=True) """ (X_source, X_target, p, q, dist_source, dist_target, C_linear_t, N, K, provider, device) = _prepare_inputs( X_source, X_target, p, q, dist_source, dist_target, C_linear, distance_mode, fgw_alpha, k, n_landmarks, device, ) # Marginals (float64) if p is not None: p_real = p.to(dtype=torch.float64, device=device) else: p_real = torch.ones(N, device=device, dtype=torch.float64) / N if q is not None: q_real = q.to(dtype=torch.float64, device=device) else: q_real = torch.ones(K, device=device, dtype=torch.float64) / K # Multiscale warm start T_init = _maybe_multiscale( multiscale, n_coarse, X_source, X_target, N, K, dist_source, dist_target, C_linear_t, fgw_alpha, distance_mode, n_landmarks, device, p_real, q_real, solver_fn=sampled_gw, solver_kwargs=dict( s_shared=s_shared, M=M, alpha=alpha, max_iter=max_iter, tol=tol, epsilon=epsilon, k=k, min_iter_before_converge=min_iter_before_converge, verbose=False, log=False, differentiable=differentiable, semi_relaxed=semi_relaxed, rho=rho, ), ) if T_init is None: T_init = torch.outer(p_real, q_real) # C_linear on device C_lin_device = C_linear_t.to(dtype=torch.float64, device=device) if C_linear_t is not None and fgw_alpha > 0 else None # Sinkhorn function ctx = torch.no_grad() if not differentiable else torch.enable_grad() sinkhorn_fn = _sinkhorn_differentiable if differentiable else _sinkhorn_torch with ctx: T_out, err_list, n_iter, gw_cost_val = _gw_loop( N=N, K=K, provider=provider, p_real=p_real, q_real=q_real, T_init=T_init, sinkhorn_fn=sinkhorn_fn, use_augmented=True, s_shared=s_shared, fgw_alpha=fgw_alpha, C_lin_device=C_lin_device, M=M, alpha=alpha, max_iter=max_iter, tol=tol, epsilon=epsilon, min_iter_before_converge=min_iter_before_converge, device=device, verbose=verbose, verbose_every=verbose_every, semi_relaxed=semi_relaxed, rho=rho, differentiable=differentiable, lambda_ema_beta=lambda_ema_beta, mixed_precision=mixed_precision, ) if log: return T_out, {"err_list": err_list, "n_iter": n_iter, "gw_cost": gw_cost_val} return T_out
# ── Public API: low-rank solver ─────────────────────────────────────────
[docs] def sampled_lowrank_gw( X_source: np.ndarray | torch.Tensor | None = None, X_target: np.ndarray | torch.Tensor | None = None, p: np.ndarray | torch.Tensor | None = None, q: np.ndarray | torch.Tensor | None = None, *, rank: int = 20, lr_max_iter: int = 5, lr_dykstra_max_iter: int = 50, distance_mode: str = "dijkstra", dist_source: np.ndarray | torch.Tensor | None = None, dist_target: np.ndarray | torch.Tensor | None = None, n_landmarks: int = 50, fgw_alpha: float = 0.0, C_linear: np.ndarray | torch.Tensor | None = None, s_shared: int | None = None, M: int = 50, alpha: float = 0.9, max_iter: int = 500, tol: float = 1e-5, epsilon: float = 0.001, k: int = 30, min_iter_before_converge: int = 50, device: torch.device | None = None, verbose: bool = False, verbose_every: int = 20, log: bool = False, semi_relaxed: bool = False, rho: float = 1.0, multiscale: bool = False, n_coarse: int | None = None, lambda_ema_beta: float | None = None, mixed_precision: bool = False, ) -> torch.Tensor | tuple[torch.Tensor, dict]: """Sampled Gromov-Wasserstein with low-rank Sinkhorn. Uses the low-rank Sinkhorn factorization (Scetbon, Cuturi & Peyre 2021) to reduce memory from O(NK) to O((N+K)*rank) per Sinkhorn step. This is a **memory optimization** for large-scale problems (N, K > 50k). At smaller scales, ``sampled_gw`` with standard Sinkhorn is faster. Parameters ---------- X_source, X_target : ndarray or Tensor, optional p, q : ndarray or Tensor, optional rank : int Nonneg. rank of the transport plan factorization. lr_max_iter : int Outer mirror descent iterations per Sinkhorn call. lr_dykstra_max_iter : int Inner Dykstra projection iterations per Sinkhorn call. distance_mode, dist_source, dist_target, n_landmarks : distance params fgw_alpha, C_linear : Fused GW params s_shared, M, alpha, max_iter, tol, epsilon, k : solver params min_iter_before_converge : int device : torch.device, optional verbose, verbose_every : progress printing log : bool semi_relaxed : bool rho : float multiscale : bool n_coarse : int, optional lambda_ema_beta : float, optional EMA smoothing factor for the cost matrix (see ``sampled_gw``). mixed_precision : bool Run internal computations in float32 for speed (see ``sampled_gw``). Returns ------- T : Tensor (ns, nt) log_dict : dict (only if log=True) """ if semi_relaxed: raise ValueError("semi_relaxed is not supported for low-rank Sinkhorn") from torchgw._lowrank import sinkhorn_lowrank (X_source, X_target, p, q, dist_source, dist_target, C_linear_t, N, K, provider, device) = _prepare_inputs( X_source, X_target, p, q, dist_source, dist_target, C_linear, distance_mode, fgw_alpha, k, n_landmarks, device, ) # Marginals (float64) if p is not None: p_real = p.to(dtype=torch.float64, device=device) else: p_real = torch.ones(N, device=device, dtype=torch.float64) / N if q is not None: q_real = q.to(dtype=torch.float64, device=device) else: q_real = torch.ones(K, device=device, dtype=torch.float64) / K # Multiscale warm start T_init = _maybe_multiscale( multiscale, n_coarse, X_source, X_target, N, K, dist_source, dist_target, C_linear_t, fgw_alpha, distance_mode, n_landmarks, device, p_real, q_real, solver_fn=sampled_lowrank_gw, solver_kwargs=dict( rank=rank, lr_max_iter=lr_max_iter, lr_dykstra_max_iter=lr_dykstra_max_iter, s_shared=s_shared, M=M, alpha=alpha, max_iter=max_iter, tol=tol, epsilon=epsilon, k=k, min_iter_before_converge=min_iter_before_converge, verbose=False, log=False, semi_relaxed=semi_relaxed, rho=rho, ), ) if T_init is None: T_init = torch.outer(p_real, q_real) # C_linear on device C_lin_device = C_linear_t.to(dtype=torch.float64, device=device) if C_linear_t is not None and fgw_alpha > 0 else None # Wrap sinkhorn_lowrank with fixed rank/iteration params def _lr_sinkhorn(a, b, C, reg, semi_relaxed=False, rho=1.0, verbose=False, log_u_init=None, log_v_init=None): return sinkhorn_lowrank( a, b, C, rank=rank, reg=reg, max_iter=lr_max_iter, dykstra_max_iter=lr_dykstra_max_iter, ) with torch.no_grad(): T_out, err_list, n_iter, gw_cost_val = _gw_loop( N=N, K=K, provider=provider, p_real=p_real, q_real=q_real, T_init=T_init, sinkhorn_fn=_lr_sinkhorn, use_augmented=False, s_shared=s_shared, fgw_alpha=fgw_alpha, C_lin_device=C_lin_device, M=M, alpha=alpha, max_iter=max_iter, tol=tol, epsilon=epsilon, min_iter_before_converge=min_iter_before_converge, device=device, verbose=verbose, verbose_every=verbose_every, semi_relaxed=False, rho=rho, lambda_ema_beta=lambda_ema_beta, mixed_precision=mixed_precision, ) if log: return T_out, {"err_list": err_list, "n_iter": n_iter, "gw_cost": gw_cost_val} return T_out