AI, ML

Developing and training the AlexNet model using PyTorch on the CIFAR-10 dataset

개발공주 2023. 4. 10. 01:10
728x90
PyTorch model code based on "ImageNet Classification with Deep Convolutional Neural Networks" paper

 

1. Library import

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm

 

2. AlexNet Network 

class AlexNet(nn.Module):
    def __init__(self):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=0),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),

            nn.Conv2d(96, 256, kernel_size=5, stride=1, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),

            nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),

            nn.Conv2d(384, 384, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),

            nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Dropout(p = 0.5),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(p = 0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, 10),
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), 256 * 6 * 6)
        output = self.classifier(x)
        return output

 

3. Training dataset code

def train(model, trainloaer, loss_fn, optimizer, device):
    model.train()
    correct = 0
    running_size = 0
    running_loss = 0

    prograss_bar = tqdm(trainloaer)

    for i, data in enumerate(prograss_bar):
        inputs, labels = data[0].to(device), data[1].to(device)

        optimizer.zero_grad()
        output = model(inputs)
        loss = loss_fn(output, labels)
        loss.backward()
        optimizer.step()

        _, pred = output.max(dim=1)
        correct += pred.eq(labels).sum().item()
        
        running_loss += loss.item() * inputs.size(0) # FIXME:
        running_size += inputs.size(0)

    loss = running_loss / running_size
    acc = correct / running_size
    return loss, acc

 

4. Testing dataset code

def test(model, testloader, device):
    model.eval()
    correct = 0

    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = model(images)

            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted == labels).sum().item()

    acc = 100. * correct / len(testloader.dataset)
    return acc

 

5. Running the model

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 데이터셋 불러오기
    transform = transforms.Compose([
        transforms.Resize(size=(227, 227)),
        transforms.ToTensor(), #이미지를 pytorch tensors 타입으로 변형, 0.0~1.0 으로 변환
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # rgb, -1~1로 변환
    ])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False) # TODO: fix

    # 모델 초기화 및 하이퍼파라미터 설정
    model = AlexNet().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # train, test
    num_epochs = 30

    for epoch in range(num_epochs):
        train_loss, train_acc = train(model, trainloader, criterion, optimizer, device)   
        test_acc = test(model, testloader, device)

        print(f'epoch {epoch+1:02d}, train loss: {train_loss:.5f}, train acc: {train_acc:.5f}, test accuracy: {test_acc:.5f}')
if __name__ == '__main__':
    main()

 

