2023. 11. 22. 16:36ㆍML&DL/CV
[뇌종양 판단 이진 분류 모델]

데이터셋
https://www.kaggle.com/datasets/masoudnickparvar/brain-tumor-mri-dataset
Brain Tumor MRI Dataset
A dataset for classify brain tumors
www.kaggle.com
- 신경교종(glioma) / 뇌수막종(meningioma ) / 뇌하수체종양(pituitary) / 종양 없음(no tumor) 4가지 클래스로 이루어진 뇌 MRI 데이터셋을 이용
- 본 문제에서는 3개의 종양 클래스를 묶어 (신경교종 + 뇌수막종 + 뇌하수체종양)=> 양성 뇌종양
종양 없음=> 음성 뇌종양
데이터셋 구조
tumor
├─Training
│ ├─glioma : 1321 files
│ ├─meningioma : 1339 files
│ ├─notumor : 1595 files
│ └─pituitary : 1457 files
└─Testing
├─glioma : 1321 files
├─meningioma : 1339 files
├─notumor : 1595 files
└─pituitary : 1457 files
💻 실습
* 데이터 준비
- ['glioma', 'meningioma', 'notumor', 'pituitary'] 4개 클래스를 ['tumor', 'notumor'] 양/음성 클래스로 만듦
import os
import matplotlib.pyplot as plt
from glob import glob
from PIL import Image
import shutil
dir_main = "/content/gdrive/MyDrive/tumor/"
classes_before = ['glioma', 'meningioma', 'notumor', 'pituitary']
classes = ['tumor', 'notumor']
source = "Training"
def create_symlink(img_path, from_cls, to_cls):
src = os.path.abspath(img_path)
dst = src.replace(from_cls, to_cls)
os.makedirs(os.path.dirname(dst), exist_ok=True)
if not os.path.exists(dst):
#os.symlink(src, dst)
shutil.copy(src, dst)
for img_path in glob(os.path.join(dir_main, f"*/*/*.jpg")):
for from_cls in ['glioma', 'meningioma', 'pituitary']:
if from_cls in img_path:
create_symlink(img_path, from_cls, 'tumor')
image_paths = []
for cls in classes:
image_paths.extend(glob(os.path.join(dir_main, source, f"{cls}/*.jpg")))
cls_image_paths = {}
n_show = 5
for cls in classes:
cls_image_paths[cls] = [image_path for image_path in image_paths if cls == image_path.split("/")[-2]][:n_show]
for cls in classes:
fig, axes = plt.subplots(nrows=1, ncols=n_show, figsize=(10,2))
for idx, image_path in enumerate(cls_image_paths[cls]):
img = Image.open(image_path)
axes[idx].set_title(f"{cls}_{idx}")
axes[idx].imshow(img)

* train/ val /test
import os
import numpy as np
from glob import glob
np.random.seed(724)
dir_main = "/content/gdrive/MyDrive/tumor/"
classes = ['tumor', 'notumor']
def preprocessing(dir_name, classes):
x_data = []
for cls in classes:
dir_data = os.path.join(dir_main, dir_name)
x_data.extend(glob(f"{dir_data}/{cls}/*.jpg"))
y_data = np.array([x.split("/")[-2] for x in x_data])
return x_data, y_data
x_train, y_train = preprocessing("Training", classes)
x_test, y_test = preprocessing("Testing", classes)
x_val, y_val = preprocessing("Testing", classes)
def get_numbers(ys, cls=None):
cls_cnt = {}
for y in ys:
if y not in cls_cnt.keys():
cls_cnt[y]=0
cls_cnt[y]+=1
if cls is None:
return cls_cnt
return cls_cnt[cls]
print(f"Class\t\tTrain\tVal\tTest\n")
for cls in classes:
print(f"{cls:10}\t{get_numbers(y_train, cls)}\t{get_numbers(y_val, cls)}\t{get_numbers(y_test, cls)}")

