Quick Start

Installation

git clone https://github.com/chansigit/torchgw.git
cd torchgw
pip install -e .

Requirements: numpy, scipy, scikit-learn, torch>=2.0, joblib. Source code: github.com/chansigit/torchgw.

Basic Usage

import numpy as np
from torchgw import sampled_gw, build_knn_graph

# Two point clouds (dimensions may differ)
X = np.random.randn(500, 3).astype(np.float32)
Y = np.random.randn(600, 5).astype(np.float32)

# Compute transport plan
T = sampled_gw(X, Y, epsilon=0.005, M=80, max_iter=200)

# T[i,j] is the coupling weight between X[i] and Y[j]
print(T.shape)  # (500, 600)

With Precomputed Graphs

Building the kNN graph once and reusing it avoids redundant computation:

g_src = build_knn_graph(X, k=10)
g_tgt = build_knn_graph(Y, k=10)

T = sampled_gw(X, Y, graph_source=g_src, graph_target=g_tgt,
               epsilon=0.005, M=80, max_iter=200)

Semi-Relaxed Mode

When source and target have different compositions (e.g., a cell type present in source but absent in target), balanced GW forces mass onto wrong matches. Semi-relaxed GW fixes the source marginal but lets the target marginal adapt:

# Balanced (default): T @ 1 = p,  T.T @ 1 = q  (both enforced)
T = sampled_gw(X, Y, epsilon=0.005)

# Semi-relaxed: T @ 1 = p (enforced),  T.T @ 1 ≈ q (soft KL penalty)
T = sampled_gw(X, Y, epsilon=0.005, semi_relaxed=True, rho=1.0)

# rho controls how strictly q is enforced:
#   rho → ∞  : recovers balanced GW
#   rho → 0  : target marginal is completely free

Convergence Logging

T, info = sampled_gw(X, Y, epsilon=0.005, max_iter=200, log=True)
print(info["n_iter"])    # actual iterations run
print(info["err_list"])  # per-iteration convergence errors

Differentiable Mode

For end-to-end training, keep the computation graph:

T = sampled_gw(X, Y, epsilon=0.005, differentiable=True)
# T is differentiable w.r.t. the cost matrix via envelope theorem

Uses a custom torch.autograd.Function that saves only the transport plan (not the full Sinkhorn iteration history), so memory overhead is minimal.

Joint Embedding

After computing a transport plan, embed both datasets into a shared space:

from torchgw import joint_embedding

emb = joint_embedding(
    anchor_name="tgt",
    data_by_name={"src": X, "tgt": Y},
    graphs_by_name={"src": g_src, "tgt": g_tgt},
    transport_plans={("src", "tgt"): T},
    out_dim=10,
)
print(emb["src"].shape)  # (500, 10)
print(emb["tgt"].shape)  # (600, 10)