AI, ML
Developing and training the AlexNet model using PyTorch on the CIFAR-10 dataset
개발공주
2023. 4. 10. 01:10
728x90
PyTorch model code based on "ImageNet Classification with Deep Convolutional Neural Networks" paper
1. Library import
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm
2. AlexNet Network
class AlexNet(nn.Module):
def __init__(self):
super(AlexNet, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=0),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(96, 256, kernel_size=5, stride=1, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(384, 384, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
)
self.classifier = nn.Sequential(
nn.Dropout(p = 0.5),
nn.Linear(256 * 6 * 6, 4096),
nn.ReLU(inplace=True),
nn.Dropout(p = 0.5),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, 10),
)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), 256 * 6 * 6)
output = self.classifier(x)
return output
3. Training dataset code
def train(model, trainloaer, loss_fn, optimizer, device):
model.train()
correct = 0
running_size = 0
running_loss = 0
prograss_bar = tqdm(trainloaer)
for i, data in enumerate(prograss_bar):
inputs, labels = data[0].to(device), data[1].to(device)
optimizer.zero_grad()
output = model(inputs)
loss = loss_fn(output, labels)
loss.backward()
optimizer.step()
_, pred = output.max(dim=1)
correct += pred.eq(labels).sum().item()
running_loss += loss.item() * inputs.size(0) # FIXME:
running_size += inputs.size(0)
loss = running_loss / running_size
acc = correct / running_size
return loss, acc
4. Testing dataset code
def test(model, testloader, device):
model.eval()
correct = 0
with torch.no_grad():
for data in testloader:
images, labels = data[0].to(device), data[1].to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
correct += (predicted == labels).sum().item()
acc = 100. * correct / len(testloader.dataset)
return acc
5. Running the model
def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 데이터셋 불러오기
transform = transforms.Compose([
transforms.Resize(size=(227, 227)),
transforms.ToTensor(), #이미지를 pytorch tensors 타입으로 변형, 0.0~1.0 으로 변환
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # rgb, -1~1로 변환
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False) # TODO: fix
# 모델 초기화 및 하이퍼파라미터 설정
model = AlexNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# train, test
num_epochs = 30
for epoch in range(num_epochs):
train_loss, train_acc = train(model, trainloader, criterion, optimizer, device)
test_acc = test(model, testloader, device)
print(f'epoch {epoch+1:02d}, train loss: {train_loss:.5f}, train acc: {train_acc:.5f}, test accuracy: {test_acc:.5f}')
if __name__ == '__main__':
main()
6. Result
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz
100%|██████████| 170498071/170498071 [00:13<00:00, 12736287.09it/s]
Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
100%|██████████| 391/391 [02:27<00:00, 2.65it/s]
epoch 01, train loss: 1.71098, train acc: 0.36702, test accuracy: 50.63000
100%|██████████| 391/391 [02:20<00:00, 2.79it/s]
epoch 02, train loss: 1.29971, train acc: 0.52890, test accuracy: 59.68000
100%|██████████| 391/391 [02:20<00:00, 2.79it/s]
epoch 03, train loss: 1.12004, train acc: 0.60120, test accuracy: 64.24000
100%|██████████| 391/391 [02:21<00:00, 2.76it/s]
epoch 04, train loss: 0.98776, train acc: 0.65162, test accuracy: 66.73000
100%|██████████| 391/391 [02:20<00:00, 2.77it/s]
epoch 05, train loss: 0.90384, train acc: 0.68420, test accuracy: 66.85000
100%|██████████| 391/391 [02:20<00:00, 2.78it/s]
epoch 06, train loss: 0.84316, train acc: 0.70506, test accuracy: 70.64000
100%|██████████| 391/391 [02:24<00:00, 2.71it/s]
epoch 07, train loss: 0.78164, train acc: 0.72584, test accuracy: 70.87000
100%|██████████| 391/391 [02:32<00:00, 2.57it/s]
epoch 08, train loss: 0.72490, train acc: 0.74338, test accuracy: 72.80000
100%|██████████| 391/391 [02:21<00:00, 2.76it/s]
epoch 09, train loss: 0.69673, train acc: 0.75462, test accuracy: 73.60000
100%|██████████| 391/391 [02:20<00:00, 2.78it/s]
epoch 10, train loss: 0.65764, train acc: 0.76806, test accuracy: 72.77000
100%|██████████| 391/391 [02:19<00:00, 2.81it/s]
epoch 11, train loss: 0.62567, train acc: 0.77932, test accuracy: 73.48000
100%|██████████| 391/391 [02:18<00:00, 2.82it/s]
epoch 12, train loss: 0.59616, train acc: 0.78912, test accuracy: 74.53000
100%|██████████| 391/391 [02:25<00:00, 2.69it/s]
epoch 13, train loss: 0.55846, train acc: 0.80272, test accuracy: 74.60000
100%|██████████| 391/391 [02:20<00:00, 2.78it/s]
epoch 14, train loss: 0.53184, train acc: 0.81340, test accuracy: 74.82000
100%|██████████| 391/391 [02:22<00:00, 2.75it/s]
epoch 15, train loss: 0.52647, train acc: 0.81510, test accuracy: 75.01000
100%|██████████| 391/391 [02:29<00:00, 2.62it/s]
epoch 16, train loss: 0.50323, train acc: 0.82410, test accuracy: 74.83000
100%|██████████| 391/391 [02:29<00:00, 2.61it/s]
epoch 17, train loss: 0.47027, train acc: 0.83294, test accuracy: 75.40000
100%|██████████| 391/391 [02:21<00:00, 2.76it/s]
epoch 18, train loss: 0.46619, train acc: 0.83752, test accuracy: 75.29000
100%|██████████| 391/391 [02:20<00:00, 2.79it/s]
epoch 19, train loss: 0.44510, train acc: 0.84490, test accuracy: 74.54000
100%|██████████| 391/391 [02:20<00:00, 2.79it/s]
epoch 20, train loss: 0.42874, train acc: 0.85056, test accuracy: 74.75000
100%|██████████| 391/391 [02:19<00:00, 2.81it/s]
epoch 21, train loss: 0.43053, train acc: 0.85150, test accuracy: 75.04000
100%|██████████| 391/391 [02:19<00:00, 2.79it/s]
epoch 22, train loss: 0.39896, train acc: 0.86068, test accuracy: 75.82000
100%|██████████| 391/391 [02:20<00:00, 2.78it/s]
epoch 23, train loss: 0.38840, train acc: 0.86602, test accuracy: 75.00000
100%|██████████| 391/391 [02:23<00:00, 2.73it/s]
epoch 24, train loss: 0.37481, train acc: 0.87074, test accuracy: 75.07000
100%|██████████| 391/391 [02:29<00:00, 2.61it/s]
epoch 25, train loss: 0.37298, train acc: 0.87290, test accuracy: 75.35000
100%|██████████| 391/391 [02:20<00:00, 2.79it/s]
epoch 26, train loss: 0.35478, train acc: 0.87682, test accuracy: 75.15000
100%|██████████| 391/391 [02:21<00:00, 2.77it/s]
epoch 27, train loss: 0.33714, train acc: 0.88418, test accuracy: 75.55000
100%|██████████| 391/391 [02:21<00:00, 2.77it/s]
epoch 28, train loss: 0.32900, train acc: 0.88902, test accuracy: 76.60000
100%|██████████| 391/391 [02:29<00:00, 2.62it/s]
epoch 29, train loss: 0.31907, train acc: 0.89206, test accuracy: 75.68000
100%|██████████| 391/391 [02:22<00:00, 2.74it/s]
epoch 30, train loss: 0.32230, train acc: 0.89208, test accuracy: 76.03000
Accuracy 수준zzz
728x90