scDFM: Distributional Flow Matching Model for Robust Single-Cell Perturbation Prediction
A central goal in systems biology and drug discovery is to predict the transcriptional response of cells to perturbations. This task is challenging due to the noisy and sparse nature of single-cell measurements, as well as the fact that perturbations often induce population-level shifts rather than changes in individual cells. Existing deep learning methods typically assume cell-level correspondences, limiting their ability to capture such global effects. We present scDFM, a generative framework based on conditional flow matching that models the full distribution of perturbed cells conditioned on control states. By incorporating a maximum mean discrepancy (MMD) objective, our method aligns perturbed and control populations beyond cell-level correspondences. To further improve robustness to sparsity and noise, we introduce the Perturbation-Aware Differential Transformer (PAD-Transformer), a backbone architecture that leverages gene interaction graphs and differential attention to capture context-specific expression changes. Across multiple genetic and drug perturbation benchmarks, scDFM consistently outperforms prior methods, demonstrating strong generalization in both unseen and combinatorial settings. In the combinatorial setting, it reduces mean squared error by 19.6% relative to the strongest baseline. These results highlight the importance of distribution-level generative modeling for robust in silico perturbation prediction. The code is available at https://github.com/AI4Science-WestlakeU/scDFM
💡 Research Summary
The paper introduces scDFM, a novel generative framework for predicting single‑cell transcriptional responses to genetic or drug perturbations. The authors begin by highlighting two fundamental challenges in this domain: (1) single‑cell RNA‑seq data are inherently noisy, sparse, and zero‑inflated, making it difficult to learn reliable cell‑level mappings; (2) perturbations often induce population‑level shifts—changes in sub‑population proportions, variance, and higher‑order moments—rather than simple per‑cell expression changes. Existing deep‑learning approaches typically assume a one‑to‑one correspondence between control and perturbed cells, focusing on mean expression reconstruction and consequently missing these distributional effects.
To address these issues, scDFM combines conditional flow matching (CFM) with a multi‑kernel maximum mean discrepancy (MMD) regularizer. CFM, originally proposed for continuous‑time generative modeling, learns a time‑dependent velocity field vθ(x_t | t, c_x, c_p) that morphs a noisy source distribution (x₀) into the target perturbed distribution (x₁). The source is generated by adding Gaussian noise to control cells c_x, while the target is drawn from the empirical post‑perturbation distribution conditioned on the same control and a perturbation embedding c_p (a multi‑hot vector). The model minimizes the L2 distance between the predicted velocity and the analytical velocity of a linear interpolation path π_t(x₀, x₁) = (1 − t)x₀ + t x₁, ensuring that each intermediate state follows a biologically plausible trajectory.
Because CFM only enforces local dynamical consistency, the authors augment training with an MMD term that directly compares the set of generated terminal samples (\hat{x}_1) with real perturbed cells x₁. They employ a mixture of Gaussian RBF kernels with bandwidths chosen via a median heuristic, allowing the loss to capture discrepancies across multiple scales (mean, variance, and higher‑order structure). The final objective is L = L_CFM + λ L_MMD, where λ balances trajectory fidelity against distributional alignment.
A key contribution is the Perturbation‑Aware Differential Transformer (PAD‑Transformer), which serves as the backbone for encoding gene expression and predicting velocities. PAD‑Transformer incorporates three innovations: (i) a gene‑gene co‑expression graph is used as an attention mask, forcing each gene token to attend only to biologically related neighbors; (ii) differential attention separates control and perturbation tokens, enabling the model to learn interaction‑specific representations; and (iii) time embeddings are injected at each layer, providing explicit temporal context for the continuous flow. Gene embeddings are obtained via a cross‑attention encoder, and expression values are projected into a shared latent space before entering the transformer blocks. This architecture mitigates over‑fitting to noise, respects regulatory network structure, and scales to high‑dimensional gene spaces.
The authors evaluate scDFM on two challenging benchmarks. The Norman dataset contains CRISPR‑based gene knock‑outs and combinatorial double‑knock‑outs; experiments test (a) an “additive” setting where all single knock‑outs and a subset of doubles are seen during training, and (b) a “hold‑out” setting where specific double combinations are completely unseen. The Combosiplex dataset comprises drug‑pair perturbations with similar split strategies. Metrics include mean squared error (MSE), Pearson correlation, and the raw MMD between generated and true distributions. Across all settings, scDFM outperforms state‑of‑the‑art baselines such as CPA, GEARS, CellFlow, UNLASTING, and recent transformer‑based foundation models (Geneformer, scGPT). Notably, in the combinatorial (unseen) scenario scDFM reduces MSE by 19.6 % relative to the strongest baseline and achieves substantially lower MMD scores, demonstrating superior capture of population‑level shifts.
Ablation studies confirm the importance of each component: removing the MMD term degrades distributional fidelity, while replacing PAD‑Transformer with a vanilla transformer or a simple MLP reduces robustness to sparsity and harms performance on unseen perturbations. The authors also discuss computational efficiency: the flow‑matching loss avoids costly reverse‑diffusion sampling, and the multi‑kernel MMD can be computed in linear time with respect to batch size using unbiased estimators.
In summary, scDFM advances single‑cell perturbation modeling by (1) framing the problem as a conditional continuous‑time flow that respects the stochastic nature of cellular transitions, (2) explicitly aligning generated and real distributions via MMD, and (3) leveraging a graph‑aware, differential transformer to encode regulatory relationships and perturbation context. The work demonstrates that distribution‑level generative modeling, rather than pointwise regression, is essential for accurate, generalizable in‑silico prediction of cellular responses, especially in combinatorial settings where experimental data are scarce. The authors release code and pretrained models, opening the door for downstream applications such as personalized drug response prediction, synthetic perturbation screening, and integration with multimodal single‑cell atlases.
Comments & Academic Discussion
Loading comments...
Leave a Comment