From 77fac4295d07e979b3dc5ddeaba93cbfecd09154 Mon Sep 17 00:00:00 2001
From: Abdulrahman <abdulrahman2.ali@live.uwe.ac.uk>
Date: Mon, 13 May 2024 15:23:02 +0300
Subject: [PATCH] chore: Organize and improve car prediction CNN model code

---
 .../FINAL_car_prediction_cnn_BRAND.ipynb      | 67 ++++++++++++++-----
 1 file changed, 52 insertions(+), 15 deletions(-)

diff --git a/mlmodel/Final/FINAL_car_prediction_cnn_BRAND.ipynb b/mlmodel/Final/FINAL_car_prediction_cnn_BRAND.ipynb
index 4efae18..b407119 100644
--- a/mlmodel/Final/FINAL_car_prediction_cnn_BRAND.ipynb
+++ b/mlmodel/Final/FINAL_car_prediction_cnn_BRAND.ipynb
@@ -79,23 +79,12 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 33,
+   "execution_count": null,
    "id": "f9ff0b13-86db-4de5-b2a8-599a66ccc27b",
    "metadata": {
     "tags": []
    },
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "  ...])"
-      ]
-     },
-     "execution_count": 33,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
+   "outputs": [],
    "source": [
     "import os\n",
     "\n",
@@ -123,10 +112,11 @@
     "\n",
     "# Create a DataFrame containing the paths to resized images and their labels\n",
     "resized_data = pd.DataFrame({\n",
-    "    'brand': data['brand']\n","
+    "    'src': [os.path.join(resized_dir, os.path.basename(src)) for src in data['src']],\n",
+    "    'brand': data['brand']\n",
     "})\n",
     "\n",
-    "X, y\n"
+    "# X, y\n"
    ]
   },
   {
@@ -324,6 +314,30 @@
     "predicted_class_index = np.argmax(predictions)\n",
     "\n",
     "class_to_brand = {\n",
+    "    0: 'Acura',\n",
+    "    1: 'Audi',\n",
+    "    2: 'Volkswagen',\n",
+    "    3: 'Toyota',\n",
+    "    4: 'Subaru',\n",
+    "    5: 'Porsche',\n",
+    "    6: 'Nissan',\n",
+    "    7: 'MINI',\n",
+    "    8: 'Mercedes-Benz',\n",
+    "    9: 'Mazda',\n",
+    "    10: 'Lincoln',\n",
+    "    11: 'Lexus',\n",
+    "    12: 'Kia',\n",
+    "    13: 'Jeep',\n",
+    "    14: 'Jaguar',\n",
+    "    15: 'Hyundai',\n",
+    "    16: 'Honda',\n",
+    "    17: 'GMC',\n",
+    "    18: 'Ford',\n",
+    "    19: 'Dodge',\n",
+    "    20: 'Chevrolet',\n",
+    "    21: 'Cadillac',\n",
+    "    22: 'BMW',\n",
+    "    23: 'Volvo',\n",
     "}\n",
     "\n",
     "# Get the predicted brand name\n",
@@ -584,6 +598,13 @@
     "from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping\n",
     "\n",
     "datagen = ImageDataGenerator(\n",
+    "    rotation_range=15,\n",
+    "    width_shift_range=0.1,\n",
+    "    height_shift_range=0.1,\n",
+    "    shear_range=0.1,\n",
+    "    zoom_range=0.1,\n",
+    "    horizontal_flip=True,\n",
+    "    vertical_flip=True,\n",
     "    fill_mode='nearest'\n",
     ")\n",
     "\n",
@@ -591,14 +612,30 @@
     "\n",
     "input_shape = X_train.shape[1:]\n",
     "model = models.Sequential([\n",
+    "    layers.Conv2D(32, (3, 3), activation='relu', input_shape=input_shape),\n",
+    "    layers.MaxPooling2D((2, 2)),\n",
+    "    layers.Conv2D(64, (3, 3), activation='relu'),\n",
+    "    layers.MaxPooling2D((2, 2)),\n",
+    "    layers.Conv2D(128, (3, 3), activation='relu'),\n",
+    "    layers.MaxPooling2D((2, 2)),\n",
+    "    layers.Flatten(),\n",
+    "    layers.Dropout(0.5),\n",
+    "    layers.Dense(512, activation='relu'),\n",
     "    layers.Dense(24, activation='softmax')\n",
     "])\n",
+    "model.compile(optimizer='adam',\n",
+    "              loss='categorical_crossentropy',\n",
     "              metrics=['accuracy'])\n",
     "\n",
     "checkpoint_callback = ModelCheckpoint('model2.keras', monitor='val_loss', save_best_only=True)\n",
     "early_stopping = EarlyStopping(monitor='val_loss', patience=10)\n",
     "\n",
     "history = model.fit(\n",
+    "    datagen.flow(X_train, y_train, batch_size=batch_size),\n",
+    "    epochs=100,\n",
+    "    validation_data=(X_test, y_test),\n",
+    "    verbose=1,\n",
+    "    callbacks=[checkpoint_callback, early_stopping],\n",
     ")"
    ]
   },
-- 
GitLab