diff --git a/app/backend/database.py b/app/backend/database.py index 261851da7b8ae289d85b2e3cd707ab1cfdc54968..1571f4d478aa1fba94156d92cea71f87b660c464 100644 --- a/app/backend/database.py +++ b/app/backend/database.py @@ -6,6 +6,7 @@ engine = create_engine(settings.database_url, echo=settings.debug) def init_db(): + """Initialize database with tables and dummy data""" SQLModel.metadata.create_all(engine, checkfirst=True) with Session(engine) as session: insert_dummy_data(session) diff --git a/app/backend/dummy_data.py b/app/backend/dummy_data.py index 89a5687750a6b39d9d8086fe66ba39b7c4dcf396..fe3755682e8f6af1a981e04d458c00090e082304 100644 --- a/app/backend/dummy_data.py +++ b/app/backend/dummy_data.py @@ -14,7 +14,23 @@ from app.backend.models.models import ( from app.backend.utils.hashing import hash_password +def check_dummy_data_exists(session: Session) -> bool: + """Check if any dummy data already exists in the database""" + checks = [ + session.exec(select(User).where(User.email == "user@example.com")).first(), + session.exec(select(Shop).where(Shop.name == "Google HQ")).first(), + session.exec(select(Category).where(Category.name == "Category1")).first(), + session.exec(select(Product).where(Product.name == "Product1")).first(), + ] + return any(checks) + + def insert_dummy_data(session: Session): + """Insert dummy data only if it doesn't exist""" + if check_dummy_data_exists(session): + print("Dummy data already exists, skipping insertion") + return + if not session.exec(select(User)).first(): users = [ User( diff --git a/app/backend/routes/auth.py b/app/backend/routes/auth.py index a36aa5400cb3cdedbdfdfca5086af3a4dcbd657c..e288a295ff1e2e72cd7b255ea8019b74a544d468 100644 --- a/app/backend/routes/auth.py +++ b/app/backend/routes/auth.py @@ -1,7 +1,7 @@ from fastapi import APIRouter, Depends, HTTPException from fastapi.security import OAuth2PasswordBearer from app.backend.models.models import User -from app.backend.schemas.user import UserCreate, UserLogin +from app.backend.schemas.user import UserCreate, UserLogin, UserRead, UserUpdate from app.backend.database import get_session from sqlmodel import Session, select from app.backend.utils.hashing import ( @@ -65,16 +65,10 @@ def login(user_data: UserLogin, session: Session = Depends(get_session)): } -@router.get("/profile") +@router.get("/profile", response_model=UserRead) def get_user_profile(current_user: User = Depends(get_current_user)): """Get the current user's profile information""" - return { - "username": current_user.username, - "name": current_user.username, # Just use username since name isn't in the model - "email": current_user.email, - "phone": current_user.phone_number, - "role": current_user.role, - } + return current_user @router.get("/role") @@ -85,21 +79,32 @@ def get_user_role(current_user: User = Depends(get_current_user)): @router.put("/update") def update_user_profile( - user_data: dict, + user_data: UserUpdate, current_user: User = Depends(get_current_user), session: Session = Depends(get_session), ): """Update the current user's profile information""" - # Update user fields - if "username" in user_data: - current_user.username = user_data["username"] - if "email" in user_data: - current_user.email = user_data["email"] - if "phone" in user_data: - current_user.phone_number = user_data["phone"] - - session.add(current_user) - session.commit() - session.refresh(current_user) - - return {"message": "Profile updated successfully"} + try: + # Update user fields if provided + if user_data.username is not None: + current_user.username = user_data.username + if user_data.email is not None: + # Check if email already exists for another user + existing_user = session.exec( + select(User) + .where(User.email == user_data.email) + .where(User.id != current_user.id) + ).first() + if existing_user: + raise HTTPException(status_code=400, detail="Email already registered") + current_user.email = user_data.email + if user_data.phone is not None: + current_user.phone_number = user_data.phone + + session.add(current_user) + session.commit() + session.refresh(current_user) + + return current_user + except ValueError as e: + raise HTTPException(status_code=422, detail=str(e)) diff --git a/app/backend/schemas/user.py b/app/backend/schemas/user.py index be829391c3c826a00bd5c6d012afef2e5e62bf3c..39605fa118c9339b30128777e3444181a1247ee3 100644 --- a/app/backend/schemas/user.py +++ b/app/backend/schemas/user.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, EmailStr, ConfigDict +from pydantic import BaseModel, EmailStr, ConfigDict, field_validator from typing import Optional @@ -9,6 +9,22 @@ class UserCreate(BaseModel): phone_number: str password: str + @field_validator("password") + @classmethod + def password_must_not_be_empty(cls, v): + if not v or len(v.strip()) == 0: + raise ValueError("Password must not be empty") + if len(v) < 8: + raise ValueError("Password must be at least 8 characters long") + return v + + @field_validator("phone_number") + @classmethod + def phone_number_must_be_valid(cls, v): + if not v.isdigit() or len(v) < 10: + raise ValueError("Phone number must be valid") + return v + # Schema for user login class UserLogin(BaseModel): @@ -26,8 +42,25 @@ class UserResponse(BaseModel): model_config = ConfigDict(from_attributes=True) +# Schema for reading user data +class UserRead(BaseModel): + username: str + phone_number: str + email: EmailStr + role: str + model_config = ConfigDict(from_attributes=True) + + # Schema for updating user profile class UserUpdate(BaseModel): username: Optional[str] = None email: Optional[EmailStr] = None - password: Optional[str] = None + phone: Optional[str] = None + + @field_validator("phone") + @classmethod + def phone_number_must_be_valid(cls, v): + if v is not None: + if not v.isdigit() or len(v) < 10: + raise ValueError("Phone number must be valid") + return v diff --git a/app/backend/utils/hashing.py b/app/backend/utils/hashing.py index 336fe4cf13c277e4c1436c5e7ca1607cdd4fa1bb..2090d16b2d45394e3b833f864649e3997238262d 100644 --- a/app/backend/utils/hashing.py +++ b/app/backend/utils/hashing.py @@ -46,6 +46,8 @@ def decode_token(token: str) -> int: status_code=401, detail="Invalid authentication credentials" ) return user_id + except jwt.ExpiredSignatureError: + raise HTTPException(status_code=401, detail="Token has expired") except PyJWTError as e: print(f"JWT error: {str(e)}") raise HTTPException( diff --git a/app/tests/test_auth.py b/app/tests/test_auth.py index 1e8e2ef1ae2a4cb900534812182cce67d2cea921..d42d918bc8d7a47e0124797cbbd14dd5b9106b45 100644 --- a/app/tests/test_auth.py +++ b/app/tests/test_auth.py @@ -1,33 +1,56 @@ import pytest from fastapi import HTTPException from app.backend.schemas.user import UserCreate, UserLogin +from app.backend.utils.hashing import hash_password, verify_password +from app.backend.models.models import User +from sqlmodel import select def test_signup_success(client, db_session): """Test successful user registration""" + password = "testpassword123" + hashed_password = hash_password(password) # Hash password explicitly user_data = { "username": "testuser", "email": "test@example.com", "phone_number": "1234567890", - "password": "testpassword123", + "password": password, } response = client.post("/auth/signup", json=user_data) assert response.status_code == 200 assert response.json()["message"] == "User created successfully" + # Verify password was properly hashed + user = db_session.exec(select(User).where(User.email == user_data["email"])).first() + assert user is not None + assert user.password.startswith("$2b$") # Check bcrypt hash format + assert user.password != password # Ensure password was hashed + assert user.password != hashed_password # Each hash should be unique + assert verify_password(password, user.password) # Verify password matches hash + def test_signup_duplicate_email(client, db_session): """Test registration with duplicate email""" + password = "testpassword123" + hashed_password = hash_password(password) # Hash password explicitly user_data = { "username": "testuser", "email": "test@example.com", "phone_number": "1234567890", - "password": "testpassword123", + "password": password, } # Create first user - client.post("/auth/signup", json=user_data) + response = client.post("/auth/signup", json=user_data) + assert response.status_code == 200 + + # Verify first user's password is properly hashed + user = db_session.exec(select(User).where(User.email == user_data["email"])).first() + assert user.password.startswith("$2b$") # Check bcrypt hash format + assert user.password != password # Ensure password was hashed + assert user.password != hashed_password # Each hash should be unique + assert verify_password(password, user.password) # Try to create second user with same email response = client.post("/auth/signup", json=user_data) @@ -35,9 +58,55 @@ def test_signup_duplicate_email(client, db_session): assert "Email already registered" in response.json()["detail"] -def test_login_success(client, db_session): - """Test successful login""" - # First create a user +def test_signup_invalid_data(client, db_session): + """Test registration with invalid data""" + # Test empty password + user_data = { + "username": "testuser", + "email": "test@example.com", + "phone_number": "1234567890", + "password": "", + } + response = client.post("/auth/signup", json=user_data) + assert response.status_code == 422 + + # Test invalid email format + user_data["password"] = "testpassword123" + user_data["email"] = "invalid_email" + response = client.post("/auth/signup", json=user_data) + assert response.status_code == 422 + + +def test_invalid_token(client, db_session): + """Test accessing endpoints with invalid token""" + invalid_token = "invalid_token" + + # Test profile endpoint + response = client.get( + "/auth/profile", headers={"Authorization": f"Bearer {invalid_token}"} + ) + assert response.status_code == 401 + assert "Invalid authentication credentials" in response.json()["detail"] + + # Test role endpoint + response = client.get( + "/auth/role", headers={"Authorization": f"Bearer {invalid_token}"} + ) + assert response.status_code == 401 + assert "Invalid authentication credentials" in response.json()["detail"] + + +def test_token_expiration(client, db_session, monkeypatch): + """Test token expiration fast without real waiting""" + from datetime import timedelta + import time + + # Monkeypatch the token expiration to 1 second + monkeypatch.setattr( + "app.backend.utils.hashing.ACCESS_TOKEN_EXPIRE_MINUTES", 0.016 + ) # ~1 second + + # Create and login user user_data = { "username": "testuser", "email": "test@example.com", @@ -46,14 +115,55 @@ def test_login_success(client, db_session): } client.post("/auth/signup", json=user_data) + login_response = client.post( + "/auth/login", json={"email": "test@example.com", "password": "testpassword123"} + ) + token = login_response.json()["access_token"] + + # Sleep just a bit over 1 second + time.sleep(2) + + response = client.get("/auth/profile", headers={"Authorization": f"Bearer {token}"}) + assert response.status_code == 401 + assert "Token has expired" in response.json()["detail"] + + +def test_login_success(client, db_session): + """Test successful login""" + password = "testpassword123" + user_data = { + "username": "testuser", + "email": "test@example.com", + "phone_number": "1234567890", + "password": password, + } + # Create user + response = client.post("/auth/signup", json=user_data) + assert response.status_code == 200 + + # Verify password was hashed + from app.backend.models.models import User + from sqlmodel import select + + user = db_session.exec(select(User).where(User.email == user_data["email"])).first() + assert verify_password(password, user.password) + # Then try to login - login_data = {"email": "test@example.com", "password": "testpassword123"} + login_data = {"email": "test@example.com", "password": password} response = client.post("/auth/login", json=login_data) assert response.status_code == 200 assert "access_token" in response.json() assert response.json()["role"] == "buyer" +def test_login_missing_user(client, db_session): + """Test login with non-existent user""" + login_data = {"email": "nonexistent@example.com", "password": "testpassword123"} + response = client.post("/auth/login", json=login_data) + assert response.status_code == 401 + assert "Invalid credentials" in response.json()["detail"] + + def test_login_invalid_credentials(client, db_session): """Test login with invalid credentials""" login_data = {"email": "wrong@example.com", "password": "wrongpassword"} @@ -83,7 +193,8 @@ def test_get_profile(client, db_session): assert response.status_code == 200 assert response.json()["username"] == "testuser" assert response.json()["email"] == "test@example.com" - assert response.json()["phone"] == "1234567890" + assert response.json()["phone_number"] == "1234567890" + assert response.json()["role"] == "buyer" def test_update_profile(client, db_session): @@ -119,7 +230,39 @@ def test_update_profile(client, db_session): ) assert profile_response.json()["username"] == "newusername" assert profile_response.json()["email"] == "newemail@example.com" - assert profile_response.json()["phone"] == "9876543210" + assert profile_response.json()["phone_number"] == "9876543210" + assert response.json()["role"] == "buyer" + + +def test_update_profile_invalid_data(client, db_session): + """Test updating profile with invalid data""" + # Create and login user + user_data = { + "username": "testuser", + "email": "test@example.com", + "phone_number": "1234567890", + "password": "testpassword123", + } + client.post("/auth/signup", json=user_data) + + login_response = client.post( + "/auth/login", json={"email": "test@example.com", "password": "testpassword123"} + ) + token = login_response.json()["access_token"] + + # Test invalid email format + update_data = {"email": "invalid_email"} + response = client.put( + "/auth/update", json=update_data, headers={"Authorization": f"Bearer {token}"} + ) + assert response.status_code == 422 + + # Test invalid phone number + update_data = {"phone": "invalid"} + response = client.put( + "/auth/update", json=update_data, headers={"Authorization": f"Bearer {token}"} + ) + assert response.status_code == 422 def test_get_user_role(client, db_session):