diff --git a/app/backend/dummy_data.py b/app/backend/dummy_data.py index f7a55b4eebfe187e36075be20138792aa6826c24..80db0c79c42d2fa78c422ccf58e166113d375b2d 100644 --- a/app/backend/dummy_data.py +++ b/app/backend/dummy_data.py @@ -7,6 +7,7 @@ from backend.models.models import ( ProductImage, Order, OrderItem, + Payment, ) from backend.utils.hashing import hash_password @@ -30,6 +31,26 @@ def insert_dummy_data(session: Session): session.add_all(users) session.commit() + if not session.query(Payment).first(): + payments = [ + Payment( + user_id=1, + payment_method="Visa", + card_number="**** **** **** 1234", + cvv="123", + expiry_date="12/25", + ), + Payment( + user_id=1, + payment_method="MasterCard", + card_number="**** **** **** 5678", + cvv="456", + expiry_date="11/26", + ), + ] + session.add_all(payments) + session.commit() + if not session.query(Shop).first(): shops = [ Shop( @@ -116,6 +137,7 @@ def insert_dummy_data(session: Session): Order( user_id=1, shop_id=1, + payment_id=1, # Link to a valid payment_id total_price=150.0, shipping_price=10.0, status="pending", @@ -126,6 +148,7 @@ def insert_dummy_data(session: Session): Order( user_id=1, shop_id=2, + payment_id=2, # Link to a valid payment_id total_price=200.0, shipping_price=15.0, status="shipped", @@ -133,36 +156,6 @@ def insert_dummy_data(session: Session): delivery_latitude=37.33182, delivery_longitude=-122.03118, ), - Order( - user_id=1, - shop_id=3, - total_price=250.0, - shipping_price=20.0, - status="delivered", - delivery_address="350 Fifth Avenue, New York, NY", - delivery_latitude=40.748817, - delivery_longitude=-73.985428, - ), - Order( - user_id=1, - shop_id=4, - total_price=300.0, - shipping_price=25.0, - status="pending", - delivery_address="221B Baker Street, London, UK", - delivery_latitude=51.523767, - delivery_longitude=-0.1585557, - ), - Order( - user_id=1, - shop_id=5, - total_price=350.0, - shipping_price=30.0, - status="canceled", - delivery_address="Eiffel Tower, Paris, France", - delivery_latitude=48.8588443, - delivery_longitude=2.2943506, - ), ] session.add_all(orders) session.commit() @@ -170,7 +163,7 @@ def insert_dummy_data(session: Session): if not session.query(OrderItem).first(): order_items = [ OrderItem(order_id=i, product_id=i, quantity=i, price=99.99 + i) - for i in range(1, 6) + for i in range(1, 2) ] session.add_all(order_items) session.commit() diff --git a/app/backend/main.py b/app/backend/main.py index 9606ccc22b982a4ee3f358b37390df63fc26d397..2cf270aad86fa6f2c97a3177e54bd35781238d1c 100644 --- a/app/backend/main.py +++ b/app/backend/main.py @@ -5,7 +5,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from fastapi import FastAPI from fastapi.staticfiles import StaticFiles -from backend.routes import auth, shop, product, category, search, order +from backend.routes import auth, shop, product, category, search, order, payment from backend.database import init_db from core.config import settings @@ -13,7 +13,6 @@ app = FastAPI(title="Shopping App", version="1.0.0", debug=settings.debug) # ------------------- NEW: MOUNT STATIC FILES ------------------- # Suppose your static files are located in "app/static" -# Adjust the path if needed. static_dir_path = os.path.join(os.path.dirname(__file__), "..", "static") app.mount("/static", StaticFiles(directory=static_dir_path), name="static") @@ -23,6 +22,7 @@ init_db() # Include API routes app.include_router(search.router, prefix="/search", tags=["search"]) app.include_router(auth.router, prefix="/auth", tags=["auth"]) +app.include_router(payment.router, prefix="/payment", tags=["payment"]) app.include_router(shop.router, prefix="/shops", tags=["shops"]) app.include_router(product.router, prefix="/product", tags=["product"]) app.include_router(category.router, prefix="/category", tags=["category"]) diff --git a/app/backend/models/models.py b/app/backend/models/models.py index 77b9b590cf60683d884d81632073388953082281..8308831705efa3732c88a1ce48883179b08c05b8 100644 --- a/app/backend/models/models.py +++ b/app/backend/models/models.py @@ -8,6 +8,7 @@ class User(SQLModel, table=True): username: str email: str = Field(unique=True, index=True) password: str + payments: List["Payment"] = Relationship(back_populates="user") role: str = Field(default="customer") # Roles: customer, shop_owner, admin created_at: datetime = Field(default_factory=datetime.utcnow) @@ -65,6 +66,7 @@ class Order(SQLModel, table=True): id: Optional[int] = Field(default=None, primary_key=True) user_id: int = Field(foreign_key="user.id") shop_id: int = Field(foreign_key="shop.id") + payment_id: int = Field(foreign_key="payment.id") total_price: float shipping_price: float status: str = Field(default="pending") @@ -75,6 +77,7 @@ class Order(SQLModel, table=True): created_at: datetime = Field(default_factory=datetime.utcnow) user: User = Relationship(back_populates="orders") + payment: "Payment" = Relationship() order_items: List["OrderItem"] = Relationship(back_populates="order") @@ -86,3 +89,15 @@ class OrderItem(SQLModel, table=True): price: float order: Order = Relationship(back_populates="order_items") + + +class Payment(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + user_id: int = Field(foreign_key="user.id") + payment_method: str # "Visa", "MasterCard", "PayPal" + card_number: str = Field(unique=True, nullable=False) + expiry_date: str = None # mm/yyyy + cvv: str = Field(nullable=False) # CVV code + created_at: datetime = Field(default_factory=datetime.utcnow) + + user: User = Relationship(back_populates="payments") diff --git a/app/backend/routes/auth.py b/app/backend/routes/auth.py index ed0d1e94d8cf56ca4a9649f0602b3f2863b12758..7dfc67be951acc8e333e9146a162a65f0e48ae20 100644 --- a/app/backend/routes/auth.py +++ b/app/backend/routes/auth.py @@ -4,8 +4,13 @@ from backend.models.models import User from backend.schemas.user import UserCreate, UserLogin from backend.database import get_session from sqlmodel import Session, select -from backend.utils.hashing import hash_password, verify_password -from app.core.security import decode_token, create_access_token +from backend.utils.hashing import ( + hash_password, + verify_password, + decode_token, + create_access_token, +) + router = APIRouter() diff --git a/app/backend/routes/order.py b/app/backend/routes/order.py index c2eefde4e63ae7fe5b6f18772dfa033a7b5ba463..ee9c0756970cea463e418edb13ff52bca98231cd 100644 --- a/app/backend/routes/order.py +++ b/app/backend/routes/order.py @@ -4,7 +4,7 @@ from geopy.geocoders import Nominatim from geopy.distance import geodesic from backend.database import get_session from backend.routes.auth import get_current_user -from backend.models.models import Order, OrderItem, User, Product, Shop +from backend.models.models import Order, OrderItem, User, Product, Shop, Payment from backend.schemas.order import OrderCreate, OrderRead, OrderUpdate router = APIRouter() @@ -21,6 +21,13 @@ def create_order( if not shop: raise HTTPException(status_code=404, detail="Shop not found") + # Validate the payment ID + payment = session.get(Payment, order_data.payment_id) + if not payment or payment.user_id != current_user.id: + raise HTTPException( + status_code=400, detail="Invalid or unauthorized payment method" + ) + # Geocode the delivery address geolocator = Nominatim(user_agent="order_locator") delivery_location = geolocator.geocode(order_data.delivery_address) @@ -48,6 +55,7 @@ def create_order( new_order = Order( user_id=current_user.id, shop_id=order_data.shop_id, + payment_id=order_data.payment_id, total_price=total_price, shipping_price=shipping_price, status="pending", diff --git a/app/backend/routes/payment.py b/app/backend/routes/payment.py new file mode 100644 index 0000000000000000000000000000000000000000..ba9e21e06716657c8d130f7a3a6e192fb64c12ac --- /dev/null +++ b/app/backend/routes/payment.py @@ -0,0 +1,88 @@ +from fastapi import APIRouter, Depends, HTTPException +from sqlmodel import Session +from backend.models.models import Payment, User +from backend.schemas.payment import PaymentCreate, PaymentRead +from backend.database import get_session +from backend.routes.auth import get_current_user + +router = APIRouter() + + +@router.post("/add", response_model=PaymentRead) +def add_payment( + payment_data: PaymentCreate, + session: Session = Depends(get_session), + current_user: User = Depends(get_current_user), +): + # Extract validated data + payment_method = payment_data.payment_method + card_number = payment_data.card_number + cvv = payment_data.cvv + expiry_date = payment_data.expiry_date + + # Create and save the payment + new_payment = Payment( + user_id=current_user.id, + payment_method=payment_method, + card_number=card_number, + cvv=cvv, + expiry_date=expiry_date, + ) + session.add(new_payment) + session.commit() + session.refresh(new_payment) + + return new_payment + + +@router.get("/{payment_id}", response_model=PaymentRead) +def read_payment( + payment_id: int, + session: Session = Depends(get_session), + current_user: User = Depends(get_current_user), +): + """Retrieve a specific payment by ID.""" + payment = session.get(Payment, payment_id) + if not payment or payment.user_id != current_user.id: + raise HTTPException(status_code=404, detail="Payment not found") + return payment + + +@router.put("/{payment_id}", response_model=PaymentRead) +def update_payment( + payment_id: int, + payment_data: PaymentCreate, + session: Session = Depends(get_session), + current_user: User = Depends(get_current_user), +): + """Update a specific payment by ID.""" + payment = session.get(Payment, payment_id) + if not payment or payment.user_id != current_user.id: + raise HTTPException(status_code=404, detail="Payment not found") + + # Update payment fields + payment.payment_method = payment_data.payment_method + payment.card_number = payment_data.card_number + payment.expiry_date = payment_data.expiry_date + payment.cvv = payment_data.cvv + + session.add(payment) + session.commit() + session.refresh(payment) + return payment + + +@router.delete("/{payment_id}", response_model=dict) +def delete_payment( + payment_id: int, + session: Session = Depends(get_session), + current_user: User = Depends(get_current_user), +): + """Delete a specific payment by ID.""" + payment = session.get(Payment, payment_id) + if not payment or payment.user_id != current_user.id: + raise HTTPException(status_code=404, detail="Payment not found") + + session.delete(payment) + session.commit() + return {"detail": "Payment deleted successfully"} diff --git a/app/backend/schemas/order.py b/app/backend/schemas/order.py index ca5e2ce1dbe2a38530794397f98fdd90cdc38e3a..8cb9ede5854630152c8cf7d13f7ea086a33713ce 100644 --- a/app/backend/schemas/order.py +++ b/app/backend/schemas/order.py @@ -10,6 +10,7 @@ class OrderItemCreate(BaseModel): class OrderCreate(BaseModel): shop_id: int + payment_id: int items: List[OrderItemCreate] delivery_address: str diff --git a/app/backend/schemas/payment.py b/app/backend/schemas/payment.py new file mode 100644 index 0000000000000000000000000000000000000000..60090432a134c7910422970073f0066d22508053 --- /dev/null +++ b/app/backend/schemas/payment.py @@ -0,0 +1,65 @@ +from pydantic import BaseModel, Field, validator +from datetime import datetime +from backend.utils.security import ( + encrypt_card_number, + mask_card_number, + decrypt_card_number, +) + + +class PaymentCreate(BaseModel): + payment_method: str = Field(..., min_length=3, max_length=50) + card_number: str = Field(..., min_length=12, max_length=19, pattern=r"^\d+$") + cvv: str = Field(..., min_length=3, max_length=3, pattern=r"^\d+$") + expiry_date: str # Expecting MM/YY format + + @validator("card_number") + def encrypt_card(cls, value): + """Encrypt card number before storing.""" + return encrypt_card_number(value) + + @validator("cvv") + def encrypt_cvv(cls, value): + """Encrypt CVV before storing.""" + return encrypt_card_number(value) + + @validator("expiry_date") + def validate_expiry_date(cls, value): + """Validate expiry date format (MM/YY) and ensure it's in the future.""" + try: + expiry_obj = datetime.strptime(value, "%m/%y") + current_date = datetime.now() + + if expiry_obj.year < current_date.year or ( + expiry_obj.year == current_date.year + and expiry_obj.month < current_date.month + ): + raise ValueError("Card expiry date must be in the future") + + return value + except ValueError: + raise ValueError("Invalid expiry date format. Use MM/YY") + + +class PaymentRead(BaseModel): + id: int + user_id: int + payment_method: str + expiry_date: str # Keeping expiry date visible + created_at: datetime # Assuming you have a timestamp field + + @classmethod + def from_orm(cls, obj): + """Decrypt card number and mask it before returning.""" + decrypted_card = decrypt_card_number(obj.card_number) + return cls( + id=obj.id, + user_id=obj.user_id, + payment_method=obj.payment_method, + masked_card_number=mask_card_number(decrypted_card), + expiry_date=obj.expiry_date, + created_at=obj.created_at, + ) + + class Config: + from_attributes = True # Enable ORM mode diff --git a/app/backend/utils/hashing.py b/app/backend/utils/hashing.py index 3b8f8ce320d5fe04041181a68b28e623fd0d93f4..02df6fd461aaa069d1223166911129c319e93f5e 100644 --- a/app/backend/utils/hashing.py +++ b/app/backend/utils/hashing.py @@ -1,4 +1,13 @@ from passlib.context import CryptContext +import jwt +from datetime import datetime, timedelta +from fastapi import HTTPException +from jwt import PyJWTError +from core.config import settings + +SECRET_KEY = settings.secret_key +ALGORITHM = settings.algorithm +ACCESS_TOKEN_EXPIRE_MINUTES = 30 pwd_context = CryptContext(schemes=["bcrypt"], bcrypt__default_rounds=12) @@ -9,3 +18,27 @@ def hash_password(password: str) -> str: def verify_password(plain_password: str, hashed_password: str) -> bool: return pwd_context.verify(plain_password, hashed_password) + + +def create_access_token(data: dict): + to_encode = data.copy() + expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + to_encode.update({"exp": expire}) + encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) + return encoded_jwt + + +def decode_token(token: str) -> int: + try: + token = token.replace("Bearer ", "") # Remove "Bearer " prefix + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + user_id: int = payload.get("sub") + if user_id is None: + raise HTTPException( + status_code=401, detail="Invalid authentication credentials" + ) + return user_id + except PyJWTError: + raise HTTPException( + status_code=401, detail="Invalid authentication credentials" + ) diff --git a/app/backend/utils/security.py b/app/backend/utils/security.py new file mode 100644 index 0000000000000000000000000000000000000000..9fc9aec845036c8c9c6028e55f96bee06b049a6a --- /dev/null +++ b/app/backend/utils/security.py @@ -0,0 +1,33 @@ +import base64 +from cryptography.fernet import Fernet +from core.config import settings + + +# Ensure the secret key is 32 bytes and Base64 encoded +def get_fernet_key(secret_key: str) -> bytes: + """Convert a plain string secret key to a valid Fernet key.""" + key_bytes = secret_key.encode() # Convert to bytes + padded_key = base64.urlsafe_b64encode(key_bytes.ljust(32)[:32]) # Ensure 32 bytes + return padded_key + + +# Convert settings.secret_key to a valid Fernet key +SECRET_KEY = get_fernet_key(settings.secret_key) + +# Initialize the Fernet encryption object +fernet = Fernet(SECRET_KEY) + + +def encrypt_card_number(card_number: str) -> str: + """Encrypts the card number.""" + return fernet.encrypt(card_number.encode()).decode() + + +def decrypt_card_number(encrypted_card: str) -> str: + """Decrypts the encrypted card number.""" + return fernet.decrypt(encrypted_card.encode()).decode() + + +def mask_card_number(card_number: str) -> str: + """Masks the card number, showing only the last 4 digits.""" + return f"**** **** **** {card_number[-4:]}" diff --git a/app/core/security.py b/app/core/security.py deleted file mode 100644 index 411ea600c4cb19086ad4355aa5bfa62addebde47..0000000000000000000000000000000000000000 --- a/app/core/security.py +++ /dev/null @@ -1,33 +0,0 @@ -import jwt -from datetime import datetime, timedelta -from fastapi import HTTPException -from jwt import PyJWTError -from core.config import settings - -SECRET_KEY = settings.secret_key -ALGORITHM = settings.algorithm -ACCESS_TOKEN_EXPIRE_MINUTES = 30 - - -def create_access_token(data: dict): - to_encode = data.copy() - expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) - to_encode.update({"exp": expire}) - encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) - return encoded_jwt - - -def decode_token(token: str) -> int: - try: - token = token.replace("Bearer ", "") # Remove "Bearer " prefix - payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) - user_id: int = payload.get("sub") - if user_id is None: - raise HTTPException( - status_code=401, detail="Invalid authentication credentials" - ) - return user_id - except PyJWTError: - raise HTTPException( - status_code=401, detail="Invalid authentication credentials" - )