训练的大部分代码都和 DDPM 一样,不一样的只有损失函数以及重要性采样。这里定义一个专门用来计算 IDDPM 的损失函数。损失函数共由两项组成,也就是 L_\mathrm{simple} 和L_\mathrm{vlb}: def training_losses( iddpm: IDDPM, model: UNet2DModel, clean_images: torch.Tensor, noise: torch.Tensor, noisy...
如上图作者观察到前几步采样对损失函数负对数似然NLL(negative log-likelihood)贡献很大,而且正好前几步采样时{\beta}_{t}和\tilde{\beta}_{t}也差别很大,那就让它们结合一下,网络只需预测一个参数v: \Sigma_\theta(x_t,t)=\exp(v\log\beta_t+(1-v)\log\tilde{\beta}_t) 由于DDPM中的损失函数L...
而LT与参数θ无关,如果前向过程足够充分地破坏了原始数据分布,即有q(xT|x0)≈N(0,I),那么其损失值就接近 0 了。 DDPM 中还指出,式 (2) 所定义的前向加噪过程中是相邻单步的加噪,实际上从原始数据x0开始,可以直接完成任意多步加噪。记αt:=1−βt,α¯t:=∏s=0tαs,边缘分布可重写为: q...
计算Loss的时候,如果损失函数的类型是MSE或者Rescaled_MSE,则计算DDPM里面的simple形式的预测噪声的损失或者变分下界对应的的loss。 如果高斯分布的噪声需要预测,则会计算vlb,而且不会影响对均值的学习,只会影响对方差的学习。学习的target可以是 x_{t-1}、 x_0 或者\epsilon (一般是这个)。 model_output = model...