vqvae的loss计算 loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2) z_q是codebook 找到的最接近z的向量. z是encoder生成的向量. L对z求导 = 2(z_q.detach()-z)*(-1)=2(z - z_q.detach()) # 这个部分对于encoder做了训练. L对z_q...
Mse代表均方误差,.detach作为停止梯度操作。 e_latent_loss = F.mse_loss(quantized.detach(), inputs) q_latent_loss = F.mse_loss(quantized, inputs.detach()) loss = q_latent_loss + self._commitment_cost * e_latent_loss 最后确保梯度可以直接从解码器流向编码器。 quantized = inputs + (quanti...
vqvae 训练不稳定主要还是由字典引起来的。codebook 奔溃的意思是,所有 embedding 都被分配给了其中部分...
上图zq就是ze通过与codebook中的e计算argmin,获得的隐变量序列。基于上面的直通估计原理,认为这个argmin是恒等操作。在计算梯度时,直接用重建loss相对zq(x)计算梯度,作为ze(x)的梯度,也就是encoder的更新依据也是zq(x)。下面公式中第一项。 如果只有第一项重建loss,codebook里面的embedding就无法接收到重建误差的梯...
综上,VQ-VAE共包含三个部分的训练loss:reconstruction loss,VQ loss,commitment loss。 其中reconstruction loss作用在encoder和decoder上,VQ loss用来更新embedding空间(也可用EMA方式),而commitment loss用来约束encoder,这里的为权重系数,论文默认设置为0.25。 另外,在实际实验中,一张图像会采用个离散隐变量,这个和...
VAE的loss 主要由两部分组成: (1)为了使输出和输入尽可能像,用MSE来约束。 (2)为了满足(1),随着训练不断进行,模型会倾向于产生固定的Z,即encoder输出的标准差接近0,VAE就会越来越像AE,这也就和我们的初衷相违背了。因此,我们要求N(mean, std)要逼近标准正态分布,这里使用KL散度来进行约束。
loss = q_latent_loss + self._commitment_cost*e_latent_loss 最后确保梯度可以直接从解码器流向编码器。 quantized= inputs + (quantized - inputs).detach() 从数学上讲,左右两边是相等的(+输入和-输入将相互抵消)。在反向传播过程中,.detach部分将被忽略 ...
loss = q_latent_loss + self._commitment_cost * e_latent_loss 1. 2. 3. 最后确保梯度可以直接从解码器流向编码器。 quantized = inputs + (quantized - inputs).detach() 1. 从数学上讲,左右两边是相等的(+输入和-输入将相互抵消)。在反向传播过程中,.detach部分将被忽略 ...
loss = q_latent_loss + self._commitment_cost * e_latent_loss 最后确保梯度可以直接从解码器流向编码器。 quantized = inputs + (quantized - inputs).detach() 从数学上讲,左右两边是相等的(+输入和-输入将相互抵消)。在反向传播过程中,.detach部分将被忽略 ...
第一项相等于固定z,让zq靠近z,第二项则反过来固定zq,让z靠近zq。注意这个“等价”是对于反向传播(求梯度)来说的,对于前向传播(求loss)它是原来的两倍。根据我们刚才的讨论,我们希望“让zq去靠近z”多于“让z去靠近zq”,所以可以调一下最终的loss比例: ...