diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..6f428b8a44b0090916adcb90ba7be0aee1e9ad2a --- /dev/null +++ b/app.py @@ -0,0 +1,327 @@ +import os +from flask import Flask, render_template, request, jsonify, redirect, url_for, session +import numpy as np +import cv2 +import base64 +import io +from PIL import Image +import tensorflow as tf +import torch +import torch.nn as nn +from segmentation_models_pytorch import Unet +import mysql.connector +from argon2 import PasswordHasher, exceptions as argon2_exceptions + +from dotenv import load_dotenv +load_dotenv() + +app = Flask(__name__) + +app.secret_key = os.environ.get("SECRET_KEY") + +ph = PasswordHasher() + +# The database configuration pulled from .env +db_config = { + 'host': os.environ.get("DB_HOST"), + 'user': os.environ.get("DB_USER"), + 'password': os.environ.get("DB_PASSWORD"), + 'database': os.environ.get("DB_NAME") +} + +def get_db_connection(): + return mysql.connector.connect( + host=db_config['host'], + user=db_config['user'], + password=db_config['password'], + database=db_config['database'] + ) + +ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'} + +def allowed_file(filename): + return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS + +# This loads paths for AI models from environment +classification_model_path = os.environ.get("CLASSIFICATION_MODEL_PATH") +pre_classifier_path = os.environ.get("PRE_CLASSIFIER_PATH") +seg_model_path = os.environ.get("SEG_MODEL_PATH") + +# This loads the classification and pre-classifier models +classification_model = tf.keras.models.load_model(classification_model_path) +pre_classifier = tf.keras.models.load_model(pre_classifier_path) + +labels = ['glioma_tumor', 'meningioma_tumor', 'no_tumor', 'pituitary_tumor'] + +class PretrainedUNet(nn.Module): + def __init__(self, in_channels=1, out_channels=1): + super(PretrainedUNet, self).__init__() + self.unet = Unet( + encoder_name="resnet101", + encoder_weights="imagenet", + in_channels=in_channels, + classes=out_channels + ) + + def forward(self, x): + return self.unet(x) + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +seg_model = PretrainedUNet(in_channels=1, out_channels=1).to(DEVICE) + +# This loads the UNet model +seg_model.load_state_dict( + torch.load(seg_model_path, map_location=DEVICE) +) +seg_model.eval() + +# This defines available colormaps for visualisation +COLORMAPS = { + "jet": cv2.COLORMAP_JET, + "hot": cv2.COLORMAP_HOT, + "bone": cv2.COLORMAP_BONE, + "rainbow": cv2.COLORMAP_RAINBOW, + "ocean": cv2.COLORMAP_OCEAN, + "winter": cv2.COLORMAP_WINTER, + "parula": cv2.COLORMAP_PARULA, + "HSV": cv2.COLORMAP_HSV, +} + +# Function to perform segmentation on a grayscale image +def segment_image(grayscale_img: np.ndarray) -> np.ndarray: + input_size = 256 # This line defines input size expected by the model + original_h, original_w = grayscale_img.shape[:2] # The stores original dimensions + + # Resizes image to match model input size + resized_img = cv2.resize(grayscale_img, (input_size, input_size)) + # Create tensor, add batch and channel dimensions, move to device + tensor = torch.tensor(resized_img, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(DEVICE) + with torch.no_grad(): + logits = seg_model(tensor) # This line gets model output without gradient computation + probs = torch.sigmoid(logits).squeeze().cpu().numpy() # Apply sigmoid activation + + # Create binary mask with threshold 0.5 + mask_256 = (probs > 0.5).astype(np.uint8) + # Resize mask back to original image size + mask_original = cv2.resize(mask_256, (original_w, original_h), interpolation=cv2.INTER_NEAREST) + return mask_original + +# Convert BGR image to Base64 encoded PNG +def bgr_to_base64(img_bgr: np.ndarray) -> str: + pil_img = Image.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)) # Convert to RGB PIL image + buf = io.BytesIO() # Creates a bytes buffer + pil_img.save(buf, format='PNG') # Saves image into buffer + return base64.b64encode(buf.getvalue()).decode('utf-8') # Returns base64 encoded string + +# Function to enforce login for certain routes +def login_required(func): + def wrapper(*args, **kwargs): + if not session.get("logged_in"): # Checks login state + return redirect(url_for('home')) + return func(*args, **kwargs) + wrapper.__name__ = func.__name__ + return wrapper + +# Home route handling login logic +@app.route('/', methods=['GET', 'POST']) +def home(): + if session.get("logged_in"): + return redirect(url_for("index")) # Redirect if already logged in + + if request.method == 'GET': + return render_template('login.html') # Show login form + + # Process POST login attempt + email = request.form.get('email') + password = request.form.get('password') + print(f"DEBUG: Attempting login with email={email}, password={password}") + + conn = get_db_connection() + cursor = conn.cursor(dictionary=True) + try: + cursor.execute( + "SELECT email, password_hash FROM users WHERE email = %s LIMIT 1", + (email,) + ) + user_row = cursor.fetchone() + print("DEBUG: user_row:", user_row) + except mysql.connector.Error as err: + print("DEBUG: Database error:", err) + cursor.close() + conn.close() + return render_template('login.html', error_message="Database error. Please try again.") + finally: + cursor.close() + conn.close() + + if not user_row: + print(f"DEBUG: No user found with email={email}") + return render_template('login.html', error_message="Invalid details") + + db_hashed_pw = user_row['password_hash'] + print(f"DEBUG: DB hashed PW: {db_hashed_pw}") + + try: + # This code verifies provided password against stored hash + ph.verify(db_hashed_pw, password) + print("DEBUG: Password verified successfully!") + session['logged_in'] = True + session['user_email'] = email + return redirect(url_for('index')) + except argon2_exceptions.VerifyMismatchError: + return render_template('login.html', error_message="Invalid details") + except argon2_exceptions.VerificationError: + return render_template('login.html', error_message="Invalid details") + except Exception as e: + print("DEBUG: Unexpected exception in Argon2 verify:", e) + return render_template('login.html', error_message="Invalid details") + +# Main dashboard after login +@app.route('/index') +@login_required +def index(): + return render_template('index.html') + +# Page for image segmentation functionality +@app.route('/segmentation') +@login_required +def segmentation_page(): + return render_template('second.html') + +# Logout endpoint +@app.route('/logout') +def logout(): + session.clear() # Clear session data + return redirect(url_for('home')) + +# Analyse uploaded image for tumor type +@app.route('/analyze', methods=['POST']) +@login_required +def analyze(): + if 'file' not in request.files: + return jsonify({'error': 'No file uploaded'}), 400 + + file = request.files['file'] + if file.filename == '': + return jsonify({'error': 'Empty filename'}), 400 + + if file and allowed_file(file.filename): + img_bytes = file.read() + img_array = np.frombuffer(img_bytes, np.uint8) + img = cv2.imdecode(img_array, cv2.IMREAD_COLOR) + if img is None: + return jsonify({'error': 'Unable to decode image'}), 400 + + # This is the preprocess code for MRI classification + img_for_pre = cv2.resize(img, (150, 150)) + x_pre_classifier = img_for_pre / 255.0 + x_pre_classifier = np.expand_dims(x_pre_classifier, axis=0) + is_mri_pred = pre_classifier.predict(x_pre_classifier) + is_mri = is_mri_pred[0][0] < 0.1 + + if not is_mri: + return jsonify({'error': 'Uploaded image is not an MRI scan'}), 400 + + # Classify tumor type + x_classify = cv2.resize(img, (150, 150)) + x_classify = np.expand_dims(x_classify, axis=0) + preds = classification_model.predict(x_classify) + index = np.argmax(preds[0]) + tumor_type = labels[index] + probability = float(preds[0][index]) + + return jsonify({ + 'tumor_type': tumor_type, + 'probability': probability + }), 200 + else: + return jsonify({'error': 'File type not allowed'}), 400 + +# Segment tumor region from uploaded MRI image +@app.route('/segment', methods=['POST']) +@login_required +def segment(): + if 'file' not in request.files: + return jsonify({'error': 'No file uploaded'}), 400 + + file = request.files['file'] + if file.filename == '': + return jsonify({'error': 'Empty filename'}), 400 + + if file and allowed_file(file.filename): + img_bytes = file.read() + img_bgr = cv2.imdecode(np.frombuffer(img_bytes, np.uint8), cv2.IMREAD_COLOR) + if img_bgr is None: + return jsonify({'error': 'Invalid image data'}), 400 + + # Verify it's an MRI scan + img_for_pre = cv2.resize(img_bgr, (150, 150)) + x_pre_classifier = img_for_pre / 255.0 + x_pre_classifier = np.expand_dims(x_pre_classifier, axis=0) + is_mri_pred = pre_classifier.predict(x_pre_classifier) + is_mri = is_mri_pred[0][0] < 0.1 + if not is_mri: + return jsonify({'error': 'Uploaded image is not an MRI scan'}), 400 + + # Segment tumor + img_gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY) + mask = segment_image(img_gray) + mask_3ch = cv2.merge([mask, mask, mask]).astype(np.uint8) + cutout_bgr = img_bgr * mask_3ch # Applies the mask to the original image + cutout_base64 = bgr_to_base64(cutout_bgr) + + return jsonify({'segmentation': cutout_base64}), 200 + else: + return jsonify({'error': 'File type not allowed'}), 400 + +# Apply a selected colormap to uploaded MRI scan +@app.route('/colormap', methods=['POST']) +@login_required +def apply_colormap(): + if 'file' not in request.files: + return jsonify({'error': 'No file uploaded'}), 400 + + file = request.files['file'] + if file.filename == '': + return jsonify({'error': 'Empty filename'}), 400 + + colormap_name = request.form.get('colormap', 'jet') + if colormap_name not in COLORMAPS: + return jsonify({'error': f'Unknown colormap: {colormap_name}'}), 400 + + if file and allowed_file(file.filename): + img_bytes = file.read() + img_array = np.frombuffer(img_bytes, np.uint8) + mri_gray = cv2.imdecode(img_array, cv2.IMREAD_GRAYSCALE) + if mri_gray is None: + return jsonify({'error': 'Invalid image data'}), 400 + + gray_3ch = cv2.cvtColor(mri_gray, cv2.COLOR_GRAY2BGR) # Converts to 3 channels + color_mapped = cv2.applyColorMap(gray_3ch, COLORMAPS[colormap_name]) # Applies the colormap + colormapped_b64 = bgr_to_base64(color_mapped) + + return jsonify({'colormapped': colormapped_b64}), 200 + else: + return jsonify({'error': 'File type not allowed'}), 400 + +# Database connectivity test for debugging +@app.route('/test_db') +def test_db(): + try: + connection = get_db_connection() + cursor = connection.cursor() + cursor.execute("SHOW TABLES;") + tables = cursor.fetchall() + return jsonify({"tables_in_db": [table[0] for table in tables]}) + except mysql.connector.Error as e: + return f"Database error: {e}", 500 + finally: + if cursor: + cursor.close() + if connection: + connection.close() + +# Run the app +if __name__ == '__main__': + debug_mode = os.environ.get("FLASK_DEBUG", "False").lower() == "true" + app.run(debug=debug_mode)