首页   注册   登录
V2EX = way to explore
V2EX 是一个关于分享和探索的地方
现在注册
已注册用户请  登录
推荐学习书目
Learn Python the Hard Way
Python 学习手册
Python Cookbook
Python 基础教程
Python Sites
PyPI - Python Package Index
http://www.simple-is-better.com/
http://diveintopython.org/toc/index.html
Pocoo
值得关注的项目
PyPy
Celery
Jinja2
Read the Docs
gevent
pyenv
virtualenv
Stackless Python
Beautiful Soup
结巴中文分词
Green Unicorn
Sentry
Shovel
Pyflakes
pytest
Python 编程
pep8 Checker
Styles
PEP 8
Google Python Style Guide
Code Style from The Hitchhiker's Guide
V2EX  ›  Python

手把手带你写一个 GAN

  •  
  •   MoModel · 35 天前 · 620 次点击
    这是一个创建于 35 天前的主题,其中的信息可能已经有所发展或是发生改变。

    GAN

    前言

    我们接下来从这几个方面来探索.

    • GAN能用来做什么
    • GAN的原理
    • GAN的代码实现

    用途

    GAN自 2014 年诞生以来, 就一直备受关注, 著名的应用也随即产出, 比如比较著名的 GAN 的应用有 Pix2Pix、CycleGAN 等, 大家也将它用于各个地方.

    1. 缺失 /模糊像素的补充
    2. 图片修复
    3. etc...

    我觉得还有一个比较重要的用途, 很多人都会缺少数据集, 那么就通过GAN去生成数据集了, 通过调节部分参数来进行数据集的产生的相似度.

    原理

    GAN 的基本原理其实非常简单,这里以生成图片为例进行说明。假设我们有两个网络,G(Generator)D(Discriminator)。正如它的名字所暗示的那样, 它们的功能分别是:

    G是一个生成图片的网络, 它接收一个随机的噪声(随机生成的图片)z, 通过这个噪声生成图片,记做G(z). D是一个判别网络, 判别一张图片是不是“真实的”。它的输入参数是x, x代表一张图片,输出D(x)代表x为真实图片的概率,如果为>0.5,就代表是真实(相似)的图片,反之,就代表不是真实的图片。

    图片来自于知乎

    我们通过一个假产品宣传的例子来理解:

    首先, 我们来定义一下角色.

    1. 进行宣传的'专家'(生成网络)
    2. 正在听讲的'我们'(判别网络)

    '专家'的手里面拿着一堆高仿的产品, 正在进行宣讲, 我们是熟知真品的相关信息的, 通过去对比两个产品之间的差距, 来判断是赝品的可能性.

    这时, 我们就可以引出来一个概念, 如果'专家'团队比较厉害, 完美的仿造了我们的判断依据, 比如说产出方, 发明日期, 说明文等等, 那么我们就会觉得他是真的, 那么他就是一个好的生成网络, 反之, 我们会判断他是赝品.

    我们(判别网络)出发, 我们的判断条件越苛刻, 赝品和真品之间的差距会越来越小, 这样的最后的产出就是真假难分, 完全被模仿了.

    相关资源

    Generative Adversarial Networks, 深层的原理推荐大家可以去阅读以下这一篇论文.

    * 损失函数等相关细节我们在实现里在讲

    实现

    接下来我们就以 mnist来实现 GAN吧.

    1. 首先, 我们先下载数据集.
    from tensorflow.examples.tutorials.mnist import input_data
    mnist = input_data.read_data_sets('./Dataset/datasets/MNIST_data', one_hot=False)
    

    我们通过tensorflow去下载mnist的数据集, 然后加载到内存, one-hot参数决定我们的label是否要经过编码(mnist 数据集是有 10 个类别), 但是我们判别网络是对比真实的和生成的之间的区别以及相似的可能性, 所以不需要执行one-hot编码了.

    这里读取出来的图片已经归一化到[0, 1]之间了.

    1. 俗话说, 知己知彼, 百战百胜, 那我们拿到数据集, 就先来看看它长什么样.
    def show_images(images):
        images = np.reshape(images, [images.shape[0], -1])  # images reshape to (batch_size, D)
        sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
        sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))
    
        fig = plt.figure(figsize=(sqrtn, sqrtn))
        gs = gridspec.GridSpec(sqrtn, sqrtn)
        gs.update(wspace=0.05, hspace=0.05)
    
        for i, img in enumerate(images):
            ax = plt.subplot(gs[i])
            plt.axis('off')
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_aspect('equal')
            plt.imshow(img.reshape([sqrtimg,sqrtimg]))
    

    运行截图

    这里有一个小问题, 如果是在Notebook中执行, 记得加上这句话, 否则需要执行两次才会绘制.

    %matplotlib inline
    
    1. 数据看过了, 我们该对它进行一定的处理了, 这里我们只是将数据缩放到[-1, 1]之间.
    def preprocess_img(x):
        return 2 * x - 1.0
    
    def deprocess_img(x):
        return (x + 1.0) / 2.0
    
    1. 数据处理完了, 接下来我们要开始搭建模型了, 这一部分我们有两个模型, 一个生成网络, 一个判别网络.

    生成网络

    def generator(z):
    
        with tf.variable_scope("generator"):
    
            fc1 = tf.layers.dense(inputs=z, units=1024, activation=tf.nn.relu)
            bn1 = tf.layers.batch_normalization(inputs=fc1, training=True)
            fc2 = tf.layers.dense(inputs=bn1, units=7*7*128, activation=tf.nn.relu)
            bn2 = tf.layers.batch_normalization(inputs=fc2, training=True)
            reshaped = tf.reshape(bn2, shape=[-1, 7, 7, 128])
            conv_transpose1 = tf.layers.conv2d_transpose(inputs=reshaped, filters=64, kernel_size=4, strides=2, activation=tf.nn.relu,
                                                        padding='same')
            bn3 = tf.layers.batch_normalization(inputs=conv_transpose1, training=True)
            conv_transpose2 = tf.layers.conv2d_transpose(inputs=bn3, filters=1, kernel_size=4, strides=2, activation=tf.nn.tanh,
                                            padding='same')
    
            img = tf.reshape(conv_transpose2, shape=[-1, 784])
            return img
    

    判别网络

    def discriminator(x):
    
        with tf.variable_scope("discriminator"):
    
            unflatten = tf.reshape(x, shape=[-1, 28, 28, 1])
            conv1 = tf.layers.conv2d(inputs=unflatten, kernel_size=5, strides=1, filters=32 ,activation=leaky_relu)
            maxpool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=2, strides=2)
            conv2 = tf.layers.conv2d(inputs=maxpool1, kernel_size=5, strides=1, filters=64,activation=leaky_relu)
            maxpool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=2, strides=2)
            flatten = tf.reshape(maxpool2, shape=[-1, 1024])
            fc1 = tf.layers.dense(inputs=flatten, units=1024, activation=leaky_relu)
            logits = tf.layers.dense(inputs=fc1, units=1)
    
            return logits
    

    激活函数我们使用了leaky_relu, 他的代码实现是

    def leaky_relu(x, alpha=0.01):
        activation = tf.maximum(x,alpha*x)
        return activation
    

    它和 relu的区别就是, 小于 0 的值也会给与一点小的权重进行保留.

    1. 建立损失函数
    def gan_loss(logits_real, logits_fake):
    
        # Target label vector for generator loss and used in discriminator loss.
        true_labels = tf.ones_like(logits_fake)
    
        # DISCRIMINATOR loss has 2 parts: how well it classifies real images and how well it
        # classifies fake images.
        real_image_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_real, labels=true_labels)
        fake_image_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_fake, labels=1-true_labels)
    
        # Combine and average losses over the batch
        D_loss = real_image_loss + fake_image_loss
        D_loss = tf.reduce_mean(D_loss)
    
        # GENERATOR is trying to make the discriminator output 1 for all its images.
        # So we use our target label vector of ones for computing generator loss.
        G_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_fake, labels=true_labels)
    
        # Average generator loss over the batch.
        G_loss = tf.reduce_mean(G_loss)
    
        return D_loss, G_loss
    

    损失我们分为两部分, 一部分是生成网络的, 一部分是判别网络的.

    生成网络的损失定义为, 生成图像的类别与真实标签(全是 1)的交叉熵损失.

    $$ loss_{generate} = -\sum_i^n(Y_ilogG_i + (1 - Y_i)logG_i) $$

    判别网络的损失定义为, 我们将真实图片的标签设置为 1, 生成图片的标签设置为 0, 然后由真实图片的输出以及生成图片的输出的交叉熵损失和.

    T: True, G: Generate

    生成损失

    $$ loss_{G} = -\sum_i^n(Y_ilogG_i + (1 - Y_i)logG_i) $$

    真实图片损失

    $$ loss_{T} = -\sum_i^n(Y_ilogT_i + (1 - Y_i)logT_i) $$

    总损失

    $$ Loss_{D} = loss_{G} + loss_{T} $$

    1. 训练
    def run_a_gan(sess, G_train_step, G_loss, D_train_step, D_loss, G_extra_step, D_extra_step,\
                  show_every=250, print_every=50, batch_size=128, num_epoch=10):
        # compute the number of iterations we need
        max_iter = int(mnist.train.num_examples*num_epoch/batch_size)
        for it in range(max_iter):
            # every show often, show a sample result
            if it % show_every == 0:
                samples = sess.run(G_sample)
                fig = show_images(samples[:16])
                plt.show()
                print()
            # run a batch of data through the network
            minibatch,minbatch_y = mnist.train.next_batch(batch_size)
            _1, D_loss_curr = sess.run([D_train_step, D_loss], feed_dict={x: minibatch})
            _2, G_loss_curr = sess.run([G_train_step, G_loss])
            if it % show_every == 0:
                print(_1,_2)
            # print loss every so often.
            # We want to make sure D_loss doesn't go to 0
            if it % print_every == 0:
                print('Iter: {}, D: {:.4}, G:{:.4}'.format(it,D_loss_curr,G_loss_curr))
        print('Final images')
        samples = sess.run(G_sample)
    
        fig = show_images(samples[:16])
        plt.show()
    

    这里就是开始训练了, 并展示训练的结果.

    1. 查看结果

    刚开始的时候, 还没学会怎么模仿.

    经过学习改进.

    项目地址

    源码看这

    声明

    该文章参考了知乎的文章.

    目前尚无回复
    关于   ·   FAQ   ·   API   ·   我们的愿景   ·   广告投放   ·   感谢   ·   实用小工具   ·   2459 人在线   最高记录 4385   ·  
    创意工作者们的社区
    World is powered by solitude
    VERSION: 3.9.8.3 · 17ms · UTC 14:40 · PVG 22:40 · LAX 06:40 · JFK 09:40
    ♥ Do have faith in what you're doing.
    沪ICP备16043287号-1