AI, ML
Developing and training the AlexNet model using Tensorflow on CIFAR-10 dataset
개발공주
2023. 4. 11. 02:20
728x90
1. import libraries
import tensorflow as tf
from tensorflow.keras import datasets, layers, models, optimizers, regularizers
2. Load and preprocess CIFAR-10 dataset
(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()
train_images, test_images = train_images / 255.0, test_images / 255.0
num_train = int(len(train_images) * 0.8)
train_images, validation_images = train_images[:num_train], train_images[num_train:]
train_labels, validation_labels = train_labels[:num_train], train_labels[num_train:]
3. Define AlexNet-like model
def create_alexnet():
model = models.Sequential([
layers.experimental.preprocessing.Resizing(224, 224, input_shape=(32, 32, 3)),
layers.Conv2D(96, (11, 11), strides=(4, 4), activation='relu'),
layers.BatchNormalization(),
layers.MaxPooling2D((3, 3), strides=(2, 2)),
layers.Conv2D(256, (5, 5), padding='same', activation='relu'),
layers.BatchNormalization(),
layers.MaxPooling2D((3, 3), strides=(2, 2)),
layers.Conv2D(384, (3, 3), padding='same', activation='relu'),
layers.BatchNormalization(),
layers.Conv2D(384, (3, 3), padding='same', activation='relu'),
layers.BatchNormalization(),
layers.Conv2D(256, (3, 3), padding='same', activation='relu'),
layers.BatchNormalization(),
layers.MaxPooling2D((3, 3), strides=(2, 2)),
layers.Flatten(),
layers.Dense(4096, activation='relu', kernel_regularizer=regularizers.l2(0.0005)),
layers.Dropout(0.5),
layers.Dense(4096, activation='relu', kernel_regularizer=regularizers.l2(0.0005)),
layers.Dropout(0.5),
layers.Dense(10, activation='softmax')
])
return model
4. Create the model
model = create_alexnet()
5. Compile the model
model.compile(optimizer=optimizers.Adam(learning_rate=0.001),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
6. Train the model
history = model.fit(train_images, train_labels, epochs=50, batch_size=128, validation_data=(validation_images, validation_labels))
7. Result
Epoch 1/50
313/313 [==============================] - 58s 125ms/step - loss: 6.3220 - accuracy: 0.3163 - val_loss: 4.0909 - val_accuracy: 0.4078
Epoch 2/50
313/313 [==============================] - 40s 127ms/step - loss: 3.3782 - accuracy: 0.4340 - val_loss: 3.1517 - val_accuracy: 0.3418
Epoch 3/50
313/313 [==============================] - 38s 122ms/step - loss: 2.5291 - accuracy: 0.5001 - val_loss: 2.5480 - val_accuracy: 0.4140
Epoch 4/50
313/313 [==============================] - 39s 125ms/step - loss: 2.1386 - accuracy: 0.5583 - val_loss: 3.2482 - val_accuracy: 0.3013
Epoch 5/50
313/313 [==============================] - 42s 133ms/step - loss: 1.9543 - accuracy: 0.6087 - val_loss: 2.6728 - val_accuracy: 0.4096
Epoch 6/50
313/313 [==============================] - 42s 135ms/step - loss: 1.7871 - accuracy: 0.6665 - val_loss: 1.8894 - val_accuracy: 0.6108
...
Epoch 47/50
313/313 [==============================] - 43s 137ms/step - loss: 0.2638 - accuracy: 0.9839 - val_loss: 1.2113 - val_accuracy: 0.8090
Epoch 48/50
313/313 [==============================] - 41s 130ms/step - loss: 0.2613 - accuracy: 0.9833 - val_loss: 1.3618 - val_accuracy: 0.7728
Epoch 49/50
313/313 [==============================] - 43s 136ms/step - loss: 0.2441 - accuracy: 0.9863 - val_loss: 1.1390 - val_accuracy: 0.8094
Epoch 50/50
313/313 [==============================] - 42s 136ms/step - loss: 0.2470 - accuracy: 0.9844 - val_loss: 1.2852 - val_accuracy: 0.7858
728x90