GAN tutorial

0. Giới thiệu

Neural network vẫn đang phát triển rất mạnh trong những năm gần đây. Trong năm 2017, 1 loại neural network có tên là GAN ( Generative adversarial network ) đã phát triển rất mạnh mẽ. Nhằm mục đích giúp các bạn tiếp cận nhanh chóng với mô hình này. Mình sẽ giới thiệu qua nguyên lý và cách thức train 1 mô hình GAN cơ bản cho các bạn.

Bài viết gồm có 3 phần:
1. Ý tưởng cở bản về GAN
2. Cách thức train 1 GAN về mặt lý thuyết với mã giá
3. Train GAN với keras cho tập dataset MNIST

Với các bạn muốn có code ngay thì có thể check ở đây : https://gist.github.com/astrung/0af1899cd73f1eaab60157fc9f14242b

1.Ý tưởng cơ bản về GAN

Chúng ta hãy bắt đầu với một câu chuyện thực tế trong cuộc sống của chúng ta để hiểu về nguyên tắc hoạt động của GAN, đó là cuộc đấu tranh giữa cảnh sát và người làm tiền giả :

  • Người làm tiền giả không hề biết cách để tạo ra tiền thật, họ chỉ cố gắng mô phỏng lại đồng tiền thật.

  • Cảnh sát có nhiệm vụ phát hiện tiền giả - tiền thật.

  • Người làm tiền giả cố gắng lừa cảnh sát rằng tiền mình tạo ra là thật.

  • Dựa trên việc tiền giả có bị phát hiện hay không, người làm tiền giả sẽ cải tiến phương pháp.

  • Tiền giả ngày càng được cải tiến tinh vi để qua mặt cảnh sát. Bởi vậy cảnh sát sẽ cải tiến phương pháp phát hiện.

  • 2 bên cạnh tranh, dẫn đến tiền giả ngày càng giống tiền thật và các phương pháp phát hiện của cảnh sát cũng hiện đại hơn.

Trong câu chuyện trên, chúng ta có thể thấy rằng sự cạnh tranh giữa 2 bên đã làm 2 bên ngày cành cải tiến các phương pháp để qua mặt lẫn nhau. Nhưng nếu xét trên 1 khía cạnh khác, cả 2 bên đang ngày càng tốt lên.
GAN là 1 mô hình cũng mô phỏng quá trình cạnh tranh này để nhằm mục đích cải tiến mô hình. Tất nhiên là để cạnh tranh thì trước tiên cần phải có 2 thực thể có mục đích đối lập nhau đã. Vậy nên khi nói về mô hình GAN và train GAN, thục chất chúng ta đang nói tới 2 mô hình neural network:

  • Generator neural network ( Người làm tiền giả ) : sinh ra những dữ liệu
    giống thực tế nhất có thể ( sinh ra tiền giả đó ) từ 1 nguồn noise nào đó ( thông thường là 1 phân bố xác suất ), có nhiệm vụ đánh lừa Discriminator rằng những dữ liệu mình sinh ra là thật ( dữ liệu gần với thực tế nhất có thể ).
  • Discriminator neural network ( Cảnh sát ): phát hiện đó là dữ liệu thực tế hay dữ liệu giả do Generator sinh ra.

Vậy 1 mô hình GAN thường được sử dụng để tạo ra các dữ liệu giống như trong thực tế. Ví dụ như chúng ta có 1 tập ảnh các chữ số đen trắng ( MNIST dataset ), và giờ chúng ta muốn sinh ra các ảnh chữ số đen trắng mới thì nhiệm vụ của Generator sẽ là sinh ra các ảnh đó, còn Discriminator sẽ có nhiệm vụ phân biệt giữa những bức ảnh được Generator sinh ra và những bức ảnh từ thực tế, và sau đó phản hồi sự khác nhau để Discriminator có thể chỉnh sửa những bức ảnh mình sinh ra.Thông qua sự cạnh tranh giữa Discriminator và Generator, các bức ảnh được sinh ra sẽ ngày càng tốt hơn. Bạn có thể xem thêm trong hình sau :

2.Training 1 mô hình GAN :

Như các bạn đã thấy ở mục trên, thực chất 1 mô hình GAN là tổng hợp của 2 mô hình : Generator và Discriminator. Vậy thì thực chất việc training GAN sẽ bao gồm việc training 2 mô hình :