* 데이터로더
import torch
from torch.utils.data import Dataset
from glob import glob
class TumorDataset(Dataset):
def __init__(self, dir_dataset, tr):
self.dir_dataset = os.path.abspath(dir_dataset)
self.classes = ['tumor', 'notumor']
self.filelist = []
for cls in self.classes:
self.filelist.extend(glob(self.dir_dataset + f'/{cls}/*.jpg'))
assert len(self.filelist)!=0, f"{self.dir_dataset + '/cls/*.jpg'} is empty"
self.tr = tr
def get_image(self, filename):
img = Image.open(filename).convert("RGB")
img = self.tr(img)
return img
def get_label(self, filename):
label = np.array([0] * len(self.classes))
cls = filename.split('/')[-2]
label[self.classes.index(cls)] = 1
return torch.from_numpy(label).type(torch.FloatTensor)
def __getitem__(self, idx):
filename = self.filelist[idx]
img = self.get_image(filename)
label = self.get_label(filename)
return img, label
def __len__(self):
return len(self.filelist)
* 데이터 전처리
from torch.utils.data import DataLoader
import torchvision.transforms as T
# print(model.default_cfg['mean']) # 'mean': (0.485, 0.456, 0.406)
# print(model.default_cfg['std']) # 'std': (0.229, 0.224, 0.225)
normalize = T.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
train_tr = T.Compose([
T.Resize((256, 256)),
T.RandomCrop(224),
T.RandomHorizontalFlip(),
T.ToTensor(),
normalize
])
test_tr = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
normalize
])
train_ds = TumorDataset(os.path.join(dir_main, "Training"), train_tr)
val_ds = TumorDataset(os.path.join(dir_main, "Testing"), test_tr)
test_ds = TumorDataset(os.path.join(dir_main, "Testing"), test_tr)
train_dl = DataLoader(train_ds, shuffle=True, num_workers=0, batch_size=64)
val_dl = DataLoader(val_ds, shuffle=True, num_workers=0, batch_size=64)
test_dl = DataLoader(test_ds, shuffle=True, num_workers=0, batch_size=64)
* 모델
- Timm에서 제공하는 Pre-trained 모델 ResNet-18 사용
!pip install timm
import timm
print(f"The number of pretrained models : {len(timm.list_models('*', pretrained=True))}")
timm.list_models('resnet*', pretrained=True)
model = timm.create_model('resnet18', pretrained=True)
model.default_cfg
model = timm.create_model('resnet18', pretrained=True, num_classes=len(classes), global_pool='avg')
model.eval()
model(torch.randn(1, 3, 224, 224)).shape
import timm
import torch.nn.functional as F
from torch import nn
class ResNet18(nn.Module):
def __init__(self):
super().__init__()
self.model = timm.create_model('resnet18', pretrained=True, num_classes=len(classes), global_pool='avg')
def forward(self, x):
return torch.sigmoid(self.model(x))
model = ResNet18()
* 모델 학습 준비
import pickle
import os
class TrainHelper():
def __init__(self, save_path='./ckpt/history.pickle', history=[]):
self.history = history
self.save_path = save_path
os.makedirs(os.path.dirname(save_path), exist_ok=True)
def accuracy(self, outputs, labels):
pred = torch.max(outputs, dim=1)[1]
gt = torch.max(labels, dim=1)[1]
return torch.tensor(torch.sum(pred == gt).item() / len(pred))
@torch.no_grad()
def validation(self, batch):
images, labels = batch
out = model(images)
acc = self.accuracy(out, labels)
loss = F.binary_cross_entropy(out, labels)
return {'val_loss': loss.detach(), 'val_acc': acc}
@torch.no_grad()
def evaluation(self, model, data_loader):
model.eval()
outputs = [self.validation(batch) for batch in data_loader]
batch_losses = [x['val_loss'] for x in outputs]
epoch_loss = torch.stack(batch_losses).mean()
batch_accs = [x['val_acc'] for x in outputs]
epoch_acc = torch.stack(batch_accs).mean()
return {'val_loss': round(epoch_loss.item(), 5), 'val_acc': round(epoch_acc.item(), 5)}
def logging(self, epoch, result):
print("Epoch {}: train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(
epoch, result['train_loss'], result['val_loss'], result['val_acc']))
self.history.append(result)
with open(self.save_path, 'wb') as f:
pickle.dump(self.history, f)
train_helper = TrainHelper()
train_helper.evaluation(model, val_dl)
* 모델 학습
- 에포크 3번
- best 모델 저장
from tqdm import tqdm
epochs = 3
optimizer = torch.optim.Adam(model.parameters(), 5.5e-5)
val_acc_best = 0
save_model_path = "./ckpt/"
os.makedirs(save_model_path, exist_ok=True)
for epoch in range(epochs):
# Training Phase
model.train()
train_losses = []
for batch in tqdm(train_dl):
inputs, targets = batch
outputs = model(inputs)
loss = F.binary_cross_entropy(outputs, targets)
train_losses.append(loss)
loss.backward()
optimizer.step()
optimizer.zero_grad()
# Validation phase
result = train_helper.evaluation(model, val_dl)
result['train_loss'] = torch.stack(train_losses).mean().item()
# Save the best model
if result['val_acc'] >= val_acc_best:
val_acc_best = result['val_acc']
if 'save_model_name' in locals() and os.path.exists(save_model_name):
os.remove(save_model_name)
save_model_name = os.path.join(save_model_path, f"best_ep_{epoch}_{val_acc_best}.pt")
torch.save(model.state_dict(), save_model_name)
print(f"Saved PyTorch Model State to {save_model_name}")
train_helper.logging(epoch, result)
# Save the last model
save_model_name = os.path.join(save_model_path, f"last_ep_{epoch}_{val_acc_best}.pt")
torch.save(model.state_dict(), save_model_name)
*학습 현황 시각화
-loss
import matplotlib.pyplot as plt
train_loss = [history['train_loss'] for history in train_helper.history]
val_loss = [history['val_loss'] for history in train_helper.history]
plt.plot(train_loss, label='train loss')
plt.plot(val_loss, label='Validation loss')
plt.grid()
plt.legend(frameon=False)

- acc
import matplotlib.pyplot as plt
val_acc = [history['val_acc'] for history in train_helper.history]
plt.plot(val_acc, label='Accuracy')
plt.grid()
plt.legend(frameon=False)

* 학습 결과

* 모델 테스트


'ML&DL > CV' 카테고리의 다른 글
[Object Detection] COCO 데이터셋을 이용한 교통수단 객체 인식 (0) | 2023.11.24 |
---|---|
[Computer Vision] COCO 데이터셋 (2) | 2023.11.24 |
[Object Detection] Yolo를 이용한 실내 공간의 객체 검출 (0) | 2023.11.21 |
[Computer Vision] Object Detection (객체 검출) (1) | 2023.11.21 |
[Multi-Class Classification] 재활용 쓰레기 분리수거 모델 (1) | 2023.11.21 |