Gromov-Wasserstein at Scale, Beyond Squared Norms
A fundamental challenge in data science is to match disparate point sets with each other. While optimal transport efficiently minimizes point displacements under a bijectivity constraint, it is inherently sensitive to rotations. Conversely, minimizing distortions via the Gromov-Wasserstein (GW) framework addresses this limitation but introduces a non-convex, computationally demanding optimization problem. In this work, we identify a broad class of distortion penalties that reduce to a simple alignment problem within a lifted feature space. Leveraging this insight, we introduce an iterative GW solver with a linear memory footprint and quadratic (rather than cubic) time complexity. Our method is differentiable, comes with strong theoretical guarantees, and scales to hundreds of thousands of points in minutes. This efficiency unlocks a wide range of geometric applications and enables the exploration of the GW energy landscape, whose local minima encode the symmetries of the matching problem.
💡 Research Summary
The paper tackles the fundamental problem of matching heterogeneous point sets, a task that underlies many modern data‑science applications such as domain adaptation, shape registration, and graph learning. Classical optimal transport (OT) aligns points by minimizing displacement under a bijectivity constraint, but it requires a common embedding space and is highly sensitive to global rotations. The Gromov‑Wasserstein (GW) framework overcomes these limitations by measuring distortion of pairwise relationships instead of absolute positions, making it suitable for graphs, point clouds, and probability measures defined on different metric spaces. However, GW is a quadratic assignment problem, which is NP‑hard in general, and existing solvers rely on entropic regularization (EGW) combined with Sinkhorn iterations. This regularization introduces a geometric bias and leads to cubic‑time and quadratic‑memory costs that prohibit scaling beyond a few thousand points.
The authors propose to broaden the class of admissible distortion penalties from the traditional squared Euclidean distance to any conditionally negative‑type (CNT) cost. CNT costs include tree distances, hyperbolic geodesics, spherical distances, and all p‑norms with 0 < p ≤ 2. By Schoenberg’s theorem, every CNT cost can be embedded isometrically into a Hilbert space: there exists a mapping φ such that c(x,x′)=‖φ(x)−φ(x′)‖². Using this embedding, each point is lifted to an augmented feature (φ(x),½‖φ(x)‖²). The key theoretical contribution (Theorem 3.2) shows that the entropic GW problem with any CNT cost can be rewritten as
GWₑ(α,β)=C(α,β)+8 · min_{Γ∈HS(H_Y,H_X)} min_{π∈M(α,β)} F(Γ,π),
where Γ is a Hilbert‑Schmidt linear operator between the two Hilbert spaces, π is a transport plan, and
F(Γ,π)=‖Γ‖_{HS}²+(ε/8) KL(π‖α⊗β)−2∬⟨Φ(x),ΓΨ(y)⟩ dπ(x,y).
Crucially, F is convex in Γ for fixed π and convex in π for fixed Γ. The optimal Γ for a given π is simply the cross‑covariance matrix Z φ(x)ψ(y)ᵀ dπ, while the optimal π for a given Γ reduces to an entropic OT problem with a bilinear cost c_Γ(x,y)=−2⟨Φ(x),ΓΨ(y)⟩. Hence the non‑convex GW landscape can be solved by an alternating minimization scheme:
- π‑update: solve OT₍ε/8₎(α,β) with cost c_Γ (standard Sinkhorn, O(N²) time, O(N) memory).
- Γ‑update: compute the empirical cross‑covariance of the current coupling (closed‑form, O(N) operations).
This decomposition yields a solver whose memory footprint is linear in the number of points and whose per‑iteration cost is quadratic, a dramatic improvement over the O(N³) complexity of classic GW solvers. Moreover, the method is fully differentiable, making it suitable for deep‑learning pipelines.
The paper also addresses the well‑known entropic bias of EGW. By defining a Sinkhorn GW divergence
SGWₑ(α,β)=OTₑ(α,β)−½
Comments & Academic Discussion
Loading comments...
Leave a Comment