[Binary Classification] MRI 데이터셋을 사용한 뇌종양 음성/양성 모델

2023. 11. 22. 16:36ML&DL/CV

[뇌종양 판단 이진 분류 모델]

 

 

데이터셋 

https://www.kaggle.com/datasets/masoudnickparvar/brain-tumor-mri-dataset

 

Brain Tumor MRI Dataset

A dataset for classify brain tumors

www.kaggle.com

 

- 신경교종(glioma) / 뇌수막종(meningioma ) / 뇌하수체종양(pituitary) / 종양 없음(no tumor) 4가지 클래스로 이루어진 뇌 MRI 데이터셋을 이용

- 본 문제에서는 3개의 종양 클래스를 묶어 (신경교종 + 뇌수막종  + 뇌하수체종양)=> 양성 뇌종양

                                                                                                            종양 없음=> 음성 뇌종양

 

데이터셋 구조

 tumor
      ├─Training
       ├─glioma     : 1321 files
       ├─meningioma : 1339 files
       ├─notumor    : 1595 files
       └─pituitary  : 1457 files
      └─Testing
        ├─glioma     : 1321 files
        ├─meningioma : 1339 files
        ├─notumor    : 1595 files
        └─pituitary  : 1457 files

 

 

💻 실습

*  데이터 준비

- ['glioma', 'meningioma', 'notumor', 'pituitary'] 4개 클래스를 ['tumor', 'notumor'] 양/음성 클래스로 만듦

import os
import matplotlib.pyplot as plt
from glob import glob
from PIL import Image
import shutil


dir_main = "/content/gdrive/MyDrive/tumor/"
classes_before = ['glioma', 'meningioma', 'notumor', 'pituitary']
classes = ['tumor', 'notumor']
source = "Training"

def create_symlink(img_path, from_cls, to_cls):
    src = os.path.abspath(img_path)
    dst = src.replace(from_cls, to_cls)
    os.makedirs(os.path.dirname(dst), exist_ok=True)
    if not os.path.exists(dst):
        #os.symlink(src, dst)
        shutil.copy(src, dst)

for img_path in glob(os.path.join(dir_main, f"*/*/*.jpg")):
    for from_cls in ['glioma', 'meningioma', 'pituitary']:
        if from_cls in img_path:
            create_symlink(img_path, from_cls, 'tumor')
            
image_paths = []
for cls in classes:
    image_paths.extend(glob(os.path.join(dir_main, source, f"{cls}/*.jpg")))

cls_image_paths = {}
n_show = 5
for cls in classes:
    cls_image_paths[cls] = [image_path for image_path in image_paths if cls == image_path.split("/")[-2]][:n_show]

for cls in classes:
    fig, axes = plt.subplots(nrows=1, ncols=n_show, figsize=(10,2))
    for idx, image_path in enumerate(cls_image_paths[cls]):
        img = Image.open(image_path)
        axes[idx].set_title(f"{cls}_{idx}")
        axes[idx].imshow(img)

* train/ val /test 

import os
import numpy as np
from glob import glob
np.random.seed(724)

dir_main = "/content/gdrive/MyDrive/tumor/"
classes = ['tumor', 'notumor']

def preprocessing(dir_name, classes):
    x_data = []
    for cls in classes:
        dir_data = os.path.join(dir_main, dir_name)
        x_data.extend(glob(f"{dir_data}/{cls}/*.jpg"))
    y_data = np.array([x.split("/")[-2] for x in x_data])
    return x_data, y_data

x_train, y_train = preprocessing("Training", classes)
x_test, y_test = preprocessing("Testing", classes)
x_val, y_val = preprocessing("Testing", classes)
def get_numbers(ys, cls=None):
    cls_cnt = {}
    for y in ys:
        if y not in cls_cnt.keys():
            cls_cnt[y]=0
        cls_cnt[y]+=1
    if cls is None:
        return cls_cnt
    return cls_cnt[cls]

print(f"Class\t\tTrain\tVal\tTest\n")
for cls in classes:
    print(f"{cls:10}\t{get_numbers(y_train, cls)}\t{get_numbers(y_val, cls)}\t{get_numbers(y_test, cls)}")

 

* 데이터로더

 

import torch
from torch.utils.data import Dataset
from glob import glob

class TumorDataset(Dataset): 
    def __init__(self, dir_dataset, tr):
        self.dir_dataset = os.path.abspath(dir_dataset)
        self.classes = ['tumor', 'notumor']
        self.filelist = []
        for cls in self.classes:
            self.filelist.extend(glob(self.dir_dataset + f'/{cls}/*.jpg'))
        assert len(self.filelist)!=0, f"{self.dir_dataset + '/cls/*.jpg'} is empty"
        self.tr = tr

    def get_image(self, filename):
        img = Image.open(filename).convert("RGB")
        img = self.tr(img)
        return img

    def get_label(self, filename):
        label = np.array([0] * len(self.classes))
        cls = filename.split('/')[-2]
        label[self.classes.index(cls)] = 1
        return torch.from_numpy(label).type(torch.FloatTensor)


    def __getitem__(self, idx):
        filename = self.filelist[idx]
        img = self.get_image(filename)
        label = self.get_label(filename)
        return img, label

    def __len__(self): 
        return len(self.filelist)

* 데이터 전처리

from torch.utils.data import DataLoader
import torchvision.transforms as T