2.1 Training discriminator :

Mục tiêu của chúng ta là giúp Discriminator học để có thể phân biệt được giữa data thực tế và fake data từ generator. Đây thực chất là bài toán classification. Vậy chúng ta sẽ train discriminator theo các bước sau :

  • Sử dụng generator (chưa được train ) để sinh ra một số dữ liệu giả - fake data
  • Chúng ta sẽ đánh nhãn : real data có giá trị 1, fake data có giá trị 0 => tạo training set cho discriminator.Đến đây chúng ta đã quay lại với 1 bài toán binary classification đơn giản.
  • Chúng ta sẽ train discriminator với training set vừa mới tạo ở bước trên.

2.2 Training generator :

Sau khi training Discriminator, chúng ta sẽ training Generator dựa trên sự phản hồi của Discriminator theo các bước sau :

  • Sử dụng generator (chưa được train ) để sinh ra một số dữ liệu giả - fake data
  • Chúng ta sẽ đánh nhãn : fake data có giá trị 1, và sử dụng đây như là training set trong bước này . Vậy khi train generator, chúng ta không sử dụng real data.
  • Chúng ta chain mô hình generator vào sau mô hình discriminator - output của generator sẽ được nối với input của discriminator. Như vậy các framework deeplearning có thể sử dụng quá trình backpropagation để học cho cả 2 mô hình.
  • Khi generator tạo ảnh và được chuyển tiếp cho discriminator, discriminator sẽ phân biệt thật giả và truyền sự phản hồi thông qua backpropagation tới generator để generator sửa đổi mô hình.
  • Tuy nhiên, chúng ta chỉ muốn train generator. Việc backpropagation sửa đổi mô hình của discriminator ( mà chúng ta vừa mới train trước đó ) có thể khiên mô hình tệ đi > trong quá trình train generator, chúng ta cần freeze mô hình của discriminator để tránh việc các framework deeplearning tự động chỉnh sửa mô hình discriminator.

Chúng ta sẽ train 2 mô hình tuần tự với nhau. Thông thường ở mỗi bước, chúng ta sẽ train discriminator 1 lần, sau đó train generator 1 lần, sau đó cứ thê lặp lại. Tuy nhiên chúng ta cúng có thể train 2,...n lần 1 mô hình ở mỗi bước, tùy theo sự phức tạp của bài toán.

3.Train GAN với keras cho tập dataset MNIST:

Trong ví dụ này, chúng ta sẽ sử dụng tập MNIST cho mô hình GAN, nhằm mục đích tạo ra những bức ảnh "fake" của số viết tay bằng framework keras. Bạn nào chưa biết về keras có thể tham khảo ở đây : https://keras.io/

Đầu tiên chúng ta sẽ định nghĩa cấu trúc của Discriminator. Vì chúng ta đang nhận dạng ảnh nên sẽ sử dụng mô hình convolution :

class DCGAN(object):
    def __init__(self, img_rows=28, img_cols=28, channel=1):

        self.img_rows = img_rows
        self.img_cols = img_cols
        self.channel = channel
        self.D = None   # discriminator builder
        self.G = None   # generator builder
        self.AM = None  # generator model
        self.DM = None  # discriminator model
        
    def discriminator(self):
        if self.D:
            return self.D
        self.D = Sequential()
        depth = 64
        dropout = 0.4
        # In: 28 x 28 x 1, depth = 1
        # Out: 14 x 14 x 1, depth=64
        input_shape = (self.img_rows, self.img_cols, self.channel)
        self.D.add(Conv2D(depth*1, 5, strides=2, input_shape=input_shape,\
            padding='same'))
        self.D.add(LeakyReLU(alpha=0.2))
        self.D.add(Dropout(dropout))

        self.D.add(Conv2D(depth*2, 5, strides=2, padding='same'))
        self.D.add(LeakyReLU(alpha=0.2))
        self.D.add(Dropout(dropout))

        self.D.add(Conv2D(depth*4, 5, strides=2, padding='same'))
        self.D.add(LeakyReLU(alpha=0.2))
        self.D.add(Dropout(dropout))

        self.D.add(Conv2D(depth*8, 5, strides=1, padding='same'))
        self.D.add(LeakyReLU(alpha=0.2))
        self.D.add(Dropout(dropout))

        # Out: 1-dim probability
        self.D.add(Flatten())
        self.D.add(Dense(1))
        self.D.add(Activation('sigmoid'))
        self.D.summary()
        return self.D

