[GAN] Face 생성 모델 만들기
2023. 11. 28. 16:46ㆍML&DL/CV
[Face 생성 모델 만들기]
목표
GAN을 이용해 Face 생성 모델 만들기
데이터셋
* PubFig83 데이터셋
- 유명인 83명에 대한 10000개 가량의 이미지로 구성
https://vision.seas.harvard.edu/pubfig83/
PubFig83
PubFig83: A resource for studying face recognition in personal photo collections This is a downloadable dataset of 8300 cropped facial images, made up of 100 images for each of 83 public figures. It was derived from the list of URLs compiled by Neeraj K
vision.seas.harvard.edu
💻 실습
* 데이터 준비
- 사이트에서 첫번째 tgz 파일 다운로드
- tgz 파일 압축 해제
import tarfile
# 압축 파일 경로
file_path = '/content/drive/MyDrive/Computer_Vision/GAN/pubfig83.v1.tgz'
# 압축 해제할 디렉토리 경로
extract_path = '/content/drive/MyDrive/Computer_Vision/GAN/face_images/'
# tar 파일 압축 해제
with tarfile.open(file_path, 'r:gz') as tar:
tar.extractall(path=extract_path)
-> 84개의 폴더(폴더명: 해당 유명인 이름) 안에 압축이 풀어졌습니다.
* DataLoader 만들기
import os
import numpy as np
import math
import matplotlib.pyplot as plt
import torchvision.utils as utils
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torchvision import datasets
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F
image_size = 64
data_folder = '/content/drive/MyDrive/Computer_Vision/GAN/face_images/pubfig83'
dataset = datasets.ImageFolder(root=data_folder,
transform=transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
* 13838 장의 학습 데이터
* 학습 파라미터 정의
n_epochs = 5
batch_size = 128
lr = 0.0002
b1 = 0.5
b2 = 0.999
latent_dim = 100
channels = 1
sample_interval = 400
device = torch.device("cuda" if(torch.cuda.is_available()) else "cpu")
* 데이터 살펴보기
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(utils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))
def weights_init(w):
classname = w.__class__.__name__
if classname.find('conv') != -1:
nn.init.normal_(w.weight.data, 0.0, 0.02)
elif classname.find('bn') != -1:
nn.init.normal_(w.weight.data, 1.0, 0.02)
nn.init.constant_(w.bias.data, 0)
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.tconv1 = nn.ConvTranspose2d(100, 64*8, kernel_size=4, stride=1, padding=0, bias=False)
self.bn1 = nn.BatchNorm2d(64*8)
self.tconv2 = nn.ConvTranspose2d(64*8, 64*4, 4, 2, 1, bias=False)
self.bn2 = nn.BatchNorm2d(64*4)
self.tconv3 = nn.ConvTranspose2d(64*4, 64*2, 4, 2, 1, bias=False)
self.bn3 = nn.BatchNorm2d(64*2)
self.tconv4 = nn.ConvTranspose2d(64*2, 64, 4, 2, 1, bias=False)
self.bn4 = nn.BatchNorm2d(64)
self.tconv5 = nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False)
def forward(self, x):
x = F.relu(self.bn1(self.tconv1(x)))
x = F.relu(self.bn2(self.tconv2(x)))
x = F.relu(self.bn3(self.tconv3(x)))
x = F.relu(self.bn4(self.tconv4(x)))
x = F.tanh(self.tconv5(x))
return x
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, 4, 2, 1, bias=False) # 3 x 64 x 64
self.conv2 = nn.Conv2d(64, 64*2, 4, 2, 1, bias=False)
self.bn2 = nn.BatchNorm2d(64*2)
self.conv3 = nn.Conv2d(64*2, 64*4, 4, 2, 1, bias=False) # 64 x 2 x 16 x 16
self.bn3 = nn.BatchNorm2d(64*4)
self.conv4 = nn.Conv2d(64*4, 64*8, 4, 2, 1, bias=False)# 64 x 4 x 8 x 8
self.bn4 = nn.BatchNorm2d(64*8)
self.conv5 = nn.Conv2d(64*8, 1, 4, 1, 0, bias=False)# 64 x 4 x 4 x 4
def forward(self, x):
x = F.leaky_relu(self.conv1(x), 0.2, True)
x = F.leaky_relu(self.bn2(self.conv2(x)), 0.2, True)
x = F.leaky_relu(self.bn3(self.conv3(x)), 0.2, True)
x = F.leaky_relu(self.bn4(self.conv4(x)), 0.2, True)
x = F.sigmoid(self.conv5(x))
return x
generator = Generator().to(device)
discriminator = Discriminator().to(device)
generator.apply(weights_init)
discriminator.apply(weights_init)
print(generator)
print(discriminator)
* epoch을 늘려서 학습할수록 점점 구체적인 형태로 만들어진다!
'ML&DL > CV' 카테고리의 다른 글
[Computer Vision] CAM(Class Avtivation Map) 모델을 이용한 특징 시각화 (0) | 2023.11.29 |
---|---|
[GAN] CIFAR-10 데이터셋을 이용한 이미지 생성 모델 만들기 (0) | 2023.11.28 |
[Computer Vision] GAN (Generative adversarial network) (0) | 2023.11.28 |
[Object Detection] COCO 데이터셋을 이용한 food 객체 인식 (0) | 2023.11.27 |
[Object Detection] COCO 데이터셋을 이용한 교통수단 객체 인식 (0) | 2023.11.24 |