# print(model.default_cfg['mean']) # 'mean': (0.485, 0.456, 0.406)
# print(model.default_cfg['std']) # 'std': (0.229, 0.224, 0.225)
normalize = T.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225]
)

train_tr = T.Compose([
    T.Resize((256, 256)),
    T.RandomCrop(224),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    normalize
])

test_tr = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    normalize
])
train_ds = TumorDataset(os.path.join(dir_main, "Training"), train_tr)
val_ds = TumorDataset(os.path.join(dir_main, "Testing"), test_tr)
test_ds = TumorDataset(os.path.join(dir_main, "Testing"), test_tr)

train_dl = DataLoader(train_ds, shuffle=True, num_workers=0, batch_size=64)
val_dl = DataLoader(val_ds, shuffle=True, num_workers=0, batch_size=64)
test_dl = DataLoader(test_ds, shuffle=True, num_workers=0, batch_size=64)

 

* 모델 

 - Timm에서 제공하는 Pre-trained 모델 ResNet-18 사용

 

 

!pip install timm
import timm

print(f"The number of pretrained models : {len(timm.list_models('*', pretrained=True))}")
timm.list_models('resnet*', pretrained=True)
model = timm.create_model('resnet18', pretrained=True)
model.default_cfg
model = timm.create_model('resnet18', pretrained=True, num_classes=len(classes), global_pool='avg')
model.eval()
model(torch.randn(1, 3, 224, 224)).shape

 

import timm
import torch.nn.functional as F
from torch import nn

class ResNet18(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = timm.create_model('resnet18', pretrained=True, num_classes=len(classes), global_pool='avg')
    
    def forward(self, x):
        return torch.sigmoid(self.model(x))

model = ResNet18()

 

* 모델 학습 준비

import pickle
import os

class TrainHelper():
    def __init__(self, save_path='./ckpt/history.pickle', history=[]):
        self.history = history
        self.save_path = save_path
        os.makedirs(os.path.dirname(save_path), exist_ok=True)

    def accuracy(self, outputs, labels):
        pred = torch.max(outputs, dim=1)[1]
        gt = torch.max(labels, dim=1)[1]
        return torch.tensor(torch.sum(pred == gt).item() / len(pred))

    @torch.no_grad()
    def validation(self, batch):
        images, labels = batch 
        out = model(images)
        acc = self.accuracy(out, labels)
        loss = F.binary_cross_entropy(out, labels)
        return {'val_loss': loss.detach(), 'val_acc': acc}

    @torch.no_grad()
    def evaluation(self, model, data_loader):
        model.eval()
        outputs = [self.validation(batch) for batch in data_loader]
        batch_losses = [x['val_loss'] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()
        batch_accs = [x['val_acc'] for x in outputs]
        epoch_acc = torch.stack(batch_accs).mean()
        return {'val_loss': round(epoch_loss.item(), 5), 'val_acc': round(epoch_acc.item(), 5)}

    def logging(self, epoch, result):
        print("Epoch {}: train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(
            epoch, result['train_loss'], result['val_loss'], result['val_acc']))
        self.history.append(result)
        with open(self.save_path, 'wb') as f:
            pickle.dump(self.history, f)

train_helper = TrainHelper()
train_helper.evaluation(model, val_dl)

 

* 모델 학습

- 에포크 3번

- best 모델 저장

from tqdm import tqdm

epochs = 3
optimizer = torch.optim.Adam(model.parameters(), 5.5e-5)
val_acc_best = 0
save_model_path = "./ckpt/"
os.makedirs(save_model_path, exist_ok=True)
for epoch in range(epochs):
    # Training Phase 
    model.train()
    train_losses = []
    for batch in tqdm(train_dl):
        inputs, targets = batch
        outputs = model(inputs)
        loss = F.binary_cross_entropy(outputs, targets)

        train_losses.append(loss)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    # Validation phase
    result = train_helper.evaluation(model, val_dl)
    result['train_loss'] = torch.stack(train_losses).mean().item()

    # Save the best model
    if result['val_acc'] >= val_acc_best:
        val_acc_best = result['val_acc']
        if 'save_model_name' in locals() and os.path.exists(save_model_name):
            os.remove(save_model_name)
        save_model_name = os.path.join(save_model_path, f"best_ep_{epoch}_{val_acc_best}.pt")
        torch.save(model.state_dict(), save_model_name)
        print(f"Saved PyTorch Model State to {save_model_name}")

    train_helper.logging(epoch, result)

# Save the last model
save_model_name = os.path.join(save_model_path, f"last_ep_{epoch}_{val_acc_best}.pt")
torch.save(model.state_dict(), save_model_name)

 

*학습 현황 시각화

-loss

import matplotlib.pyplot as plt

train_loss = [history['train_loss'] for history in train_helper.history]
val_loss = [history['val_loss'] for history in train_helper.history]

plt.plot(train_loss, label='train loss')
plt.plot(val_loss, label='Validation loss')
plt.grid()
plt.legend(frameon=False)

- acc

import matplotlib.pyplot as plt

val_acc = [history['val_acc'] for history in train_helper.history]
plt.plot(val_acc, label='Accuracy')
plt.grid()
plt.legend(frameon=False)

 

* 학습 결과

 

 

 

* 모델 테스트

오답 케이스
정답 케이스