CVAE (条件 变分 自动编码器)
生活随笔
收集整理的這篇文章主要介紹了
CVAE (条件 变分 自动编码器)
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
notations
- xxx image
- zzz latent
- yyy label (omitted to lighten notation)
- p(x∣z)p(x|z)p(x∣z) decoder Encoder
- q(z∣x)q(z|x)q(z∣x) encoder Decoder
- p^(h)\hat{p}(h)p^?(h) prior encoder (by variational inference) PriorEncoder
model structure
class CVAE(nn.Module):def __init__(self, config):super(CVAE, self).__init__()self.encoder = Encoder(...)self.decoder = Decoder(...)self.priorEncoder = PriorEncoder(...)def forward(self, x, y):x = x.reshape((-1, 784)) # MNISTmu, sigma = self.encoder(x, y)prior_mu, prior_sigma = self.priorEncoder(y)z = torch.randn_like(mu)z = z * sigma + mureconstructed_x = self.decoder(z, y)reconstructed_x = reconstructed_x.reshape((-1, 28, 28))return reconstructed_x, mu, sigma, prior_mu, prior_sigmadef infer(self, y):prior_mu, prior_sigma = self.priorEncoder(y)z = torch.randn_like(prior_mu)z = z * prior_sigma + prior_mureconstructed_x = self.decoder(z, y)return reconstructed_x # class Loss(nn.Module):def __init__(self):super(Loss,self).__init__()self.loss_fn = nn.MSELoss(reduction='mean')self.kld_loss_weight = 1e-5def forward(self, x, reconstructed_x, mu, sigma, prior_mu, prior_sigma):mse_loss = self.loss_fn(x, reconstructed_x)kld_loss = torch.log(prior_sigma / sigma) + (sigma**2 + (mu - prior_mu)**2) / (2 * prior_sigma**2) - 0.5kld_loss = torch.sum(kld_loss) / x.shape[0]loss = mse_loss + self.kld_loss_weight * kld_lossreturn loss # def train(model, criterion, optimizer, data_loader, config):train_task_time_str = time_str()for epoch in range(config.num_epoch):loss_seq = []for step, (x,y) in tqdm(enumerate(data_loader)):# -------------------- data --------------------x = x.to(device)y = y.to(device)# -------------------- forward --------------------reconstructed_x, mu, sigma, prior_mu, prior_sigma = model(x, y)loss = criterion(x, reconstructed_x, mu, sigma, prior_mu, prior_sigma)# -------------------- log --------------------loss_seq.append(loss.item())# -------------------- backward --------------------optimizer.zero_grad()loss.backward()optimizer.step()# -------------------- end --------------------logging.info(f'epoch {epoch:^5d} loss {sum(loss_seq[-config.batch_size:]) / config.batch_size:.5f}')with torch.no_grad():# -------------------- file --------------------path = f'{config.save_fig_path}/{train_task_time_str}' # type(model).__name__if not os.path.exists(path):os.makedirs(path)path += f'/epoch{epoch:04d}.png'# -------------------- figure --------------------plt.close()fig, axs = plt.subplots(nrows=1, ncols=10, figsize=(10, 2), dpi=512)fig.suptitle(f'epoch {epoch} loss {sum(loss_seq[-config.batch_size:]) / config.batch_size:.5f}')# -------------------- infer --------------------y = torch.Tensor(list(range(config.num_class)))y = y.to(dtype=torch.int64)y = nn.functional.one_hot(y, num_classes=config.num_class)y = y.to(dtype=torch.float)y = y.to(device)x = model.infer(y)x = x.cpu()x = x.numpy()x += x.min()x /= x.max()x *= 255x = x.astype(np.uint8)# -------------------- plot --------------------for idx,ax,arr in zip(range(config.num_class),axs,x):ax.set_title(str(idx))ax.axis('off')ax.imshow(arr.reshape((28,28)), cmap='BuGn')# -------------------- save --------------------# plt.show()plt.savefig(path)# -------------------- end -------------------- #dynamics
- SGVB (stochastic_gradient + variational_bayesian) 框架根據 EM算法的原理 使用 變分推斷 優化 ELBO.
- log?p(v)=ELBO(q(z∣x),p(x∣z))+KL(q(z∣x)∥p(z∣x))\log p(v) = \mathrm{ELBO} \left( q(z|x), p(x|z) \right) + \mathrm{KL} \left( q(z|x) \| p(z|x) \right)logp(v)=ELBO(q(z∣x),p(x∣z))+KL(q(z∣x)∥p(z∣x)) ELBO 是 對數似然 的代理.
- ELBO=Eq(z∣x)[log?p(x∣z)p(z)]+Entropy(q(z∣x))=Eq(z∣x)[log?p(x∣z)]?KL(q(z∣x)∥p(z))\mathrm{ELBO} = \mathop{\mathbb{E}} \limits_{q(z|x)} \left[ \log p(x|z)p(z) \right] + \mathrm{Entropy} \left( q(z|x) \right) = \mathop{\mathbb{E}} \limits_{q(z|x)} \left[ \log p(x|z) \right] - \mathrm{KL} \left( q(z|x) \| p(z) \right)ELBO=q(z∣x)E?[logp(x∣z)p(z)]+Entropy(q(z∣x))=q(z∣x)E?[logp(x∣z)]?KL(q(z∣x)∥p(z))
- Eq(z∣x)[log?p(x∣z)p(z)]+Entropy(q(z∣x))\mathop{\mathbb{E}} \limits_{q(z|x)} \left[ \log p(x|z)p(z) \right] + \mathrm{Entropy} \left( q(z|x) \right)q(z∣x)E?[logp(x∣z)p(z)]+Entropy(q(z∣x)) 用于證明EM算法的原理.
- Eq(z∣x)[log?p(x∣z)]?KL(q(z∣x)∥p(z))\mathop{\mathbb{E}} \limits_{q(z|x)} \left[ \log p(x|z) \right] - \mathrm{KL} \left( q(z|x) \| p(z) \right)q(z∣x)E?[logp(x∣z)]?KL(q(z∣x)∥p(z)) 用于神經網絡優化.
- max?Eq(z∣x)[log?p(x∣z)]≈samplingmax?q(zi∣xi)log?p(xi∣zi)=oppositemin?cross_entropy_loss~substitutionmin?mse_loss\max \mathop{\mathbb{E}} \limits_{q(z|x)} \left[ \log p(x|z) \right] \stackrel{\textsf{sampling}}{\approx} \max q(z_i|x_i) \log p(x_i|z_i) \stackrel{\textsf{opposite}}{=} \min \mathtt{cross\_entropy\_loss} \stackrel{\textsf{substitution}}{\sim} \min \mathtt{mse\_loss}maxq(z∣x)E?[logp(x∣z)]≈samplingmaxq(zi?∣xi?)logp(xi?∣zi?)=oppositemincross_entropy_loss~substitutionminmse_loss
- min?KL(q(z∣x)∥p(z))≈VariationalInferencemin?KL(q(z∣x)∥p^(z))\min \mathrm{KL} \left( q(z|x) \| p(z) \right) \stackrel{\textsf{Variational Inference}}{\approx} \min \mathrm{KL} \left( q(z|x) \| \hat{p}(z) \right)minKL(q(z∣x)∥p(z))≈Variational?InferenceminKL(q(z∣x)∥p^?(z))
以上這組超參數能較快的收斂到較優模型參數.
實驗發現, batch_size較大時收斂到較差模型參數, learning_rate較小時收斂非常緩慢.
- 神經網絡先學數字范圍再學數字形狀. epoch[0-3]數字有很多噪聲點, epoch[4-15]數字呈平滑圖形.
- 神經網絡先學前景(數字)再學背景(白色). epoch[0-10]背景都是暗色, epoch[11-15]背景都是亮色.
- epoch11開始學最不重要的細節(白色背景), epoch12開始就逐漸發生了過擬合! 尤其是數字0, 在epoch15中看起來像數字8一樣.
總結
以上是生活随笔為你收集整理的CVAE (条件 变分 自动编码器)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 人口红利提前消失?今年经济会如何?权威回
- 下一篇: 反驳生命的起点是rna_生命起源学说或将