U-REPA: Aligning Diffusion U-Nets to ViTs
Representation Alignment (REPA) that aligns Diffusion Transformer (DiT) hidden-states with ViT visual encoders has proven highly effective in DiT training, demonstrating superior convergence properties, but it has not been validated on the canonical diffusion U-Net architecture that shows faster convergence compared to DiTs. However, adapting REPA to U-Net architectures presents unique challenges: (1) different block functionalities necessitate revised alignment strategies; (2) spatial-dimension inconsistencies emerge from U-Net’s spatial downsampling operations; (3) space gaps between U-Net and ViT hinder the effectiveness of tokenwise alignment. To encounter these challenges, we propose \textbf{U-REPA}, a representation alignment paradigm that bridges U-Net hidden states and ViT features as follows: Firstly, we propose via observation that due to skip connection, the middle stage of U-Net is the best alignment option. Secondly, we propose upsampling of U-Net features after passing them through MLPs. Thirdly, we observe difficulty when performing tokenwise similarity alignment, and further introduces a manifold loss that regularizes the relative similarity between samples. Experiments indicate that the resulting U-REPA could achieve excellent generation quality and greatly accelerates the convergence speed. With CFG guidance interval, U-REPA could reach $FID<1.5$ in 200 epochs or 1M iterations on ImageNet 256 $\times$ 256, and needs only half the total epochs to perform better than REPA under sd-vae-ft-ema. Codes: https://github.com/YuchuanTian/U-REPA
💡 Research Summary
The paper introduces U‑REPA, a novel representation‑alignment framework that extends the successful REPA technique—originally designed for Diffusion Transformers (DiT) and Vision Transformers (ViT)—to the canonical diffusion U‑Net architecture. While REPA has demonstrated accelerated convergence and superior sample quality when aligning DiT hidden states with ViT embeddings, its direct application to U‑Nets is non‑trivial due to three fundamental mismatches: (1) differing block functionalities caused by U‑Net’s skip connections, (2) spatial‑dimension inconsistencies arising from the multi‑scale down‑sampling/ up‑sampling pipeline, and (3) a large “space gap” between the token‑wise representations of U‑Net and ViT, which makes cosine‑similarity‑based losses ineffective.
To overcome these obstacles, the authors propose three key design choices. First, they empirically discover that the middle stage of a U‑Net—situated after the down‑sampling path but before the up‑sampling decoder—is the most semantically rich layer for alignment. This is attributed to the way skip connections redistribute semantic information across the network, making the intermediate features the best proxy for the high‑level representations captured by ViT. Second, to reconcile spatial resolution, they explore three up‑scaling pipelines and find that applying a lightweight MLP to the low‑resolution U‑Net features before up‑sampling yields the best trade‑off between computational cost and performance. The MLP expands the channel dimension, after which a pixel‑unshuffle (or similar) operation restores the spatial size to match ViT’s patch grid. Third, recognizing that direct token‑wise cosine similarity is too rigid given the architectural gap, they introduce a manifold loss that regularizes the relative similarity among samples in a batch. This loss aligns the pairwise distance matrices of U‑Net and ViT embeddings, encouraging the two feature spaces to share the same relational geometry rather than exact point‑wise correspondence.
Extensive experiments on ImageNet‑256 validate the approach. With classifier‑free guidance (CFG) enabled, U‑REPA reaches an FID below 1.5 after only 200 epochs (≈1 M diffusion steps), which is roughly half the training budget required by the original REPA under the same sd‑vae‑ft‑ema setting, while achieving a lower final FID of 1.41. Ablation studies confirm that (i) the middle‑stage alignment is crucial, (ii) “MLP‑first‑then‑upscale” is the most efficient scaling strategy, and (iii) the manifold loss substantially improves convergence stability compared to a pure cosine‑similarity loss. Additional analyses show that the speed advantage of U‑Nets over DiTs stems primarily from hierarchical down‑sampling, which creates compact, semantically dense bottleneck representations; skip connections further mitigate information loss but are not the main driver of fast convergence.
The paper also proposes DiT↓ and SiT↓ variants that incorporate recent tricks such as rotary positional embeddings (RoPE) and SwiGLU, demonstrating that the insights from U‑REPA can be transferred back to transformer‑based diffusion models. All code and pretrained weights are released, facilitating reproducibility and future research.
In summary, U‑REPA bridges the architectural divide between diffusion U‑Nets and ViT encoders through (1) strategic middle‑layer selection, (2) MLP‑guided up‑scaling of low‑resolution features, and (3) a manifold‑based relational alignment loss. This combination yields dramatically faster convergence and state‑of‑the‑art image quality, establishing a new baseline for diffusion models that combine the efficiency of U‑Nets with the semantic richness of modern vision transformers.
Comments & Academic Discussion
Loading comments...
Leave a Comment