MATLAB实现自编码器(五)——变分自编码器(VAE)实现图像生成的帮助函数
本文是對Train Variational Autoencoder (VAE) to Generate Images網頁的翻譯,該網頁實現了變分自編碼的圖像生成,以MNIST手寫數字為訓練數據,生成了相似的圖像。本文主要翻譯了網頁中幫助函數外的部分。主要部分見MATLAB實現自編碼器(四)——變分自編碼器實現圖像生成Train Variational Autoencoder (VAE) to Generate Images。
processImagesMNIST
首先是兩個用于處理mnist數據集的函數,分別處理圖片和標簽,使其符合網絡的輸入要求。
function X = processImagesMNIST(filename) % The MNIST processing functions extract the data from the downloaded IDX % files into MATLAB arrays. The processImagesMNIST function performs these % operations: Check if the file can be opened correctly. Obtain the magic % number by reading the first four bytes. The magic number is 2051 for % image data, and 2049 for label data. Read the next 3 sets of 4 bytes, % which return the number of images, the number of rows, and the number of % columns. Read the image data. Reshape the array and swaps the first two % dimensions due to the fact that the data was being read in column major % format. Ensure the pixel values are in the range [0,1] by dividing them % all by 255, and converts the 3-D array to a 4-D dlarray object. Close the % file.[fileID,errmsg] = fopen(filename,'r','b'); if fileID < 0error(errmsg); endmagicNum = fread(fileID,1,'int32',0,'b'); if magicNum == 2051fprintf('\nRead MNIST image data...\n') endnumImages = fread(fileID,1,'int32',0,'b'); fprintf('Number of images in the dataset: %6d ...\n',numImages); numRows = fread(fileID,1,'int32',0,'b'); numCols = fread(fileID,1,'int32',0,'b');X = fread(fileID,inf,'unsigned char');X = reshape(X,numCols,numRows,numImages); X = permute(X,[2 1 3]); X = X./255; X = reshape(X, [28,28,1,size(X,3)]); X = dlarray(X, 'SSCB');fclose(fileID); endprocessImagesMNIST
處理標簽,使其符合網絡的輸入要求
function Y = processLabelsMNIST(filename) % The processLabelsMNIST function operates similarly to the % processImagesMNIST function. After opening the file and reading the magic % number, it reads the labels and returns a categorical array containing % their values.[fileID,errmsg] = fopen(filename,'r','b');if fileID < 0error(errmsg); endmagicNum = fread(fileID,1,'int32',0,'b'); if magicNum == 2049fprintf('\nRead MNIST label data...\n') endnumItems = fread(fileID,1,'int32',0,'b'); fprintf('Number of labels in the dataset: %6d ...\n',numItems);Y = fread(fileID,inf,'unsigned char');Y = categorical(Y);fclose(fileID); endModel Gradients Function
The modelGradients function takes the encoder and decoder dlnetwork objects and a mini-batch of input data X, and returns the gradients of the loss with respect to the learnable parameters in the networks. The function performs three operations:
- Obtain the encodings by calling the sampling function on the mini-batch of images that passes through the encoder network.
- Obtain the loss by passing the encodings through the decoder network and calling the ELBOloss function.
- Compute the gradients of the loss with respect to the learnable parameters of both networks by calling the dlgradient function.
modelGradients函數獲取編碼器和解碼器的dlnetwork對象以及輸入數據X的小批量,并返回網絡中可訓練參數的損失梯度。 該函數執行三個操作:
- 通過在通過編碼器網絡的微型圖像批次上調用采樣函數來獲取編碼。
- 通過使編碼通過解碼器網絡并調用ELBOloss函數來獲得損耗。
- 通過調用dlgradient函數,針對兩個網絡的可學習參數計算損耗的梯度。
Sampling and Loss Functions
The sampling function obtains encodings from input images. Initially, it passes a mini-batch of images through the encoder network and splits the output of size (2*latentDim)miniBatchSize into a matrix of means and a matrix of variances, each of size latentDimbatchSize. Then, it uses these matrices to implement the reparameterization trick and to compute the encoding. Finally, it converts this encoding to a dlarray object in SSCB format.
Sampling 函數從輸入圖像獲取編碼。 最初,它通過編碼器網絡傳遞一個圖像的小批量,并將大小(2 × latentDim) × miniBatchSize的輸出分成均值矩陣和方差矩陣,每個大小均為latentDim × batchSize。 然后,它使用這些矩陣來實現重新參數化技巧并計算編碼。 最后,它將這種編碼轉換為SSCB格式的dlarray對象。
function [zSampled, zMean, zLogvar] = sampling(encoderNet, x) compressed = forward(encoderNet, x); d = size(compressed,1)/2; zMean = compressed(1:d,:); zLogvar = compressed(1+d:end,:);sz = size(zMean); epsilon = randn(sz); sigma = exp(.5 * zLogvar); z = epsilon .* sigma + zMean; z = reshape(z, [1,1,sz]); zSampled = dlarray(z, 'SSCB'); endELBOloss
The ELBOloss function takes the encodings of the means and the variances returned by the sampling function, and uses them to compute the ELBO loss.
ELBOloss函數采用均值和采樣函數返回的方差的編碼,并使用它們來計算ELBO損耗。
function elbo = ELBOloss(x, xPred, zMean, zLogvar) squares = 0.5*(xPred-x).^2; reconstructionLoss = sum(squares, [1,2,3]);KL = -.5 * sum(1 + zLogvar - zMean.^2 - exp(zLogvar), 1);elbo = mean(reconstructionLoss + KL); endVisualization Functions
The VisualizeReconstruction function randomly chooses two images for each digit of the MNIST data set, passes them through the VAE, and plots the reconstruction side by side with the original input. Note that to plot the information contained inside a dlarray object, you need to extract it first using the extractdata and gather functions.
VisualizeReconstruction函數為MNIST數據集的每個數字隨機選擇兩個圖像,將它們通過VAE,然后與原始輸入并排繪制。 請注意,要繪制dlarray對象中包含的信息,需要先使用extractdata and gather函數將其提取出來。
function visualizeReconstruction(XTest,YTest, encoderNet, decoderNet) f = figure; figure(f) title("Example ground truth image vs. reconstructed image") for i = 1:2for c=0:9idx = iRandomIdxOfClass(YTest,c);X = XTest(:,:,:,idx);[z, ~, ~] = sampling(encoderNet, X);XPred = sigmoid(forward(decoderNet, z));X = gather(extractdata(X));XPred = gather(extractdata(XPred));comparison = [X, ones(size(X,1),1), XPred];subplot(4,5,(i-1)*10+c+1), imshow(comparison,[]),end end endfunction idx = iRandomIdxOfClass(T,c) idx = T == categorical(c); idx = find(idx); idx = idx(randi(numel(idx),1)); endVisualizeLatentSpace
The VisualizeLatentSpace function visualizes the latent space defined by the mean and the variance matrices that form the output of the encoder network, and locates the clusters formed by the latent space representations of each digit.
VisualizeLatentSpace函數可視化由形成編碼器網絡輸出的均值和方差矩陣定義的潛在空間,并找到由每個數字的潛在空間表示形式形成的聚類。
The function starts by extracting the mean and the variance matrices from the dlarray objects. Because transposing a matrix with channel/batch dimensions (C and B) is not possible, the function calls stripdims before transposing the matrices. Then, it carries out a principal component analysis (PCA) on both matrices. To visualize the latent space in two dimensions, the function keeps the first two principal components and plots them against each other. Finally, the function colors the digit classes so that you can observe clusters.
該函數首先從dlarray對象中提取均值和方差矩陣。 由于無法轉置具有通道/批處理尺寸(C和B)的矩陣,因此該函數在轉置矩陣之前調用stripdims。 然后,它對兩個矩陣執行主成分分析(PCA)。 為了在兩個維度上可視化潛在空間,該函數保留前兩個主要成分并將其相互繪制。 最后,該函數為數字類著色,以便觀察群集。
function visualizeLatentSpace(XTest, YTest, encoderNet) [~, zMean, zLogvar] = sampling(encoderNet, XTest);zMean = stripdims(zMean)'; zMean = gather(extractdata(zMean));zLogvar = stripdims(zLogvar)'; zLogvar = gather(extractdata(zLogvar));[~,scoreMean] = pca(zMean); [~,scoreLogvar] = pca(zLogvar);c = parula(10); f1 = figure; figure(f1) title("Latent space")ah = subplot(1,2,1); scatter(scoreMean(:,2),scoreMean(:,1),[],c(double(YTest),:)); ah.YDir = 'reverse'; axis equal xlabel("Z_m_u(2)") ylabel("Z_m_u(1)") cb = colorbar; cb.Ticks = 0:(1/9):1; cb.TickLabels = string(0:9);ah = subplot(1,2,2); scatter(scoreLogvar(:,2),scoreLogvar(:,1),[],c(double(YTest),:)); ah.YDir = 'reverse'; xlabel("Z_v_a_r(2)") ylabel("Z_v_a_r(1)") cb = colorbar; cb.Ticks = 0:(1/9):1; cb.TickLabels = string(0:9); axis equal endgenerate
The generate function tests the generative capabilities of the VAE. It initializes a dlarray object containing 25 randomly generated encodings, passes them through the decoder network, and plots the outputs.
生成函數測試VAE的生成能力。 它初始化包含25個隨機生成的編碼的dlarray對象,將它們傳遞通過解碼器網絡,并繪制輸出。
function generate(decoderNet, latentDim) randomNoise = dlarray(randn(1,1,latentDim,25),'SSCB'); generatedImage = sigmoid(predict(decoderNet, randomNoise)); generatedImage = extractdata(generatedImage);f3 = figure; figure(f3) imshow(imtile(generatedImage, "ThumbnailSize", [100,100])) title("Generated samples of digits") drawnow end總結
以上是生活随笔為你收集整理的MATLAB实现自编码器(五)——变分自编码器(VAE)实现图像生成的帮助函数的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 关于数据库表的规范设计
- 下一篇: 手动修改美化7zip图标 - 附替换文件