Polynomial Speedup in Diffusion Models with the Multilevel Euler-Maruyama Method
We introduce the Multilevel Euler-Maruyama (ML-EM) method compute solutions of SDEs and ODEs using a range of approximators $f^1,\dots,f^k$ to the drift $f$ with increasing accuracy and computational cost, only requiring a few evaluations of the most…
Authors: Arthur Jacot
P olynomial Sp eedup in Diffusion Mo dels with the Multilev el Euler-Maruy ama Metho d Arth ur Jacot Marc h 26, 2026 Abstract W e in tro duce the Multilev el Euler-Maruyama (ML-EM) method compute solutions of SDEs and ODEs using a range of approximators f 1 , . . . , f k to the drift f with increasing accuracy and computational cost, only requiring a few ev aluations of the most accurate f k and many ev aluations of the less costly f 1 , . . . , f k − 1 . If the drift lies in the so-called Harder than Mon te Carlo (HTMC) regime, i.e. it requires ϵ − γ compute to b e ϵ -approximated for some γ > 2 , then ML-EM ϵ -approximates the solution of the SDE with ϵ − γ compute, impro ving ov er the traditional EM rate of ϵ − γ − 1 . In other terms it allows us to solv e the SDE at the same cost as a single ev aluation of the drift. In the con text of diffusion mo dels, the different levels f 1 , . . . , f k are obtained b y training UNets of increasing sizes, and ML-EM allo ws us to perform sampling with the equiv alent of a single ev aluation of the largest UNet. Our numerical exp erimen ts confirm our theory: w e obtain up to fourfold sp eedups for image generation on the CelebA dataset do wnscaled to 64 × 64 , where w e measure a γ ≈ 2 . 5 . Giv en that this is a p olynomial sp eedup, w e exp ect ev en stronger sp eedups in practical applications whic h inv olve orders of magnitude larger net works. 1 In tro duction Denoising Diffusion Probabilistic Mo dels (DDPMs) [25, 9, 28] are the state of the art tec hnique for generating images [22], videos, and man y more [29, 1]. The images are generated b y a diffusion pro cess, a Sto c hastic Differen tial Equation (SDE), or sometimes an ODE which starts from Gaussian noise and ends up with a distribution that appro ximates the ‘true data distribution’. The drift term is learned with a large Deep Neural Netw ork (DNN) - t ypically a UNet, a type of Conv olutional Neural Netw ork (CNN) - so that the computational cost of DDPMs is dominated by the n umber of DNN ev aluations (sometimes written NFE: “n umber of function ev aluations”), which has to increase to reach higher levels of accuracy , or equiv alen tly a smaller error ϵ b et ween the con tinuous SDE solution and the chosen discretization. This high computational cost has a large environmen tal impact giv en the wide spread use of these models and the large scale of the underlying DNN. It also limits the application of these techniques to setting that require real-time generation such as music. Several metho ds hav e b een prop osed to reduce this computational cost: The SDE noise forces re lativ ely small step sizes, so [26] hav e prop osed replacing the bac kward SDE with a backw ard ODE (or something in b et w een, a SDE with a smaller Bro wnian term) called the Denoising Diffusion Implicit Mo dels (DDIM). DDIMs can ac hiev e realistic images with an order of magnitude less steps than DDPMs, thus drastically reducing the NFE. Assuming p erfect estimation of the score, the final distribution should b e the same for DDPMs and DDIMs, in practice how ever, it app ears that DDIMs result in sligh tly low er quality images, even at the smallest step sizes. Several other metho ds ha ve b een prop osed to further ‘straighten’ the flow [14, 13], or reduce the dimensionality by w orking in a latent space [22]. The ODE form ulation also op ens the do or to adapting efficient solvers suc h as the Runge-Kutta family of metho ds to DDIMs [16, 17]. Y et another approach is to train a new DNN to implement multiple steps of the denoising pro cess, or even the full denoising pro cess [23, 19, 24, 27]. 1 1.1 Compute Scaling Analysis W e start with a simple ‘napkin math’ analysis of the scaling T ∼ ϵ − α of the compute time T required to generate an image with error ϵ . The error ϵ can b e decomp osed into a discretization error ϵ discr and appro ximation error ϵ approx . Discretization error: W e exp ect the discretization error ϵ discr to scale as n − 1 ϕ where n is the num ber of steps or NFE. F or example, one has ϕ = 1 with the Euler-Maruy ama metho d, and ϕ = 1 4 for 4th order Runge-Kutta metho d. Appro ximation error: W e exp ect the approximation error ϵ approx b et w een our DNN and the ‘true denoiser’ (or equiv alently the true score) to follow a typical scaling law ϵ approx = N − 1 ψ + P − 1 γ for N the n umber of training samples and P the num ber of parameters, matching what has b een observed empirically [12, 8, 10]. Since the training data size is irrelev an t to the computational cost of generating images, we drop the data-size term N − 1 ψ (i.e. we assume that we alw ays hav e enough data that the appro ximation error is dominated by the netw ork size, and not the dataset size). W e will consider the rate γ to be giv en, but there exists a range of theoretical w orks that give predictions for γ under a num b er of settings [2, 4, 5, 20]. Assuming that the computational cost of a DNN is prop ortional to P , it is optimal to c ho ose n ∼ ϵ − ϕ and P ∼ ϵ − γ to reach an error of ϵ at a computational cost of order nP ∼ ϵ − ( ϕ + γ ) . This leads to a scaling la w of ϵ − ( γ +1) for the Euler-Maruyama metho d and ϵ − ( γ + 1 4 ) for Runge-Kutta or a v ariant thereof. Most of the aforemen tioned metho ds should only lead to constant improv ements, with no improv emen ts on the exp onen ts, except for the improv ed solvers inspired b y Runge-Kutta, as already mentioned. In this framework, it seems that a rate of ϵ − γ is imp ossible, b ecause an y discretization metho d would require a growing NFE to reach smaller and smaller errors. This pap er shows that a rate of ϵ − γ is actually p ossible if we assume that the score is hard enough to appro ximate γ > 2 . Inspired by Multilevel Monte Carlo (MLMC) metho ds, we rely on a M ultilev el Euler- Maruy ama metho d, where a range of DNNs of increasing sizes are randomly used at eac h discretization step, using large DNNs with lo w probability and small ones with high probability , thus leading to a small o verall computational cost. This is esp ecially impactful for SDEs, where no fast discretization metho ds exist (i.e. no Runge-Kutta or analogue, with ϕ < 1 ). It might seem almost paradoxical ho w assuming that the task is “hard enough” allows us to gain sp eedups. But previous work has shown the emergence of b eneficial prop erties suc h as con vexit y in the so-called Harder than Mon te Carlo (HTMC) regime (when γ > 2 ) [11]. This same paper [11] also pro ves a connection b et w een the sets of functions approximable with a DNN and different HTMC spaces, w hic h motiv ates our HTMC assumption ( γ > 2 ). This assumption is further motiv ated by empirical evidence that DNNs trained on images follow scaling laws [8], which obtains empirical rates ϵ 2 ∼ P − 0 . 24 for 8 × 8 images, P − 0 . 22 for 16 × 16 , and P − 0 . 13 for 32 × 32 images, whic h would corresp ond to γ ≈ 8 . 3 , γ ≈ 9 . 1 , and γ ≈ 15 . 4 resp ectively . These are all far into the HTMC regime, and the bigger the image the larger the rate γ . This paper is just a first example of the kind of speedups that can be obtained under the HTMC assumption. 2 Setup This pap er fo cuses on the efficien t approximation of SDEs of the form dx t = f t ( x t ) dt + σ t dW t . F or simplicity , we assume that the noise is isotropic and that its v ariance σ 2 t dep ends on time t but not on x t . W e then consider ODEs as the sp ecial case σ t = 0 . Example 1 (Denoising Diffusion Probabilistic Mo del - DDPM) . The main motiv ation is to apply it to the rev erse diffusion pro cess (which starts from a large T > 0 and then go es back in time until reaching t = 0 , hence the minus in front of dx t ). − dx t = 1 2 x t + s t ( x t ) dt + dW t for the score s t ( x t ) = ∇ log ρ t where ρ t is the distribution of √ e − t x 0 + √ 1 − e − t N (0 , 1) . 2 Example 2 (Denoising Diffusion Implicit Mo del - DDIM) . W e will also consider the probability flow ODE, whic h has the same marginal distribution at each time t , but has no noise term: − dx dt = 1 2 x t + 1 2 s t ( x t ) . What makes the diffusion SDE unique is the fact that the score s t can only b e approximated by very large DNNs, and these DNNs generally follow scaling laws, i.e. the size of the netw ork (and therefore its computational cost) must scale rapidly if one w ants to obtain more and more accurate approximations of the true score. W e formalize this into the follo wing assumption: Assumption 1 (Scaling Law Assumption) . Ther e is a se quenc e of estimators f k t which appr oximate f t within a 2 − k err or f t − f k t ∞ ≤ 2 − k for al l t ∈ [0 , T ] and whose c ompute C ( f k t ) sc ales exp onential ly in k : C ( f k t ) ≤ c γ 2 γ k for some γ (the c onvention of taking the c onstant c to the γ -th p ower wil l le ad to cle aner formulas). This assumption is motiv ated by the strong empirical evidence that DNNs follow scaling laws [12, 8]: there is a γ such that the test error can b e b ounded q E ∥ f ( x ) − f θ ( x ) ∥ 2 ≤ cP − 1 γ in terms of the num b er of parameters P of the DNN, a scaling constant γ and a prefactor c . Roughly sp eaking, for fully-connected net works, the computational cost of ev aluating f θ is proportional to P , since each parameter is used once, for CNNs it scales as P w h (for w , h the width and heigh t of the image), for RNNs as P ℓ (where ℓ is the sequence length), and for T ransformers as P ℓ 2 . W e can then rewrite the b ound to match the scaling assumption: C ( f θ ) ≤ sc γ ϵ − γ (where s is either 1 , w h, ℓ or ℓ 2 ). The ubiquit y of these scaling laws in practice implies that this is a very reasonable assumption (note that w e assume a b ound on the L ∞ rather than the L 2 error, but this is mainly to simplify the deriv ations, a b ound on the L 2 could b e shown to b e enough with a few extra assumption and a bit more w ork). The constant c in the assumption is closely related to the so-called HTMC norm ∥ f ∥ M γ as defined in [11]: if for all k the estimator f k t has minimal computational complexity amongst all 2 − k -estimators (in the sense that it minimizes circuit size), then c = ∥ f ∥ M γ . Ho wev er we do not need to assume that we ha ve found the most computational efficient estimator for the results of this pap er to apply , and th us in general w e only ha ve c ≥ ∥ f ∥ M γ . Euler-Maruy ama Metho d: The baseline w e consider to appro ximate our SDE algorithmically is the Euler-Maruy ama metho d [18] together with a certain approximation f k of f (typically this would b e the largest DNN we can train), yielding y t + η = y t + η f k t ( y t ) + √ η σ t Z t with a step size η (which we assume constan t) and Z t ∼ N (0 , 1) . In the absence of noise ( σ t = 0 ), we reco ver the Euler metho d. R emark 1 . Note that for DDPM, the Euler-Maruyama discretization is sligh tly differen t from the usual implemen tation of DDPM, and similarly for DDIM. W e describ e in App endix A why the t wo are equiv alent up to sub dominant terms as the learning rate go es to zero. 3 Multilev el Euler-Maruy ama Our strategy is to use a v ariation on the Multilevel Monte Carlo (MLMC) [6, 7] metho d at eac h step of the discretization 1 : y t + η = y t + η k max X k = k min B k p k f k t ( y t ) − f k − 1 t ( y t ) + √ η σ t Z t 1 The original motiv ation for MLMC was to compute exp ectations ov er the sampling of SDEs, in which case multiple dis- cretization of SDE paths are computed, with different step-sizes to obtain different levels of accuracy . In our case, w e only want to ev aluate one SDE path, and the multiple levels result from the different DNN sizes. This similarity in the setting might lead to confusion, but MLMC is not sp ecific to SDEs, it can b e applied whenev er one has access to a range of estimators with different errors and computational cost. 3 1 0 2 2 × 1 0 1 3 × 1 0 1 4 × 1 0 1 6 × 1 0 1 Generation time (s) 1 0 3 1 0 2 MSE A B C D E EM (k=1,...,5) ML -EM (lear ned coeffs) ML -EM (fix ed pr obs) ML -EM (fix ed pr obs) (a) DDPM MSE (b) DDPM samples 0 5 10 15 20 25 30 Generation time (s) 0.00 0.01 0.02 0.03 0.04 0.05 0.06 0.07 MSE A B C D E EM (k=1,...,5) ML -EM (lear ned coeffs) (c) DDIM MSE (d) DDIM samples Figure 1: (Left) W e compare ML-EM to EM metho d of generation for DDPM (top) and DDIM (b ottom) b y plotting the MSE b et ween the generated sample and the ‘true’ sample (generated with a 1000 steps DDPM/DDIM) with the same initial and Bro wnian noise, the x -axis is the time in seconds required to generate 200 images. Solid lines are the traditional EM metho d with different netw ork sizes f 1 , . . . , f 5 and with num ber of steps ranging from 58 to 933 . The crosses and dots are the ML-EM method with three net works { f 1 , f 3 , f 5 } and with either fixed probabilities or learned co efficien ts α k , β k (see Section 4). W e add a ∆ ∈ {− 3 . 0 , − 2 . 5 , . . . , 2 . 5 , 3 . 0 } to the β k s and p erform 15 trials ov er the sampling of the Bernoullis R V s (remem b er that the starting n oise and Bro wnian motion are fixed). The sampling of the Bernoullis that yield the smallest MSE can b e memorized, it is therefore ok ay to compare the straight lines of classical EM to the b est trials of ML-EM. (Right) The first 6 generated images for the ‘true sample’ and four selected instances of EM (A,B) and ML-EM (C,D,E). F or DDPMs, ML-EM with learned co efficien ts clearly outp erforms all other metho ds, requiring in some cases 4 times less compute time than EM to reach the same MSE, or reac hing a 10 times smaller MSE at the same compute time. F or DDIM the adv an tage of ML-EM is less clear, but still presen t. Visually , it app ears that the main adv antage of ML-EM is that it av oids discolorations/contrast issues present for EM with few steps. In terestingly , DDIM app ears to suffer from these discoloration ev en with 1000 steps. 4 where B k ∼ Bernoulli ( p k ) . The idea is that we are going to choose a probability p k that decreases exp o- nen tially in k so that at most steps, we will not need to ev aluate the b est estimator f k max . Note that our guaran tees will therefore b e in terms of the exp ected computational cost E C ( y T ) = X t,k p k C ( f k t ) ≤ T η c γ X k p k 2 γ k . One could then use the probabilistic method to imply the existence of a deterministic choice of the B k that reac hes a certain error at a certain computational cost. In practice, we observ e that C ( y T ) concentrat es in its exp ectation, whereas the error exhibits a significant v ariance ov er the sampling of the B k (though it is very consisten t across different initialization of the SDE and the sampling of the Bro wnian motion). W e therefore p erform a b est of 15 to iden tify the optimal choices of Bernoulli random v ariables B k . W e choose k min ( t ) = − ⌈ log 2 ∥ f t ∥ ∞ ⌉ so that we ma y assume that we may choose f k min ( t ) − 1 t = 0 as an estimator, and thus we recov er the Euler-Maruyama metho d in exp ectation E [ y t + η | y t ] = y t + η f k max t ( y t ) + √ η σ t Z t . W e will also mak e the following classical Lipsc hitzness assumptions: Assumption 2. F or al l t and k , Lip ( f t ) , Lip ( f k t ) ≤ L . W e no w b ound the distance b et ween y t and the Euler-Maruyama discretization x ( η ) t of the true flow x ( η ) t + η = x ( η ) t + η f t ( x ( η ) t ) + √ η σ t Z t . Theorem 1. Under Assumptions 1 and 2, for any step size η > 0 , err or ϵ > 0 , and time T = iη > 0 , if we cho ose k min = − ⌊ log 2 c ⌋ , k max = − log 2 2 L e L ( T + η ) ϵ and p k = min { C 2 − (1+ γ 2 ) k , 1 } for some c onstant C , we have E x ( η ) T − y T 2 ≤ ϵ 2 at an exp e cte d c omputational c ost of at most 18 L 3 T 3 + LT 2 E γ ce L ( T + η ) Lϵ wher e E γ ( r ) = 1 (1 − 2 γ 2 − 1 ) 2 r 2 γ < 2 r 2 (3 + log 2 r ) γ = 2 2 3( γ − 2) 2 γ 2 − 1 − 1 2 r γ γ > 2 . Harder than Mon te Carlo (HTMC) regime ( γ > 2 ): This result is particularly relev ant in the Harder than Monte Carlo (HTMC) regime [11], when γ > 2 , where the computational complexity ( ϵ − γ ) of solving the SDE is the same as the complexit y as a single ev aluation of the b est estimator! As already discussed in Section 1.1, there is strong empirical evidence that DNNs follo w scaling laws whic h are “flat enough” to corresp ond to γ s that are far into the HTMC regime. In Figure 2, we estimate γ ≈ 2 . 5 for the CelebA dataset (cropped and downscaled to 64 × 64 ). Indep endence on step-size η : Notice how the bound on the compute required to reac h an ϵ i s essentially indep enden t of the step-size η (to b e precise, as η ↘ 0 , it decreases to a finite v alue). This is b ecause the probabilit y that we ev aluate an y one of the levels f k is proportional to η , therefore the n umber of ev aluations of each lev el f k remains constant as η ↘ 0 . In this limit y t con verges to some form of Poisson jump pro cess that approximates the original SDE x t with the same error and compute guarantees. This also implies that there is no need to use more complex discretization sc heme than the EM metho d, b ecause one can alwa ys tak e a smaller η at no computational cost (the only cost is that we need to sample the noise Z t and add it, but when working with large DNNs, this computational cost is negligible in comparison to the DNN ev aluations). Cho osing the probabilities p k : Theorem 1 requires a v ery sp ecific c hoice of probabilities, which requires kno wledge of the rate γ , Lipschitz constant L which are not really accessible at first glance. Thankfully , it turns out that we ha ve a lot of flexibility in our c hoice of the p k s, the pro of can easily adapted to sho w that if 5 p k = C 2 − β k for constant C and an y exp onen t β that lies in the range (2 , γ ) , then the exp ected squared error will b e O ( C − 1 ϵ 2 − β ) with an O ( C ϵ γ − β ) exp ected computational cost, so that by choosing C ∼ ϵ − β , one can reac h an ϵ error with an O ( ϵ − γ ) compute, recov ering the right rate. This means that we only hav e to tune one h yp er-parameter, C , to reach the optimal rate. Cho osing β = 2 or β = γ also leads to the right rates up to some additional log ϵ terms, but these are particularly straightforw ard to implement: β = γ corresp onds to choosing p k in versely prop ortional to the compute time of f k , which can easily b e estimated. Nev ertheless, w e also prop ose in the next section a metho d for learning the p k s with SGD to obtain as m uch computational gains as p ossible, by not only obtaining the optimal rate, but also the optimal prefactor. Cho osing k min and k max : The choice of k min has very little impact on the final error. The c hoice of k max induces a low er bound on the minimal error that we can reach, since the ML-EM metho d will alw ays b e less accurate than using only the b est estimator f k max (though it can reach a similar error m uch faster). In practice the choice of k max will mostly b e determined by computational constraints: what is the largest net work that can reasonably b e trained on a certain compute budget. Exp onen tial Constan t: The exp onen tial term e T L emerges naturally from the use of a Grà ¶ nw all pro of technique and also app ears in the classical EM metho d. It represen ts the fact that in the worst case an error of size ϵ in the first few steps of the SDE could get scaled up b y e T L when we reach the final time T (e.g. if f ( x ) = Lx ). In diffusion mo dels, since denoising acts as a form of contraction rather than an expansion of the error, it is reasonable to hop e that this exp onen tial blow-up will b e naturally av oided, and this seems to b e what we observe in our exp eriments. Note that if one were instead in a setting where this exp onen tial blow-up is real, it might b e adv antageous to choose time-dep enden t probabilities p k ( t ) that decrease in time, to make less errors at times t whose errors will b e most impactful. W e discuss a metho d for doing so in the next section. 3.1 A daptiv e Method The question of how the errors from differen t times t propagate to the final time is v ery crucial in practice. ideally we w ould like to adapt our estimation metho d to b e more accurate at times t where errors contribute more to the final error. This can b e achiev ed by letting the probabilities p k and the max accuracy k max dep end on time t . One could try to b ound this error propagation with some quantit y and use it to choose p k ( t ) and k max ( t ) . Instead we take a very “deep learning” approac h and learn the optimal probabilities by minimizing the error with SGD directly . W e consider a simple time dep endence p k ( t ) = σ ( α k log( t + δ ) + β k ) for parameters α k , β k and small δ ( δ = 0 . 1 in our exp eriments) and the sigmoid σ . Our goal is to find the parameters α k , β k that minimize the regularized loss L λ ( α k , β k ) = E x T ,Z t ,B k x ( η ) T − y T 2 + λ T η − 1 X i =0 p k ( iη ) T k where T k is the computational cost (either in FLOPs or in time) of one ev aluation of f k , whic h can be easily estimated empirically . The exp ectation is ov er the sampling of the starting point x T ∼ N (0 , 1) of the bac kward pro cess and the noise Z t ∼ N (0 , 1) and Bernoullis B k ( t ) ∼ B ernoul li ( p k ( t )) at each step. There are tw o issues that make it hard to compute the gradient ∇L λ ( α k , β k ) : we need to differentiate ‘through’ the sampling of the Bernoulli random v ariables, and on a more practical level, w e cannot realistically p erform backpropagation through the whole SDE as it would require keeping in memory all activ ations of ev ery application of the netw ork (for all times iη and all samples of x T , Z t ) whic h w ould o v ersho ot our memory budget. But these can b e fixed with the right techniques: Differen tiating through Bernoullis: F or an y function f ( B ) of a Bernoulli random v ariable B ∼ B er noull i ( p ) , the deriv ative of the exp ectation E [ f ( B )] w.r.t. its probability p is f (1) − f (0) . Since E f ( B ) B − p p (1 − p ) = pf (1) 1 − p p (1 − p ) + (1 − p ) f (0) 0 − p p (1 − p ) = f (1) − f (0) , w e can use f ( B ) B − p p (1 − p ) as an unbiased estimator for d dp E [ f ( B )] . Now note that b ecause w e divide by p (1 − p ) whic h approaches zero as p ≈ 0 , 1 this estimator could p oten tially hav e a lot of v ariance, but thankfully if p 6 is parametrized as a sigmoid, as in our setting p ( t ) = σ ( α log ( t + δ ) + β ) , then by the c hain rule, we hav e ∂ α E [ f ( B )] = E f ( B ) B − p ( t ) p ( t )(1 − p ( t )) p ( t )(1 − p ( t )) log( t + δ ) = E [ f ( B )( B − p ( t ))] log ( t + δ ) ∂ β E [ f ( B )] = E f ( B ) B − p ( t ) p ( t )(1 − p ( t )) p ( t )(1 − p ( t )) = E [ f ( B )( B − p ( t ))] so that the estimates f ( B )( B − p ( t )) log( t + δ ) and f ( B )( B − p ( t )) for the deriv ative w.r.t. α and β can b e exp ected to hav e b ounded v ariance as long as f ( B ) and log ( t + δ ) remain b ounded. F orward gradient computation instead of backpropagation: T o av oid the memory cost of back- propagation, we instead rely on forward propagation [3], whic h allows us to compute the scalar pro duct ∇L T λ v of the gradient ∇L λ with a vector v , at a constan t memory usage in time iη . The gradien t ∇L λ can then b e appro ximated by ∇L λ v v T for a random Gaussian vector v ∼ N (0 , I ) . This is again an un biased estimator since E v ∇L λ v v T = ∇L λ I = ∇L λ . Putting everything together, we estimate the gradient ∇ α L λ b y x ( η ) T − y T 2 T η − 1 X i =1 B k ( iη ) − p k ( iη ) log( iη + δ ) + ∇ AD x ( η ) T − y T 2 v v T α + λ T η − 1 X i =0 T k p k ( iη )(1 − p k ( iη )) log( iη + δ ) where v is a random Gaussian v ector of dimension 2( k max − k min ) (that is the same dimension as the α k , β k ) and v α is the first half of v whic h corresp onds to the α s. F or the second term ∇ AD x ( η ) T − y T 2 is the “automatic differentiation” gradien t which treats the B k as if they were indep enden t of p k . Finally note how we use traditional differentiation for the regularization term, since it do es not suffer from the tw o aforemen tioned c hallenges. Our estimate for the gradient ∇ β L λ is obtained b y remo ving the log( iη + δ ) terms and replacing v α b y v β whic h is the second half of the vector v . 4 Numerical Exp erimen ts T raining: The exp erimen ts are p erformed on the CelebA dataset [15], cropped and rescaled to a size of 64 × 64 . This task was c hosen as it matched the relatively limited compute at our access (t wo GeF orce R TX 2080 Ti). W e train a sequence of UNets f 1 , f 2 , f 3 , f 4 , f 5 of increasing sizes, resulting in b etter and b etter approxi- mations of the true score. Our UNets ha ve the following prop erties: • At eac h lev el of the UNet, we divide the image dimension by tw o and double the n umber of channels (starting from a certain “base dimension”). W e hav e 4 levels so that at the “b ottom” of the UNet has a 8 × 8 shap e. • The filters are factored as the comp osition of a per-channel 3 × 3 con volution follow ed b y a 1 × 1 con volution across channels. • There are L 1 residual lay ers at the b ottom of the UNet, and L 2 residual lay ers at the shallow er scales, in b oth the downscaling and the upscaling parts. • The four different net works hav e base dimensions 8 , 16 , 32 , 64 , bottom depths L 1 = 5 , 10 , 20 , 40 and in termediate depths L 2 = 2 , 3 , 5 , 7 resp ectiv ely . • Each of these net works w ere first trained separately on the usual denoising loss, with Adam. 7 Note that it is now common practice to train multiple lo wer size mo dels to do h yp er-parameter search b efore training the largest models, so practitioners migh t already hav e access to a set of trained mo dels with a range of sizes and accuracies. And even if this is not the case, the computational of cost of training the smaller mo dels is almost insignificant in comparison to the training cost of the larger mo dels. Generation: F or image generation, w e follow ed the standard DDPM pro cedure with a baseline of 1000 steps with a cosine noise sc hedule [21]. W e also applied clipping to the predicted denoised image [9]. Since we do not hav e access to the true score we will use the largest mo del f 5 with 1000 steps of generations as our ‘true generated sample’ and ev aluate other generation metho ds in terms of how differen t their generated images are from this true sample (with the same starting noise x T and SDE noise Z t ). F or the baseline EM metho d, w e try a range of num b er of steps: 250, 500, 750, 900, 1000 ov er the 5 net works f 1 , . . . , f 5 . W e can see how changing the n umber of steps allows us tradeoff computation time for a smaller error, and that the error seems to saturate a bit b efore 1000 steps. Ob viously when we approximate the 1000 steps f 5 generation with the same net work with fewer steps, the error drops very suddenly to zero as the num ber of steps approaches 1000, but this very small error is misleading, b ecause our actual goal is to approach the generated images with the true score and we used our largest UNet f 5 as a proxy for it. W e therefore only fo cus on errors ab o ve 10 − 3 as an ything b elo w this threshold is ov erfitting to f 5 rather than approac hing the true score. F or ML-EM we only used three mo dels { f 1 , f 3 , f 5 } . The probabilities p k w ere c hosen with three strategies: • “Fixed probs.” orange crosses: T aking p k = C T − 1 k is the simplest metho d, since the av erage time computation time (or FLOPs) T k of f k can easily be computed. As discussed in Section 3 this is sufficien t to obtain optimal rates. W e then v ary C to obtain a range of errors/times. With this metho d, the probabilities are constan t in time. • “Fixed probs.” green crosses: F rom our theory the optimal choice of p k should be p k = C 0 2 − (1+ γ 2 ) k = C T − ( 1 γ + 1 2 ) k . W e estimate γ = 2 . 5 (see Figure 2) and therefore choose p k = C T − 0 . 9 o ver a range of C s. W e do not obs erv e any significant differences b et ween the tw o “Fixed probs” metho ds. • “Learned co effs.”, blue dots: W e optimize the α k , β k parameters with 50 steps of SGD (as describ ed in Section 3.1) with a batch size of 300 and λ = 0 . 1 for DDPMs and λ = 1 . 0 for DDIM. W e then obtain a range of errors/times b y adding a delta to the constant co efficien ts β k ← β k + ∆ for ∆ ranging from − 3 . 0 to 3 . 0 . This metho d clearly outp erforms the “Fixed probs.” metho ds. GPU batching: In GPUs, the compute time is t ypically only linear in the num b er of function ev aluations i f these function ev aluations are batched. T o tak e adv antage of this we generate N = 200 images sim ultaneously and share the Bernoulli v ariables across the batc h, so that w e either ha ve to ev aluate f k o ver the whole batc h or not at all, leading to a significant sp eedup. Ho wev er w e do not use this tric k when learning the α k , β k with SGD, b ecause we need our approximate gradien t to concentrate, and sharing Bernoullis breaks the indep endence leading to a higher v ariance. 5 Conclusion W e introduce the ML-EM metho d for discretizing SDEs and ODEs that is esp ecially useful when the drift term lies in Harder than Monte Carlo (HTMC) regime. This app ears to apply to typical applications of diffusion mo dels, leading to a fourfold sp eedups in the time and compute required to generate high quality images. The adv antage of ML-EM o ver EM should only increase for larger and more complex datasets, and so one could exp ect tenfold sp eedups or more at the kind of scales that are common in industry . This method can also b e used in combination with other metho ds for sp eeding up diffusion mo dels, such as DDIM. References [1] Marlo es Arts, Victor Garcia Satorras, Chin-W ei Huang, Daniel Zugner, Marco F ederici, Cecilia Clementi, F rank Noé, Rob ert Pinsler, and Rianne v an den Berg. T wo for one: Diffusion mo dels and force fields for coarse-grained molecular dynamics. Journal of Chemic al The ory and Computation , 19(18):6151–6159, 2023. 8 [2] Y asaman Bahri, Ethan Dyer, Jared Kaplan, Jaehoon Lee, and Utk arsh Sharma. Explaining neural scaling laws. Pr o c e e dings of the National A c ademy of Scienc es , 121(10), 2024. [3] A tılım Güneş Baydin, Barak A P earlmutter, Don Syme, F rank W o o d, and Philip T orr. Gradients without bac kpropagation. arXiv pr eprint arXiv:2202.08587 , 2022. [4] Blak e Bordelon, Alexander Atanaso v, and Cengiz Pehlev an. A dynamical mo del of neural scaling la ws. arXiv , 2024. [5] Blak e Bordelon, Mary I. Letey , and Cengiz Pehlev an. Theory of scaling la ws for in-context regression: Depth, width, context and time. arXiv , 2025. [6] Mic hael B. Giles. Multilevel monte carlo path simulation. Op er ations R ese ar ch , 56(3):607–617, 2008. [7] Mic hael B. Giles. Multilevel monte carlo metho ds. A cta Numeric a , 24:259–328, 2015. [8] T om Henighan, Jared Kaplan, Mor Katz, Mark Chen, Christopher Hesse, Jacob Jac kson, Heewoo Jun, T om B Brown, Prafulla Dhariwal, Scott Gray , et al. Scaling la ws for autoregressiv e generative mo deling. arXiv pr eprint arXiv:2010.14701 , 2020. [9] Jonathan Ho, Ajay Jain, and Pieter Abb eel. Denoising diffusion probabilistic mo dels. A dvanc es in neur al information pr o c essing systems , 33:6840–6851, 2020. [10] Jordan Hoffmann, Sebastian Borgeaud, Arthur Mensch, Elena Buchatsk ay a, T revor Cai, Eliza Ruther- ford, Diego de Las Casas, Lisa Anne Hendricks, Johannes W elbl, Aidan Clark, et al. T raining compute- optimal large language mo dels. arXiv pr eprint arXiv:2203.15556 , 2022. [11] Arth ur Jacot. Deep learning as a con vex paradigm of computation: Minimizing circuit size with resnets. arXiv pr eprint arXiv:2511.20888 , 2025. [12] Jared Kaplan, Sam McCandlish, T om Henighan, T om B Brown, Benjamin Chess, Rewon Child, Scott Gra y , Alec Radford, Jeffrey W u, and Dario Amo dei. Scaling laws for neural language mo dels. arXiv pr eprint arXiv:2001.08361 , 2020. [13] Y aron Lipman, Ricky T. Q. Chen, Heli Ben-Hamu, Maximilian Nick el, and Matthew Le. Flow netw ork based generative models. In International Confer enc e on L e arning R epr esentations , 2023. [14] Xingc hao Liu, Chengyue Gong, and Qiang Liu. Flow straight and fast: Learning to generate and transfer data with rectified flow. In International Confer enc e on L e arning R epr esentations , 2023. [15] Ziw ei Liu, Ping Luo, Xiaogang W ang, and Xiao ou T ang. Deep learning face attributes in the wild. In Pr o c e e dings of International Confer enc e on Computer Vision (ICCV) , Decem b er 2015. [16] Cheng Lu, Y uhao Zhou, F an Bao, Jianfei Chen, Chongxuan Li, and Jun Zhu. Dpm-solver: A fast o de solv er for diffusion probabilistic mo del sampling in around 10 steps. A dvanc es in neur al information pr o c essing systems , 35:5775–5787, 2022. [17] Cheng Lu, Y uhao Zhou, F an Bao, Jianfei Chen, Chongxuan Li, and Jun Zh u. Dpm-solv er++: F ast solv er for guided sampling of diffusion probabilistic mo dels. arXiv pr eprint arXiv:2211.01095 , 2022. [18] Gisiro Maruyama. On the conv ergence of numerical differentiation for sto c hastic differential equations. R endic onti del Cir c olo Matematic o di Palermo , 4(1):48–85, 1955. [19] Chenlin Meng, Robin Rombac h, Ruiqi Gao, Diederik Kingma, Stefano Ermon, Jonathan Ho, and Tim Salimans. On distillation of guided diffusion mo dels. In Pr o c e e dings of the IEEE/CVF Confer enc e on Computer Vision and Pattern R e c o gnition , pages 14297–14306, 2023. [20] Eric J. Michaud, Ziming Liu, Uza y Girit, and Max T egmark. The quantization mo del of neural scaling. arXiv , 2023. 9 [21] Alexander Quinn Nic hol and Prafulla Dhariwal. Improv ed denoising diffusion probabilistic mo dels. In International c onfer enc e on machine le arning , pages 8162–8171. PMLR, 2021. [22] Robin Rombac h, Andreas Blattmann, Dominik Lorenz, Patric k Esser, and Björn Ommer. High- resolution image synthesis with latent diffusion mo dels. In Pr o c e e dings of the IEEE/CVF Confer enc e on Computer Vision and Pattern R e c o gnition , pages 10684–10695, 2022. [23] Tim Salimans and Jonathan Ho. Progressiv e distillation for fast sampling of diffusion models. In International Confer enc e on L e arning R epr esentations , 2022. [24] Axel Sauer, Dominik Lorenz, Andreas Blattmann, and Robin Rom bach. Adv ersarial diffusion distillation. arXiv pr eprint arXiv:2311.17042 , 2023. [25] Jasc ha Sohl-Dic kstein, Eric W eiss, Niru Mahesw aranathan, and Sury a Ganguli. Deep unsupervised learning using nonequilibrium thermo dynamics. In Pr o c e e dings of the 32nd International Confer enc e on Machine L e arning , volume 37 of Pr o c e e dings of Machine L e arning R ese ar ch , pages 2256–2265. PMLR, 07–09 Jul 2015. [26] Jiaming Song, Chenlin Meng, and Stefano Ermon. Denoising diffusion implicit mo dels. In International Confer enc e on L e arning R epr esentations , 2021. [27] Y ang Song, Prafulla Dhariw al, Mark Chen, and Ilya Sutskev er. Consistency mo dels. In International Confer enc e on Machine L e arning , pages 32211–32252. PMLR, 2023. [28] Y ang Song, Jascha Sohl-Dickstein, Diederik P Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole. Score-based generativ e mo deling through sto c hastic differen tial equations. In International Confer enc e on L e arning R epr esentations , 2021. [29] Joseph L W atson, David Juergens, Nathaniel R Bennett, Brian L T ripp e, Jason Yim, Helen E Eisenach, W o o dy Ahern, Andrew J Borst, Rob ert J Ragotte, Luk as F Milles, et al. De nov o design of protein structure and function with rfdiffusion. Natur e , 620(7976):1089–1100, 2023. 10 1 0 4 t 3 × 1 0 2 4 × 1 0 2 6 × 1 0 2 denoising er r or - 0.15 U N e t s f k 0 . 0 6 t 0 . 4 Figure 2: Estimating γ ≈ 2 . 5 : W e plot the denoising error ϵ minus 0 . 15 against the ev aluation time for a range of UNets f 1 , . . . , f 5 . The constant 0 . 15 w as chosen by hand to approximate the minimal denoising error (it was c hosen so that the set of p oin ts would align as closely as p ossible to a line in the log-log plot). W e see that on a log-log scale the plot fits well with a ϵ ∼ t − 0 . 4 slop e, which would corresp ond to γ = 1 0 . 4 = 2 . 5 , whic h lies in the HTMC regime ( γ > 2 ). A DDPM/DDIM as appro ximate Euler-Maruyama metho ds In practice DDPMs (and DDIMs) are defined in terms of a sequence of steps β 1 , . . . , β M (whic h are essen tially equiv alent to time-dep enden t step-sizes η t ), which then define a forw ard pro cess y m = p 1 − β m y m − 1 + p β m Z m for some noise Z i ∼ N (0 , 1) . Defining α m = 1 − β m and ¯ α m = α 1 · · · α m , one can easily prov e that y m is Gaussian with y m ∼ N √ ¯ α m y 0 , (1 − ¯ α m ) I . W e already see an approximation b et ween y m and the contin uous pro cess x t at time t = β 1 + · · · + β m since x t ∼ N ( e − t 2 , (1 − e − t ) I ) and ¯ α m ≈ e − β 1 −···− β m = e − t . Giv en σ m = √ 1 − ¯ α m and ϵ m ( y m ) = − σ m ∇ log ρ y m ( y m ) for ρ y m the densit y of y m , the DDPM backw ard pro cess is defined as y m − 1 = 1 √ α m y m − β m √ α m σ m ϵ m + p β m σ m − 1 σ m Z m , and the DDIM backw ard pro cess is defined as y m − 1 √ ¯ α m − 1 = y m √ ¯ α m + s 1 − ¯ α m − 1 ¯ α m − 1 − r 1 − ¯ α m ¯ α m ! ϵ m ( y m ) . F or the DDPM equiv alence, observ e that 1 √ α m = 1 √ 1 − β m = 1 + β m 2 + O ( β 2 m ) , w e then obtain similar appro ximations β m √ α m σ m ≈ β m σ m , similarly σ m − 1 σ m ≈ 1 (up to O ( β 2 m ) terms). y m − 1 − y m ≈ β m 1 2 y m + ∇ log ρ y m ( y m ) + p β m Z m . This implie s that the backw ard DDPM pro cess is approximately equal to an Euler-Maruyama approximation of the backw ard SDE − dy t = 1 2 y t + ∇ log ρ y t ( y t ) dt + dW t with step-size β m . Let us first rewrite the DDIM formula, using the fact that ¯ α m = α m · ¯ α m − 1 : y m − 1 = y m √ α m + p 1 − ¯ α m − 1 − q α − 1 m − ¯ α m − 1 ϵ m ( y m ) . 11 Appro ximating 1 √ α m ≈ 1 + β m 2 and taking a T aylor approximation of the function √ x − ¯ α m − 1 around x = 1 , w e obtain y m − 1 − y m ≈ β m 2 y m + 1 − α − 1 m 2 √ 1 − ¯ α m − 1 ϵ m ( y m ) ≈ β m 2 y m + β m 2 σ m σ m − 1 ∇ log ρ y m ( y m ) ≈ β m 1 2 y m + 1 2 ∇ log ρ y m ( y m ) whic h is the Euler appro ximation of the backw ard ODE − dy t dt = 1 2 y t + 1 2 ∇ log ρ y t ( y t ) . B Pro ofs Theorem 2. Under Assumptions 1 and 2, for any step size η > 0 , err or ϵ > 0 , and time T = iη > 0 , if we cho ose k min = − ⌊ log 2 c ⌋ , k max = − log 2 2 L e L ( T + η ) ϵ and p k = min { C 2 − (1+ γ 2 ) k , 1 } for some c onstant C , we have E x ( η ) T − y T 2 ≤ ϵ 2 at an exp e cte d c omputational c ost of at most 18 L 3 T 3 + LT 2 E γ ce L ( T + η ) Lϵ wher e E γ ( r ) = 1 (1 − 2 γ 2 − 1 ) 2 r 2 γ < 2 r 2 (3 + log 2 r ) γ = 2 2 3( γ − 2) 2 γ 2 − 1 − 1 2 r γ γ > 2 . Pr o of. As a reminder, here is the formula for the MLMC-EM metho d y t + η = y t + η k max X k = k min B k ( t ) p k f k t ( y t ) − f k − 1 t ( y t ) + √ η σ t Z t . And note that since we chose k min = − ⌊ log 2 c ⌋ < − log 2 c + 1 , the estimator f k min − 1 t m ust hav e compute b ounded by c γ 2 γ ( k min − 1) < 1 and therefore w e take it to b e the constan t 0 function. W e will track the evolution of the error in time (with the usual Grà ¶ nw all’s Lemma strategy), splitting the error into a bias term b t = E y t − x ( η ) t and v 2 t = E ∥ y t − E y t ∥ 2 where the exp ectation av erages ov er the sampling of the B k ( t ) , not the Z t whic h w e assume to b e fixed (in other terms our analysis conditions on the Z t , which are shared b etw een y t and x ( η ) t ). First w e note that b t + η can b e b ounded in terms of b t and v t b t + η = E " y t + η k max X k = k min B k ( t ) p k f k t ( y t ) − f k − 1 t ( y t ) + √ η σ t Z t # − x ( η ) t + η f t ( x ( η ) t ) + √ η σ t Z t = E y t − x ( η ) t + η E f k max ( y t ) − f ( x ( η ) t ) ≤ E y t − x ( η ) t + η E f k max ( y t ) − E f k ( y t ) + η E f ( y t ) − f ( x ( η ) t ) ≤ b t + η 2 − k max + η L q b 2 t + v 2 t ≤ (1 + η L ) b t + η Lv t + η 2 − k max , where we used p b 2 t + v 2 t ≤ b t + v t in the last inequality . 12 On the other hand v t + η can b e b ounded in terms of v t , b y relying on the conditional v ariance formula, conditioning on y t : v 2 t + η = E ∥ y t + η − E [ y t + η | y t ] ∥ 2 + E ∥ E [ y t + η | y t ] − E y t + η ∥ 2 = η 2 k max X k = k min 1 p k E f k t ( y t ) − f k − 1 t ( y t ) 2 + E y t + η f k max ( y t ) − E y t + η f k max ( y t ) 2 ≤ 9 η 2 k max X k = k min 2 − 2 k p k + (1 + ηL ) 2 v 2 t , where we used the fact that f k t ( x ) − f k − 1 t ( x ) ≤ f k t ( x ) − f t ( x ) + f t ( x ) − f k − 1 t ( x ) ≤ 2 − k + 2 − k +1 = 3 · 2 − k and V ar( g ( X )) ≤ Lip ( g ) 2 V ar( X ) for the last inequality . W e can unroll the recursive b ound for v t in to a direct b ound v 2 iη ≤ 9 η 2 i X j =0 (1 + ηL ) 2( i − j ) k max X k = k min 2 − 2 k p k ≤ 9 η 2 (1 + ηL ) 2 i 1 − (1 + η L ) − 2 k max X k = k min 2 − 2 k p k ≤ 9 η 2 L (1 + ηL ) 2( i +1) k max X k = k min 2 − 2 k p k ≤ 9 η 2 L e 2 L ( i +1) η k max X k = k min 2 − 2 k p k where we used 1 1 − (1 + η L ) − 2 = (1 + ηL ) 2 (1 + ηL ) 2 − 1 ≤ (1 + ηL ) 2 2 η L . This in turn allows us to b ound the bias term b t directly b iη ≤ η L i X j =0 (1 + ηL ) i − j v j + η 2 − k max i X j =0 (1 + ηL ) i − j ≤ 3 √ L √ 2 p 1 + 2 ηLη 3 2 Li (1 + ηL ) i v u u t k max X k = k min 2 − 2 k p k + η (1 + ηL ) i 1 − (1 + η L ) − 1 2 − k max ≤ 3 √ L √ 2 √ η ( iη ) e L ( i +1) η v u u t k max X k = k min 2 − 2 k p k + 1 L e L ( i +1) η 2 − k max . T o reach an ϵ error, we c ho ose k max = − log 2 L 2 e − L ( i +1) η ϵ so that 1 L e L ( i +1) η 2 − k max ≤ ϵ 2 . The p k s are c hosen as that p k = min { C 2 − (1+ γ 2 ) k , 1 } for some constant C so as to minimize b oth the sum P k max k = k min 2 − 2 k p k and the computational cost c P k max k = k min p k 2 γ k . W e will then c ho ose a sufficiently large constant C to guarantee an ϵ error. 13 Using the iden tity ( a + b ) 2 ≤ 2 a 2 + 2 b 2 and (1 + η L ) i ≤ e Liη , we simplify the total exp ected squared error: b 2 iη + v 2 iη ≤ 3 √ L √ 2 √ η ( iη ) e L ( i +1) η C − 1 2 v u u t k max X k = k min 2 ( γ 2 − 1) k + ϵ 2 2 + 9 η 2 L e 2 L ( i +1) η C − 1 k max X k = k min 2 ( γ 2 − 1) k ≤ 9 η L ( iη ) 2 + 1 2 L e 2 L ( i +1) η C − 1 k max X k = k min 2 ( γ 2 − 1) k + ϵ 2 2 . W e therefore choose C = 18 η L ( iη ) 2 + 1 2 L e 2 L ( i +1) η k max X k = k min 2 ( γ 2 − 1) k ϵ − 2 to obtain an exp ected squared error of ϵ 2 at a computational cost of at most i k max X k = k min p k c 2 γ k ≤ iC k max X k = k min c 2 ( γ 2 − 1) k = 18 L ( iη ) 3 + iη 2 L e 2 L ( i +1) η k max X k = k min 2 ( γ 2 − 1) k ! 2 cϵ − 2 The geometric sum P k max k = k min 2 ( γ 2 − 1) k can b e b ounded in three cases: k max X k = k min 2 ( γ 2 − 1) k ≤ 1 1 − 2 γ 2 − 1 2 ( γ 2 − 1) k min γ < 2 ( k max + 1 − k min ) γ = 2 2 γ 2 − 1 2 γ 2 − 1 − 1 2 ( γ 2 − 1) k max γ > 2 ≤ 1 1 − 2 γ 2 − 1 c 1 γ − 1 2 γ < 2 log 2 8 c 1 γ e L ( T + η ) Lϵ γ = 2 2 γ − 2 2 γ 2 − 1 − 1 L 2 e − L ( i +1) η ϵ − ( γ 2 − 1) γ > 2 . This leads to the computational bound k max X k = k min p k c 2 γ k ≤ 18 ( Liη ) 3 + Liη 2 1 (1 − 2 γ 2 − 1 ) 2 c 1 γ e L ( i +1) η Lϵ 2 γ < 2 c 1 γ e L ( i +1) η Lϵ 2 log 2 8 c 1 γ e L ( i +1) η Lϵ γ = 2 2 3( γ − 2) 2 γ 2 − 1 − 1 2 c 1 γ e L ( i +1) η Lϵ γ γ > 2 . 14
Original Paper
Loading high-quality paper...
Comments & Academic Discussion
Loading comments...
Leave a Comment