[GAN] Face 생성 모델 만들기

2023. 11. 28. 16:46ML&DL/CV

[Face 생성 모델 만들기]

 

목표 

GAN을 이용해 Face 생성 모델 만들기

 

데이터셋

* PubFig83 데이터셋

- 유명인 83명에 대한 10000개 가량의 이미지로 구성

 

https://vision.seas.harvard.edu/pubfig83/

 

PubFig83

PubFig83: A resource for studying face recognition in personal photo collections   This is a downloadable dataset of 8300 cropped facial images, made up of 100 images for each of 83 public figures. It was derived from the list of URLs compiled by Neeraj K

vision.seas.harvard.edu

 

💻 실습

* 데이터 준비

- 사이트에서 첫번째 tgz 파일 다운로드

 

- tgz 파일 압축 해제

import tarfile

# 압축 파일 경로
file_path = '/content/drive/MyDrive/Computer_Vision/GAN/pubfig83.v1.tgz'

# 압축 해제할 디렉토리 경로
extract_path = '/content/drive/MyDrive/Computer_Vision/GAN/face_images/'

# tar 파일 압축 해제
with tarfile.open(file_path, 'r:gz') as tar:
    tar.extractall(path=extract_path)

 

-> 84개의 폴더(폴더명: 해당 유명인 이름) 안에 압축이 풀어졌습니다.

 

 

 

 

 

* DataLoader 만들기

import os
import numpy as np
import math
import matplotlib.pyplot as plt

import torchvision.utils as utils
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torchvision import datasets

from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F
image_size = 64
data_folder = '/content/drive/MyDrive/Computer_Vision/GAN/face_images/pubfig83' 

dataset = datasets.ImageFolder(root=data_folder,
                               transform=transforms.Compose([
                                   transforms.Resize(image_size),
                                   transforms.CenterCrop(image_size),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                               ]))

 

* 13838 장의 학습 데이터

 

* 학습 파라미터 정의

n_epochs = 5
batch_size = 128
lr = 0.0002
b1 = 0.5
b2 = 0.999
latent_dim = 100
channels = 1
sample_interval = 400
device = torch.device("cuda" if(torch.cuda.is_available()) else "cpu")

 

* 데이터 살펴보기

dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(utils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))

def weights_init(w):
    classname = w.__class__.__name__
    if classname.find('conv') != -1:
        nn.init.normal_(w.weight.data, 0.0, 0.02)
    elif classname.find('bn') != -1:
        nn.init.normal_(w.weight.data, 1.0, 0.02)
        nn.init.constant_(w.bias.data, 0)
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.tconv1 = nn.ConvTranspose2d(100, 64*8, kernel_size=4, stride=1, padding=0, bias=False)
        self.bn1 = nn.BatchNorm2d(64*8)
        self.tconv2 = nn.ConvTranspose2d(64*8, 64*4, 4, 2, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(64*4)
        self.tconv3 = nn.ConvTranspose2d(64*4, 64*2, 4, 2, 1, bias=False)
        self.bn3 = nn.BatchNorm2d(64*2)
        self.tconv4 = nn.ConvTranspose2d(64*2, 64, 4, 2, 1, bias=False)
        self.bn4 = nn.BatchNorm2d(64)
        self.tconv5 = nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False)

    def forward(self, x):
        x = F.relu(self.bn1(self.tconv1(x)))
        x = F.relu(self.bn2(self.tconv2(x)))
        x = F.relu(self.bn3(self.tconv3(x)))
        x = F.relu(self.bn4(self.tconv4(x)))
        x = F.tanh(self.tconv5(x))

        return x
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, 4, 2, 1, bias=False) # 3 x 64 x 64

        self.conv2 = nn.Conv2d(64, 64*2, 4, 2, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(64*2)

        self.conv3 = nn.Conv2d(64*2, 64*4, 4, 2, 1, bias=False) # 64 x 2 x 16 x 16
        self.bn3 = nn.BatchNorm2d(64*4)

        self.conv4 = nn.Conv2d(64*4, 64*8, 4, 2, 1, bias=False)# 64 x 4 x 8 x 8
        self.bn4 = nn.BatchNorm2d(64*8)
        
        self.conv5 = nn.Conv2d(64*8, 1, 4, 1, 0, bias=False)# 64 x 4 x 4 x 4

    def forward(self, x):
        x = F.leaky_relu(self.conv1(x), 0.2, True)
        x = F.leaky_relu(self.bn2(self.conv2(x)), 0.2, True)
        x = F.leaky_relu(self.bn3(self.conv3(x)), 0.2, True)
        x = F.leaky_relu(self.bn4(self.conv4(x)), 0.2, True)
        x = F.sigmoid(self.conv5(x))

        return x
generator = Generator().to(device)
discriminator = Discriminator().to(device)

generator.apply(weights_init)
discriminator.apply(weights_init)

print(generator)
print(discriminator)

 

 

 

 

 

 

 

 

 

10 epoch
20 epoch

 

40 epoch


* epoch을 늘려서 학습할수록 점점 구체적인 형태로 만들어진다!