6. Result

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz
100%|██████████| 170498071/170498071 [00:13<00:00, 12736287.09it/s]
Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
100%|██████████| 391/391 [02:27<00:00,  2.65it/s]
epoch 01, train loss: 1.71098, train acc: 0.36702, test accuracy: 50.63000
100%|██████████| 391/391 [02:20<00:00,  2.79it/s]
epoch 02, train loss: 1.29971, train acc: 0.52890, test accuracy: 59.68000
100%|██████████| 391/391 [02:20<00:00,  2.79it/s]
epoch 03, train loss: 1.12004, train acc: 0.60120, test accuracy: 64.24000
100%|██████████| 391/391 [02:21<00:00,  2.76it/s]
epoch 04, train loss: 0.98776, train acc: 0.65162, test accuracy: 66.73000
100%|██████████| 391/391 [02:20<00:00,  2.77it/s]
epoch 05, train loss: 0.90384, train acc: 0.68420, test accuracy: 66.85000
100%|██████████| 391/391 [02:20<00:00,  2.78it/s]
epoch 06, train loss: 0.84316, train acc: 0.70506, test accuracy: 70.64000
100%|██████████| 391/391 [02:24<00:00,  2.71it/s]
epoch 07, train loss: 0.78164, train acc: 0.72584, test accuracy: 70.87000
100%|██████████| 391/391 [02:32<00:00,  2.57it/s]
epoch 08, train loss: 0.72490, train acc: 0.74338, test accuracy: 72.80000
100%|██████████| 391/391 [02:21<00:00,  2.76it/s]
epoch 09, train loss: 0.69673, train acc: 0.75462, test accuracy: 73.60000
100%|██████████| 391/391 [02:20<00:00,  2.78it/s]
epoch 10, train loss: 0.65764, train acc: 0.76806, test accuracy: 72.77000
100%|██████████| 391/391 [02:19<00:00,  2.81it/s]
epoch 11, train loss: 0.62567, train acc: 0.77932, test accuracy: 73.48000
100%|██████████| 391/391 [02:18<00:00,  2.82it/s]
epoch 12, train loss: 0.59616, train acc: 0.78912, test accuracy: 74.53000
100%|██████████| 391/391 [02:25<00:00,  2.69it/s]
epoch 13, train loss: 0.55846, train acc: 0.80272, test accuracy: 74.60000
100%|██████████| 391/391 [02:20<00:00,  2.78it/s]
epoch 14, train loss: 0.53184, train acc: 0.81340, test accuracy: 74.82000
100%|██████████| 391/391 [02:22<00:00,  2.75it/s]
epoch 15, train loss: 0.52647, train acc: 0.81510, test accuracy: 75.01000
100%|██████████| 391/391 [02:29<00:00,  2.62it/s]
epoch 16, train loss: 0.50323, train acc: 0.82410, test accuracy: 74.83000
100%|██████████| 391/391 [02:29<00:00,  2.61it/s]
epoch 17, train loss: 0.47027, train acc: 0.83294, test accuracy: 75.40000
100%|██████████| 391/391 [02:21<00:00,  2.76it/s]
epoch 18, train loss: 0.46619, train acc: 0.83752, test accuracy: 75.29000
100%|██████████| 391/391 [02:20<00:00,  2.79it/s]
epoch 19, train loss: 0.44510, train acc: 0.84490, test accuracy: 74.54000
100%|██████████| 391/391 [02:20<00:00,  2.79it/s]
epoch 20, train loss: 0.42874, train acc: 0.85056, test accuracy: 74.75000
100%|██████████| 391/391 [02:19<00:00,  2.81it/s]
epoch 21, train loss: 0.43053, train acc: 0.85150, test accuracy: 75.04000
100%|██████████| 391/391 [02:19<00:00,  2.79it/s]
epoch 22, train loss: 0.39896, train acc: 0.86068, test accuracy: 75.82000
100%|██████████| 391/391 [02:20<00:00,  2.78it/s]
epoch 23, train loss: 0.38840, train acc: 0.86602, test accuracy: 75.00000
100%|██████████| 391/391 [02:23<00:00,  2.73it/s]
epoch 24, train loss: 0.37481, train acc: 0.87074, test accuracy: 75.07000
100%|██████████| 391/391 [02:29<00:00,  2.61it/s]
epoch 25, train loss: 0.37298, train acc: 0.87290, test accuracy: 75.35000
100%|██████████| 391/391 [02:20<00:00,  2.79it/s]
epoch 26, train loss: 0.35478, train acc: 0.87682, test accuracy: 75.15000
100%|██████████| 391/391 [02:21<00:00,  2.77it/s]
epoch 27, train loss: 0.33714, train acc: 0.88418, test accuracy: 75.55000
100%|██████████| 391/391 [02:21<00:00,  2.77it/s]
epoch 28, train loss: 0.32900, train acc: 0.88902, test accuracy: 76.60000
100%|██████████| 391/391 [02:29<00:00,  2.62it/s]
epoch 29, train loss: 0.31907, train acc: 0.89206, test accuracy: 75.68000
100%|██████████| 391/391 [02:22<00:00,  2.74it/s]
epoch 30, train loss: 0.32230, train acc: 0.89208, test accuracy: 76.03000

Accuracy 수준zzz 

728x90