Trong bài viết này chúng ta sẽ tiếp tục với seri về GAN ( các bạn chưa biết GAN là gì có thể đọc bài viết trước ở đây ).
Bài viết gồm có 3 phần:
1.Nhắc lại ý tưởng về GAN
2. Ý tưởng cơ bản về Conditional GAN ( cGAN )
3. Training cGAN bằng tensorflow
Full code các bạn có thể tìm ở đây: https://gist.github.com/astrung/76cdb95589b181bd007d86a2b7a24c8a
1. Nhắc lại về GAN
Đầu tiên, GAN là 1 mô hình neural network gồm 2 mạng neural network nhỏ ,mô phỏng quá trình cạnh tranh để nhằm mục đích cải tiến mô hình. Vậy nên khi nói về mô hình GAN và train GAN, thục chất chúng ta đang nói tới 2 mô hình neural network:
- Generator neural network: sinh ra những dữ liệu
giống thực tế nhất có thể từ 1 nguồn noise nào đó ( thông thường là 1 phân bố xác suất ), có nhiệm vụ đánh lừa Discriminator rằng những dữ liệu mình sinh ra là thật ( dữ liệu gần với thực tế nhất có thể ). Vậy input của Generator là 1 tập các giá trị noise, và output là dữ liệu thực tế. Trong bài viết này, chúng ta sẽ kí hiệu noise là z, và dữ liệu được sinh ra là X - Discriminator neural network: phát hiện đó là dữ liệu thực tế hay dữ liệu giả do Generator sinh ra. Vậy input của Discriminator là dữ liệu, và output sẽ là 1 giá trị 0 hay 1 - giả hay thật.
Vậy 1 mô hình GAN thường được sử dụng để tạo ra các dữ liệu giống như trong thực tế. Ví dụ như chúng ta có 1 tập ảnh các chữ số đen trắng ( MNIST dataset ), và giờ chúng ta muốn sinh ra các ảnh chữ số đen trắng mới thì nhiệm vụ của Generator sẽ là sinh ra các ảnh đó, còn Discriminator sẽ có nhiệm vụ phân biệt giữa những bức ảnh được Generator sinh ra và những bức ảnh từ thực tế, và sau đó phản hồi sự khác nhau để Discriminator có thể chỉnh sửa những bức ảnh mình sinh ra.Thông qua sự cạnh tranh giữa Discriminator và Generator, các bức ảnh được sinh ra sẽ ngày càng tốt hơn.
Nếu các bạn muốn biết chi tiết hơn thì có thể đọc bài viết trước ở đây
2. Ý tưởng cơ bản về Conditional GAN ( cGAN )
Vậy trong ví dụ MNIST trên, output của 1 generator sẽ là 1 bức ảnh đen trắng của các số từ 0 đến 9. Nhưng bức ảnh của các con số này mang tính hoàn toàn ngẫu nhiên, chúng ta không thể nào biết trước được bức ảnh sinh ra sẽ là số 0 hay số 9, hoặc là 1 số nào đó khác.
Vậy làm sao để chúng ta sinh ra chỉ một số cố định theo ý muốn của chúng ta. Lúc này chúng ta có 2 lựa chọn:
-
Sử dụng 10 mô hình GAN riêng biệt tương ứng cho mỗi số từ 0 đến 9.Mỗi 1 mô hình GAN sẽ chỉ sử dụng duy nhất 1 tập ảnh của 1 số để train, và sau khi train sẽ chỉ sinh ra duy nhất ảnh của số đó. Ví dụ như chúng ta sẽ train cho riêng 1 GAN với các bức ảnh số 0, và sau đó mô hình GAN này sẽ chỉ sinh ra duy nhất ảnh số 0. Tất nhiên cách này sẽ rất tốn kém hiệu năng rồi.
-
Chúng ta chỉ sử dụng 1 mô hình GAN, nhưng sử dụng thêm 1 input mới cho Generator và Discriminator - chúng ta gọi input này là y để giúp Generator và Discriminator phân biệt các điều kiện khác nhau ( giá trị input mới ) thì sẽ sinh ra các bức ảnh của các chữ số khác nhau. Đây chính là ý tưởng cơ bản của Conditional GAN.
Ứng dụng Conditional GAN cho mô hình MNIST của chúng ta, giá trị của y sẽ là từ 0 đến 9, mỗi giá trị của y tương ứng với 1 số mà chúng ta muốn sinh ra.
Generator sẽ dựa vào giá trị của y để sinh ra các bức ảnh chữ số tương ứng. Ví dụ chúng ta truyền vào 0 thì Generator sẽ chỉ sinh ra các bức ảnh của số 0.
Vậy nếu chẳng may sinh ra ảnh của số 1 thì sao? Lúc này là nhiệm vụ của Discriminator. Discriminator giờ đây không chỉ phân biệt thật hay giả, mà còn có nhiệm vụ kiểm tra xem liệu bức ảnh được Generator sinh ra có đúng với điều kiện được truyền vào là y hay không. Nói cách khác, nếu như chúng ta truyền vào 0 mà Generator lại cho ra các bức ảnh của số khác thì Discriminator sẽ coi đây như là ảnh giả - ouput của Discriminator phải là 0. Chỉ khi bức ảnh Generator sinh ra tương ứng với y - truyền vào y = 0 thì nhận được ảnh số 0 thì Discriminator mới coi đây là ảnh thật - ouput của Discriminator là 1.
Vậy này Input của Generator sẽ là noise + y
- output là ảnh (X), còn input của Discriminator sẽ là X + y
, output vẫn sẽ là 1 biến 0 hay 1. Trong trường hợp y là số từ 0 đến 9 , chúng ta sẽ vector hóa nó bằng one hot encoding.
Như vậy các bạn có thể hiểu đơn giản là cGAN chính là GAN được gắn thêm 1 input để phân biệ các điều kiện khác nhau. Các để thêm input đơn giản nhất chính là đầu tiên chúng ta vector hóa đầu vào mới, rồi chèn vào sau đầu vào cũ thôi:
3. Training cGAN bằng tensorflow:
Trong ví dụ này, chúng ta sẽ xây dựng CGAN cho mô hình MNIST. Khi truyền vào y 1 giá trị từ 0 đến 9, chúng ta sẽ nhận được các bức ảnh tương ứng với giá trị của y
Vậy đầu tiên, chúng ta vector hóa giá trị y từ 0 đến 9 sang one-hot encode - 1 vector gồm 10 phần tử:
y = tf.placeholder(tf.float32, shape=[None, 10])
Tiếp theo chúng ta cần phải xây dựng Generator Và Discriminator. Ở bài trước chúng ta đã sử dụng Convolution , nhưng trong bài này để đơn giản hóa, chúng ta sẽ chỉ sử dụng 1 mô hình feedforward gồm 2 lớp layer. Lớp đầu sẽ sử dụng activation function là RELU, lớp 2 sẽ sử dụng acvation là Sigmoid.
def generator(z, y):
# Concatenate z and y, z is noise and y is one-hot vector
inputs = tf.concat(concat_dim=1, values=[z, y])
G_h1 = tf.nn.relu(tf.matmul(inputs, G_W1) + G_b1)
G_log_prob = tf.matmul(G_h1, G_W2) + G_b2
G_prob = tf.nn.sigmoid(G_log_prob)
return G_prob
def discriminator(x, y):
# Concatenate x and y, x is images, and y is one-hot vector
inputs = tf.concat(concat_dim=1, values=[x, y])
D_h1 = tf.nn.relu(tf.matmul(inputs, D_W1) + D_b1)
D_logit = tf.matmul(D_h1, D_W2) + D_b2
D_prob = tf.nn.sigmoid(D_logit)
return D_prob, D_logit
Như các bạn thấy trên code, khi có thêm 1 đầu với mới là y thì chúng ta thực tế là concat
( ghép ) vào với đầu vào cũ của Generator hay Discriminator, rồi tiếp tục coi đó như là Input của mạng neural network.
Sau đó là quá trình training:
G_sample = generator(Z, y)
D_real, D_logit_real = discriminator(X, y)
D_fake, D_logit_fake = discriminator(G_sample, y)
D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_real, labels=tf.ones_like(D_logit_real)))
D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.zeros_like(D_logit_fake)))
D_loss = D_loss_real + D_loss_fake
G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.ones_like(D_logit_fake)))
D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D)
G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=theta_G)
Cuối cùng chúng ta sẽ thứ sinh kết quả mới. Chúng ta sẽ truyền vào 1 giá trị là y= 3
, và chúng ta muốn Generator chỉ sinh ra các bức ảnh của số 3. Vậy đầu tiên chúng ta cần chuyển số 3 của chúng ta về one-hot encoding trước đã:
n_sample = 16 # số lượng ảnh muốn sinh
Z_sample = sample_Z(n_sample, Z_dim) # lấy giá trị noise
y_sample = np.zeros(shape=[n_sample, 10]) # y có thể từ 0 đến 9, vậy one hot có 10 phần tử
y_sample[:, 3] = 1 # y= 3 vậy thì trong one hot vector phần tử ở index số 3 bằng 1, các phần tử khác bằng 0
Sau đó chúng ta sẽ feed vào trong Generator để sinh ra ảnh mới:
samples = sess.run(G_sample, feed_dict={Z: Z_sample, y:y_sample})
fig = plt.figure(figsize=(4, 4))
for i, sample in enumerate(samples):
ax = plt.subplot(gs[i])
plt.axis('off')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect('equal')
plt.imshow(sample.reshape(28, 28), cmap='Greys_r')
plt.show()
Kết quả là đây. Các ảnh số 3 rất đẹp:
Các bạn có thể check code trực tiếp ở đây: https://gist.github.com/astrung/76cdb95589b181bd007d86a2b7a24c8a