[GAN] CIFAR-10 데이터셋을 이용한 이미지 생성 모델 만들기
2023. 11. 28. 11:30ㆍML&DL/CV
[Computer Vision] GAN (Generative adversarial network)
[ GAN (Generative adversarial network) ] 1. GAN? * GAN은 이미지 생성에 주로 사용 하는 딥러닝 모델이다. - CNN은 데이터셋이 어떤 클래스인지 분류하는 문제에서 주로 사용한다면, GAN은 데이터셋과 유사한
situdy.tistory.com
목표
CIFAR-10 데이터셋 기반 새로운 이미지 생성 모델 만들기 (GAN)
데이터셋
* CIFAR-10
- 10가지 객체가 포함되어있는 이미지 데이터셋
💻 실습
* 라이브러리 import
import os
import numpy as np
import math
import matplotlib.pyplot as plt
import torchvision.utils as utils
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torchvision import datasets
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F
* 데이터셋 불러오기
- 64x64 가 되도록 리사이즈 하여 다운로드
image_size = 64
dataset = datasets.CIFAR10(root='data', download = True,
transform = transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
* 학습 하이퍼파라미터 선언
n_epochs = 5
batch_size = 128
lr = 0.0002
b1 = 0.5
b2 = 0.999
latent_dim = 100
channels = 1
sample_interval = 400
device = torch.device("cuda" if(torch.cuda.is_available()) else "cpu")
* 데이터 로더
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
* 데이터 살펴보기
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(utils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))
- 생성 이미지의 기반이 될 학습 데이터.
- 학습 데이터를 변형하여 새로운 이미지를 생성함
* 가중치 초기화 함수
def weights_init(w):
classname = w.__class__.__name__
if classname.find('conv') != -1:
nn.init.normal_(w.weight.data, 0.0, 0.02)
elif classname.find('bn') != -1:
nn.init.normal_(w.weight.data, 1.0, 0.02)
nn.init.constant_(w.bias.data, 0)
* 생성자(generator) 함수
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.tconv1 = nn.ConvTranspose2d(100, 64*8, kernel_size=4, stride=1, padding=0, bias=False)
self.bn1 = nn.BatchNorm2d(64*8)
self.tconv2 = nn.ConvTranspose2d(64*8, 64*4, 4, 2, 1, bias=False)
self.bn2 = nn.BatchNorm2d(64*4)
self.tconv3 = nn.ConvTranspose2d(64*4, 64*2, 4, 2, 1, bias=False)
self.bn3 = nn.BatchNorm2d(64*2)
self.tconv4 = nn.ConvTranspose2d(64*2, 64, 4, 2, 1, bias=False)
self.bn4 = nn.BatchNorm2d(64)
self.tconv5 = nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False)
def forward(self, x):
x = F.relu(self.bn1(self.tconv1(x)))
x = F.relu(self.bn2(self.tconv2(x)))
x = F.relu(self.bn3(self.tconv3(x)))
x = F.relu(self.bn4(self.tconv4(x)))
x = F.tanh(self.tconv5(x))
return x
* 판별자(discriminator) 함수
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, 4, 2, 1, bias=False) # 3 x 64 x 64
self.conv2 = nn.Conv2d(64, 64*2, 4, 2, 1, bias=False)
self.bn2 = nn.BatchNorm2d(64*2)
self.conv3 = nn.Conv2d(64*2, 64*4, 4, 2, 1, bias=False) # 64 x 2 x 16 x 16
self.bn3 = nn.BatchNorm2d(64*4)
self.conv4 = nn.Conv2d(64*4, 64*8, 4, 2, 1, bias=False)# 64 x 4 x 8 x 8
self.bn4 = nn.BatchNorm2d(64*8)
self.conv5 = nn.Conv2d(64*8, 1, 4, 1, 0, bias=False)# 64 x 4 x 4 x 4
def forward(self, x):
x = F.leaky_relu(self.conv1(x), 0.2, True)
x = F.leaky_relu(self.bn2(self.conv2(x)), 0.2, True)
x = F.leaky_relu(self.bn3(self.conv3(x)), 0.2, True)
x = F.leaky_relu(self.bn4(self.conv4(x)), 0.2, True)
x = F.sigmoid(self.conv5(x))
return x
* 모델 생성
generator = Generator().to(device)
discriminator = Discriminator().to(device)
generator.apply(weights_init)
discriminator.apply(weights_init)
print(generator)
print(discriminator)
* 모델 컴파일
- Binary Cross Entropy
- 생성자와 판별자 옵티마이저각각 선언 해주어야함
# Loss function
adversarial_loss = nn.BCELoss()
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))
* 레이블 선언 (1- 진짜 데이터 / 0 - 가짜 데이터)
- 랜덤 노이즈 고정
fixed_noise = torch.randn(64, 100, 1, 1, device=device)
real_label = 1.
fake_label = 0.
* 모델 학습
img_list = []
G_losses = []
D_losses = []
iters = 0
for epoch in range(n_epochs):
for i, data in enumerate(dataloader, 0):
# 1. Discriminator 학습
# 1-1. Real data
real_img = data[0].to(device)
b_size = real_img.size(0)
label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
discriminator.zero_grad()
output = discriminator(real_img).view(-1)
real_loss = adversarial_loss(output, label)
real_loss.backward()
D_x = output.mean().item()
# 1-2. Fake data
noise = torch.randn(b_size, 100, 1, 1, device=device)
fake = generator(noise)
label.fill_(fake_label)
output = discriminator(fake.detach()).view(-1)
fake_loss = adversarial_loss(output, label)
fake_loss.backward()
D_G_z1 = output.mean().item()
disc_loss = real_loss + fake_loss
optimizer_D.step()
# 2. Generator 학습
generator.zero_grad()
label.fill_(real_label)
output = discriminator(fake).view(-1)
gen_loss = adversarial_loss(output, label)
gen_loss.backward()
D_G_z2 = output.mean().item()
optimizer_G.step()
if i % 50 == 0:
print('[{}/{}][{}/{}]'.format(epoch+1, n_epochs, i, len(dataloader)))
print('Discriminator Loss:{:.4f}\t Generator Loss:{:.4f}\t D(x):{:.4f}\t D(G(z)):{:.4f}/{:.4f}'.format(disc_loss.item(), gen_loss.item(), D_x, D_G_z1, D_G_z2))
G_losses.append(gen_loss.item())
D_losses.append(disc_loss.item())
if (iters % 500 == 0) or ((epoch == n_epochs-1) and (i == len(dataloader)-1)):
with torch.no_grad():
fake = generator(fixed_noise).detach().cpu()
img_list.append(utils.make_grid(fake, padding=2, normalize=True))
iters += 1
'ML&DL > CV' 카테고리의 다른 글
[Computer Vision] CAM(Class Avtivation Map) 모델을 이용한 특징 시각화 (0) | 2023.11.29 |
---|---|
[GAN] Face 생성 모델 만들기 (0) | 2023.11.28 |
[Computer Vision] GAN (Generative adversarial network) (0) | 2023.11.28 |
[Object Detection] COCO 데이터셋을 이용한 food 객체 인식 (0) | 2023.11.27 |
[Object Detection] COCO 데이터셋을 이용한 교통수단 객체 인식 (0) | 2023.11.24 |