Tiếp theo chúng ta sẽ xây dựng cấu trúc của Generator. Đầu ra của chúng ta là 1 bức ảnh bất kì, đầu vào là 100 giá trị từ 1 hàm noise.

def generator(self):
    if self.G:
        return self.G
    self.G = Sequential()
    dropout = 0.4
    depth = 64+64+64+64
    dim = 7
    # In: 100
    # Out: dim x dim x depth
    self.G.add(Dense(dim*dim*depth, input_dim=100))
    self.G.add(BatchNormalization(momentum=0.9))
    self.G.add(Activation('relu'))
    self.G.add(Reshape((dim, dim, depth)))
    self.G.add(Dropout(dropout))

    # In: dim x dim x depth
    # Out: 2*dim x 2*dim x depth/2
    self.G.add(UpSampling2D())
    self.G.add(Conv2DTranspose(int(depth/2), 5, padding='same'))
    self.G.add(BatchNormalization(momentum=0.9))
    self.G.add(Activation('relu'))

    self.G.add(UpSampling2D())
    self.G.add(Conv2DTranspose(int(depth/4), 5, padding='same'))
    self.G.add(BatchNormalization(momentum=0.9))
    self.G.add(Activation('relu'))

    self.G.add(Conv2DTranspose(int(depth/8), 5, padding='same'))
    self.G.add(BatchNormalization(momentum=0.9))
    self.G.add(Activation('relu'))

    # Out: 28 x 28 x 1 grayscale image [0.0,1.0] per pix
    self.G.add(Conv2DTranspose(1, 5, padding='same'))
    self.G.add(Activation('sigmoid'))
    self.G.summary()
    return self.G

Giờ chúng ta cần phải định nghĩa loss function cho từng mô hình. Với discriminator, chúng ta cần phân biệt 2 loại nhãn ( giả và thật ), vậy chúng ta sẽ sử dụng binary cross entropy :

def discriminator_model(self):
    if self.DM:
        return self.DM
    optimizer = RMSprop(lr=0.0002, decay=6e-8)
    self.DM = Sequential()
    self.DM.add(self.discriminator())
    self.DM.compile(loss='binary_crossentropy', optimizer=optimizer,\
        metrics=['accuracy'])
    return self.DM

Tiếp theo chúng ta sử dụng loss function và optimizer cho generator. Như mình đã trình bày ở trên , chúng ta sẽ nối 2 mô hình với nhau để kết quả cuối cùng của discriminator ( 0 và 1 - thật hay giả ) có thể lan truyền ngược trở về generator. Đầu ra cuối cùng của chúng ta vẫn là output 0 và 1 của generator, vậy chúng ta sẽ dùng loss function giống như lúc trước.

 def adversarial_model(self):
    if self.AM:
        return self.AM
    optimizer = RMSprop(lr=0.0001, decay=3e-8)
    self.AM = Sequential()
    self.AM.add(self.generator())
    self.AM.add(self.discriminator())
    self.AM.compile(loss='binary_crossentropy', optimizer=optimizer,\
        metrics=['accuracy'])
    return self.AM

Cuối cùng là quá trình train của cả 2 mô hình. Đầu tiên , chúng ta sẽ train discriminator

    images_train = self.x_train[np.random.randint(0,
        self.x_train.shape[0], size=batch_size), :, :, :]
    noise = np.random.uniform(-1.0, 1.0, size=[batch_size, 100])
    images_fake = self.generator.predict(noise)
    x = np.concatenate((images_train,images_fake))#data có cả real data + fake data
    y = np.ones([2*batch_size, 1])
    y[batch_size:, :] = 0 
    #y[0:batch_size] là real data sẽ có nhãn là 1
    #y[batch_size:] là fake data sẽ có nhãn là 0
    d_loss = self.discriminator.train_on_batch(x, y)

Sau đó chúng ta train generator :

        y = np.ones([batch_size, 1]) # fake data bây giờ sẽ có nhãn là 1.
        noise = np.random.uniform(-1.0, 1.0, size=[batch_size, 100])
        a_loss = self.adversarial.train_on_batch(noise, y)

Đây là kết quả do mình thực hiện :

Các bạn có thể tìm full code ở đây : https://gist.github.com/astrung/0af1899cd73f1eaab60157fc9f14242b