API Reference
Solver
- torchgw.sampled_gw(X_source=None, X_target=None, p=None, q=None, *, distance_mode='dijkstra', dist_source=None, dist_target=None, n_landmarks=50, fgw_alpha=0.0, C_linear=None, s_shared=None, M=50, alpha=0.9, max_iter=500, tol=1e-05, epsilon=0.001, k=30, min_iter_before_converge=50, device=None, verbose=False, verbose_every=20, log=False, differentiable=False, semi_relaxed=False, rho=1.0, multiscale=False, n_coarse=None, lambda_ema_beta=None, mixed_precision=False)[source]
Sampled Gromov-Wasserstein alignment between two datasets.
Uses standard log-domain Sinkhorn with slack augmentation for partial transport.
- Parameters:
X_source (ndarray or Tensor, optional) – Feature matrices. Required unless dist matrices or C_linear provided.
X_target (ndarray or Tensor, optional) – Feature matrices. Required unless dist matrices or C_linear provided.
p (ndarray or Tensor, optional) – Marginal distributions (uniform if None).
q (ndarray or Tensor, optional) – Marginal distributions (uniform if None).
distance_mode (str) –
"dijkstra"(default),"precomputed", or"landmark".dist_source (ndarray or Tensor, optional) – Precomputed (ns, ns) and (nt, nt) distance matrices.
dist_target (ndarray or Tensor, optional) – Precomputed (ns, ns) and (nt, nt) distance matrices.
n_landmarks (int) – Landmark count for
distance_mode="landmark".fgw_alpha (float) – FGW blending: 0 = pure GW, 1 = pure Wasserstein.
C_linear (ndarray or Tensor, optional) – (ns, nt) feature cost for FGW.
s_shared (solver parameters)
M (solver parameters)
alpha (solver parameters)
max_iter (solver parameters)
tol (solver parameters)
epsilon (solver parameters)
k (solver parameters)
min_iter_before_converge (int)
device (torch.device, optional)
verbose (progress printing)
verbose_every (progress printing)
log (bool) – Return (T, log_dict) if True.
differentiable (bool) – Keep computation graph for backprop.
semi_relaxed (bool) – Relax target marginal via KL penalty.
rho (float) – KL penalty weight (semi_relaxed only).
multiscale (bool) – Two-stage coarse-to-fine warm start.
n_coarse (int, optional) – Coarse problem size (auto if None).
lambda_ema_beta (float, optional) – EMA smoothing factor for the cost matrix. When set, maintains a running average: Lambda_ema = (1-beta)*Lambda_ema + beta*Lambda_sample. Reduces sampling variance at the cost of small bias that vanishes at convergence. Typical values: 0.3–0.7. None disables (default).
mixed_precision (bool) – Run Sinkhorn iterations in float32 for speed, cast result back to float64. Safe because all critical ops are in log domain where values are O(log N). Marginals and transport plan stay in float64.
- Returns:
T (Tensor (ns, nt))
log_dict (dict (only if log=True))
- Return type:
- torchgw.sampled_lowrank_gw(X_source=None, X_target=None, p=None, q=None, *, rank=20, lr_max_iter=5, lr_dykstra_max_iter=50, distance_mode='dijkstra', dist_source=None, dist_target=None, n_landmarks=50, fgw_alpha=0.0, C_linear=None, s_shared=None, M=50, alpha=0.9, max_iter=500, tol=1e-05, epsilon=0.001, k=30, min_iter_before_converge=50, device=None, verbose=False, verbose_every=20, log=False, semi_relaxed=False, rho=1.0, multiscale=False, n_coarse=None, lambda_ema_beta=None, mixed_precision=False)[source]
Sampled Gromov-Wasserstein with low-rank Sinkhorn.
Uses the low-rank Sinkhorn factorization (Scetbon, Cuturi & Peyre 2021) to reduce memory from O(NK) to O((N+K)*rank) per Sinkhorn step.
This is a memory optimization for large-scale problems (N, K > 50k). At smaller scales,
sampled_gwwith standard Sinkhorn is faster.- Parameters:
X_source (ndarray or Tensor, optional)
X_target (ndarray or Tensor, optional)
p (ndarray or Tensor, optional)
q (ndarray or Tensor, optional)
rank (int) – Nonneg. rank of the transport plan factorization.
lr_max_iter (int) – Outer mirror descent iterations per Sinkhorn call.
lr_dykstra_max_iter (int) – Inner Dykstra projection iterations per Sinkhorn call.
distance_mode (distance params)
dist_source (distance params)
dist_target (distance params)
n_landmarks (distance params)
fgw_alpha (Fused GW params)
C_linear (Fused GW params)
s_shared (solver params)
M (solver params)
alpha (solver params)
max_iter (solver params)
tol (solver params)
epsilon (solver params)
k (solver params)
min_iter_before_converge (int)
device (torch.device, optional)
verbose (progress printing)
verbose_every (progress printing)
log (bool)
semi_relaxed (bool)
rho (float)
multiscale (bool)
n_coarse (int, optional)
lambda_ema_beta (float, optional) – EMA smoothing factor for the cost matrix (see
sampled_gw).mixed_precision (bool) – Run internal computations in float32 for speed (see
sampled_gw).
- Returns:
T (Tensor (ns, nt))
log_dict (dict (only if log=True))
- Return type:
Graph Construction
- torchgw.build_knn_graph(X, k=30)[source]
Build a k-NN graph, stitching disconnected components if needed.
- Parameters:
X (ndarray of shape (n_samples, n_features))
k (number of nearest neighbors)
- Returns:
graph – Symmetric sparse distance graph, guaranteed connected.
- Return type:
csr_matrix of shape (n_samples, n_samples)
Joint Embedding
- torchgw.joint_embedding(anchor_name, data_by_name, graphs_by_name, transport_plans, lambda_reg=1.0, out_dim=30)[source]
Compute joint manifold embedding using transport plans.
- Parameters:
anchor_name (str) – Name of the anchor/reference dataset.
data_by_name (dict mapping name -> (n_samples, n_features) array)
graphs_by_name (dict mapping name -> kNN graph (csr_matrix))
transport_plans (dict mapping (query_name, anchor_name) -> T array)
lambda_reg (float) – Regularization weight for Laplacian.
out_dim (int) – Dimensionality of output embedding.
- Returns:
embeddings
- Return type:
dict mapping name -> (n_samples, out_dim) array
Internal Modules
Sinkhorn
- torchgw._solver._sinkhorn_torch(a, b, C, reg, max_iter=100, tol=0.0005, check_every=10, semi_relaxed=False, rho=1.0, verbose=False, log_u_init=None, log_v_init=None)[source]
Log-domain Sinkhorn for numerical stability. Pure PyTorch.
Operates in whatever dtype the inputs are given (float32 or float64). Dtype selection is handled by the caller (_gw_loop via sink_dtype). Supports warm-starting via log_u_init/log_v_init from a previous solve.
Sampling
- torchgw._sampling.sample_pairs_from_plan(T, M, rng=None)[source]
Sample M (row, col) pairs from transport plan T, weighted by mass.
Uses the Gumbel-max trick for vectorized categorical sampling, avoiding a Python for-loop over M pairs.
- Parameters:
T (ndarray of shape (N, K), non-negative)
M (number of pairs to sample)
rng (numpy Generator, optional) – Random number generator for reproducibility. Uses a new default generator if None.
- Returns:
rows (ndarray of shape (M,))
cols (ndarray of shape (M,))
- Return type: