Learning, Solving and Optimizing PDEs with TensorGalerkin: an efficient high-performance Galerkin assembly algorithm
We present a unified algorithmic framework for the numerical solution, constrained optimization, and physics-informed learning of PDEs with a variational structure. Our framework is based on a Galerkin discretization of the underlying variational forms, and its high efficiency stems from a novel highly-optimized and GPU-compliant TensorGalerkin framework for linear system assembly (stiffness matrices and load vectors). TensorGalerkin operates by tensorizing element-wise operations within a Python-level Map stage and then performs global reduction with a sparse matrix multiplication that performs message passing on the mesh-induced sparsity graph. It can be seamlessly employed downstream as i) a highly-efficient numerical PDEs solver, ii) an end-to-end differentiable framework for PDE-constrained optimization, and iii) a physics-informed operator learning algorithm for PDEs. With multiple benchmarks, including 2D and 3D elliptic, parabolic, and hyperbolic PDEs on unstructured meshes, we demonstrate that the proposed framework provides significant computational efficiency and accuracy gains over a variety of baselines in all the targeted downstream applications.
💡 Research Summary
The paper introduces TensorGalerkin, a unified high‑performance framework for solving, learning, and optimizing partial differential equations (PDEs) that admit a variational formulation. Traditional finite‑element method (FEM) pipelines suffer from severe bottlenecks when implemented in Python‑based automatic‑differentiation (AD) environments: element‑wise loops incur interpreter overhead, scatter‑add reductions fragment the computational graph, and back‑propagation becomes memory‑ and time‑intensive. TensorGalerkin eliminates these issues by recasting the Galerkin assembly as a strictly tensorized Map‑Reduce operation that runs entirely on the GPU.
In the Map (Batch‑Map) stage, all elements are lifted to a batch dimension. Geometric data (coordinates, Jacobians), basis function gradients, and physical coefficients are stored as high‑dimensional tensors. A single torch.einsum call contracts the quadrature weights, transformed gradients, and coefficient values to produce a dense local stiffness tensor K_local ∈ ℝ^{E×k×k} for all E elements simultaneously. This eliminates per‑element loops, yields a compact computation graph, and exploits batched GEMM kernels for maximal throughput.
The Reduce (Sparse‑Reduce) stage replaces the traditional scatter‑add with a deterministic sparse matrix multiplication (SpMM). A pre‑computed binary routing matrix encodes the mapping from local degrees of freedom to global indices. Global assembly is performed as K = Rᵀ·K_local·R, where R is the routing matrix. Because SpMM is a single, well‑optimized kernel, the reduction is both fast and fully differentiable, and the same operation can be reused during back‑propagation without recreating fragmented graphs.
Built on top of this core, the authors deliver three downstream tools:
-
TensorMesh – a GPU‑native FEM solver that solves the linear system
K U = Fwith standard Krylov methods. Benchmarks on 2‑D and 3‑D elliptic, parabolic, and hyperbolic problems show speed‑ups of 5–30× over CPU‑based FEM libraries (e.g., FEniCS, deal.II) and 2–10× over existing JAX‑FEM implementations, while reducing peak memory usage by 40–70 %. -
TensorPils – a physics‑informed operator learning algorithm. Instead of using automatic differentiation to compute spatial derivatives (as in PINNs), TensorPils evaluates the residual
‖K(ρ) U_θ(ρ) − F(ρ)‖²analytically using the same tensorized assembly, thus avoiding deep computational graphs. Experiments demonstrate convergence 4× faster than PINNs and comparable or better L2 errors with far fewer training epochs. -
TensorOpt – an end‑to‑end differentiable pipeline for PDE‑constrained optimization (e.g., inverse design). Gradients with respect to design parameters
ρare obtained efficiently via the same tensorized assembly, cutting gradient‑computation time by over 80 % and enabling near‑real‑time design loops.
The paper also discusses limitations: the current implementation is tuned for low‑order polynomial bases and regular element types; scaling the routing matrix to problems with millions of degrees of freedom may require compressed representations; and handling strongly nonlinear stiffness matrices would need iterative re‑assembly strategies. Nonetheless, TensorGalerkin represents a significant step toward practical, GPU‑accelerated FEM in Python ecosystems and opens new possibilities for data‑scarce physics‑informed learning and large‑scale PDE‑based optimization.
Comments & Academic Discussion
Loading comments...
Leave a Comment