diff --git a/app/__init__.py b/app/__init__.py index 6a05c0b4998b1591590a89e75ce7093a5a105bf5..75e454890367ec299fec9a078e3b9cee21455e9d 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -1,29 +1,44 @@ -# Boilerplate create_app code taken from https://www.digitalocean.com/community/tutorials/how-to-structure-a-large-flask-application-with-flask-blueprints-and-flask-sqlalchemy - -from flask import Flask, url_for, redirect, session +from flask import Flask, url_for, redirect, session, g, abort from config import Config from flask_sqlalchemy import SQLAlchemy from flask_migrate import Migrate from flask_login import LoginManager, current_user -from dotenv import load_dotenv +from flask_security import Security +from flask_principal import Principal, Permission, RoleNeed, Identity, identity_loaded, identity_changed from flask_wtf.csrf import CSRFProtect +from dotenv import load_dotenv +from functools import wraps import os +def permission_required(permission): + def decorator(f): + @wraps(f) + def decorated_function(*args, **kwargs): + if not permission.can(): + abort(403) + return f(*args, **kwargs) + return decorated_function + return decorator + db = SQLAlchemy() migrate = Migrate() login_manager = LoginManager() csrf = CSRFProtect() -def create_app(config_class=Config): +from app.models import User, Role, RoleUsers + +super_admin_permission = Permission(RoleNeed('super-admin')) +admin_permission = Permission(RoleNeed('admin')) +user_permission = Permission(RoleNeed('user')) + +def create_app(config_class=Config): app = Flask(__name__) app.config.from_object(config_class) - #Only enabled when DEVELOPMENT_MODE in .env is set to true + # Only enabled when DEVELOPMENT_MODE in .env is set to true development_mode = os.getenv("DEVELOPMENT_MODE") - print(development_mode) - - #if (development_mode.lower() == 'true'): - app.config['DEBUG'] = True + if development_mode and development_mode.lower() == 'true': + app.config['DEBUG'] = True load_dotenv() db_host = os.getenv("DATABASE_HOST") @@ -31,22 +46,39 @@ def create_app(config_class=Config): db_password = os.getenv("DATABASE_PASSWORD") db_name = os.getenv("DATABASE_NAME") - app.config['SECRET_KEY'] = 'tnm]H+akmfnf_#PT>i|(Qo4LT@+n£9"~e3' - - app.config['SQLALCHEMY_DATABASE_URI'] = f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}".format(db_user=db_user, db_password=db_password, db_host=db_host, db_name=db_name) - print(app.config['SQLALCHEMY_DATABASE_URI']) + app.config['SECRET_KEY'] = os.getenv('SECRET_KEY') or 'your_secret_key' + app.config['SQLALCHEMY_DATABASE_URI'] = f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}" + db.init_app(app) - - #Run Flask migrations if any available migrate.init_app(app, db) - # Register blueprints and url prefixes + # Use RoleUsers.get_datastore() instead of RoleUsers.user_datastore + security = Security(app, RoleUsers.get_datastore()) + principal = Principal(app) + + # Register blueprints and URL prefixes register_blueprints(app) - #Protect internal endpoints from external use + # Protect internal endpoints from external use csrf.init_app(app) - - # Add any vars needed accessible through all templates + + # Identity loader + @identity_loaded.connect_via(app) + # Identity loader + @identity_loaded.connect_via(app) + def on_identity_loaded(sender, identity): + identity.user = current_user + if current_user.is_authenticated: + identity.provides.add(RoleNeed('user')) + for role in current_user.roles: + identity.provides.add(RoleNeed(role.name)) + # Should only be allocated to the root account, used for changing users to user -> admin + if role.name == 'super-admin': + identity.provides.add(RoleNeed('admin')) + identity.provides.add(RoleNeed('user')) + + + # Add global template variables @app.context_processor def set_global_html_variable_values(): try: @@ -55,22 +87,37 @@ def create_app(config_class=Config): else: user_in_session = False - template_config = {'user_in_session': user_in_session} + template_config = { + 'user_in_session': user_in_session, + 'admin_permission': g.admin_permission, + 'user_permission': g.user_permission, + 'super_admin_permission': g.super_admin_permission + } return template_config except Exception as e: - # print(f"Error in context processor: {e}") - return {'user_in_session': False} # Fallback, to create logging to record such failures (database corrupted etc.) - - + return { + 'user_in_session': False, + 'admin_permission': g.admin_permission, + 'user_permission': g.user_permission, + 'super_admin_permission': g.super_admin_permission + } + + @app.errorhandler(Exception) def handle_exception(e): app.logger.error(f"Unhandled exception: {e}") session['error_message'] = str(e) return redirect(url_for('errors.quandary')) - - if __name__ == "__main__": - app.run(use_reloader=True, debug=True) - + + @app.before_request + def before_request(): + g.admin_permission = admin_permission + g.user_permission = user_permission + g.super_admin_permission = super_admin_permission + if current_user.is_authenticated: + identity_changed.send(current_user._get_current_object(), identity=Identity(current_user.id)) + + login_manager.login_view = 'profile.login' login_manager.init_app(app) @@ -78,7 +125,6 @@ def create_app(config_class=Config): @login_manager.user_loader def load_user(user_id): - from .models import User return User.query.get(int(user_id)) def register_blueprints(app): @@ -93,7 +139,3 @@ def register_blueprints(app): for module_name, url_prefix in blueprints: module = __import__(f'app.{module_name}', fromlist=['bp']) app.register_blueprint(module.bp, url_prefix=url_prefix) - - - - diff --git a/app/bookings/routes.py b/app/bookings/routes.py index 39129c8ff19ccb8f47db940e7181a31075a18487..92779b487c4793013914a6293a9c20450040a932 100644 --- a/app/bookings/routes.py +++ b/app/bookings/routes.py @@ -1,8 +1,10 @@ -from flask import render_template, redirect, url_for +from flask import render_template, redirect, url_for, g from app.bookings import bp from app.models import Listings, ListingImages +from app import admin_permission, permission_required @bp.route('/home') +@permission_required(admin_permission) def index(): listing_ids = [] top_listings = Listings.get_top_listings(5) diff --git a/app/models/__init__.py b/app/models/__init__.py index 144c367fd0b277965a379ac840d39fb779bc4177..0e4cd79972ae43b67102256add861bf191b4760a 100644 --- a/app/models/__init__.py +++ b/app/models/__init__.py @@ -1,4 +1,6 @@ #Importing database models from .user import User from .listings import Listings -from .listing_images import ListingImages \ No newline at end of file +from .listing_images import ListingImages +from .role import Role +from .role_users import RoleUsers \ No newline at end of file diff --git a/app/models/role.py b/app/models/role.py new file mode 100644 index 0000000000000000000000000000000000000000..2b5f65b07612d05ed6b3a5f5e611570d296a9579 --- /dev/null +++ b/app/models/role.py @@ -0,0 +1,10 @@ +from flask_security import RoleMixin +from app import db + +class Role(RoleMixin, db.Model): + __tablename__ = 'roles' + + id = db.Column(db.Integer, primary_key=True) + name = db.Column(db.String(80), unique=True) + description = db.Column(db.String(255)) + diff --git a/app/models/role_users.py b/app/models/role_users.py new file mode 100644 index 0000000000000000000000000000000000000000..ed3275773405771c621149c587555d6cb8cf8862 --- /dev/null +++ b/app/models/role_users.py @@ -0,0 +1,22 @@ +from flask_security import SQLAlchemyUserDatastore +from app import db + +class RoleUsers: + roles_users = db.Table('roles_users', + db.Column('user_id', db.Integer(), db.ForeignKey('users.id'), primary_key=True, index=True), + db.Column('role_id', db.Integer(), db.ForeignKey('roles.id')) + ) + + @staticmethod + def get_datastore(): + from app.models.role import Role + from app.models.user import User + return SQLAlchemyUserDatastore(db, User, Role) + + @staticmethod + def add_role_to_user(user, role_name): + from app.models.role import Role + role = Role.query.filter_by(name=role_name).first() + if role and role not in user.roles: + user.roles.append(role) + db.session.commit() diff --git a/app/models/user.py b/app/models/user.py index 40661cf9ab7fd340b938af0ede2d5e418013c9f2..5fd1487a59c7beb4b6a4d056dcd768eb3fe3c3c3 100644 --- a/app/models/user.py +++ b/app/models/user.py @@ -2,6 +2,8 @@ from flask import request, jsonify from flask_login import UserMixin from werkzeug.security import generate_password_hash, check_password_hash from app import db +import os +# Avoid importing Role and RoleUsers here to prevent circular import class User(UserMixin, db.Model): __tablename__ = 'users' @@ -10,19 +12,28 @@ class User(UserMixin, db.Model): username = db.Column(db.String(255), nullable=False, unique=True) email = db.Column(db.String(255), nullable=False, unique=True) password = db.Column(db.String(255), nullable=False) - role_id = db.Column(db.SmallInteger(), nullable=False) - api_token = db.Column(db.String(255), nullable=True, unique=True) - token_expiry = db.Column(db.DateTime(), nullable=True) + fs_uniquifier = db.Column(db.String(64), unique=True, nullable=False) # Add fs_uniquifier field + + # Import Role and RoleUsers only when defining the roles relationship + from app.models.role_users import RoleUsers + from app.models.role import Role + roles = db.relationship('Role', secondary=RoleUsers.roles_users, backref=db.backref('users', lazy='dynamic')) @classmethod - def create_user(cls, username, email, password, role_id = 1): # Role ID 1 is default for standard users - hashed_password = generate_password_hash(password, method='pbkdf2:sha256') - new_user = cls(username=username, email=email, password=hashed_password, role_id=role_id) + def create_user(cls, username, email, password, role_name='user'): + from app.models import Role + hashed_password = generate_password_hash(password, method='pbkdf2:sha256') + new_user = cls(username=username, email=email, password=hashed_password, fs_uniquifier=os.urandom(32).hex()) + + role = Role.query.filter_by(name=role_name).first() + if role: + new_user.roles.append(role) db.session.add(new_user) db.session.commit() - + return new_user + @classmethod def search_user_id(cls, user_id): return cls.query.get(user_id) @@ -44,4 +55,4 @@ class User(UserMixin, db.Model): user = cls.search_user_by_email(email) if user is None: - raise ValueError("Error") \ No newline at end of file + raise ValueError("Error") diff --git a/migrations/versions/22de5b143d05_create_user_roles.py b/migrations/versions/22de5b143d05_create_user_roles.py new file mode 100644 index 0000000000000000000000000000000000000000..5e6a5869f16c6b4644fa9222896cee3d46486009 --- /dev/null +++ b/migrations/versions/22de5b143d05_create_user_roles.py @@ -0,0 +1,34 @@ +"""Create user roles + +Revision ID: 22de5b143d05 +Revises: 9a8cc1906445 +Create Date: 2025-01-06 13:40:11.307880 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.sql import table, column + +# revision identifiers, used by Alembic. +revision = '22de5b143d05' +down_revision = '9a8cc1906445' +branch_labels = None +depends_on = None + +roles_table = table('roles', + column('id', sa.Integer), + column('name', sa.String), + column('description', sa.String) +) + +def upgrade(): + roles = [ + {'name': 'super-admin', 'description': 'Super Admin, all admin perms and can create new admins'}, + {'name': 'admin', 'description': 'Can create/delete and modify bookings'}, + {'name': 'user', 'description': 'Standard user'} + ] + + op.bulk_insert(roles_table, roles) + +def downgrade(): + op.execute('DELETE FROM roles WHERE name IN ("super-admin", "admin", "user")') diff --git a/migrations/versions/9a8cc1906445_add_fs_uniquifier_field_to_user_model.py b/migrations/versions/9a8cc1906445_add_fs_uniquifier_field_to_user_model.py new file mode 100644 index 0000000000000000000000000000000000000000..7bacfee528a6aff984243944fedccc2f36974923 --- /dev/null +++ b/migrations/versions/9a8cc1906445_add_fs_uniquifier_field_to_user_model.py @@ -0,0 +1,89 @@ +"""Add fs_uniquifier field to User model + +Revision ID: 9a8cc1906445 +Revises: 68d89ef13132 +Create Date: 2025-01-06 12:52:57.272220 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import mysql +import os + +# revision identifiers, used by Alembic. +revision = '9a8cc1906445' +down_revision = '68d89ef13132' +branch_labels = None +depends_on = None + +def column_exists(table_name, column_name): + inspector = sa.inspect(op.get_bind()) + return column_name in [col['name'] for col in inspector.get_columns(table_name)] + +def index_exists(table_name, index_name): + inspector = sa.inspect(op.get_bind()) + indexes = inspector.get_indexes(table_name) + return any(index['name'] == index_name for index in indexes) + +def upgrade(): + # Conditionally create roles table + if not op.get_bind().dialect.has_table(op.get_bind(), "roles"): + op.create_table('roles', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('name', sa.String(length=80), nullable=True), + sa.Column('description', sa.String(length=255), nullable=True), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('name') + ) + + # Conditionally create roles_users table + if not op.get_bind().dialect.has_table(op.get_bind(), "roles_users"): + op.create_table('roles_users', + sa.Column('user_id', sa.Integer(), nullable=True), + sa.Column('role_id', sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(['role_id'], ['roles.id']), + sa.ForeignKeyConstraint(['user_id'], ['users.id']) + ) + + with op.batch_alter_table('listing_images', schema=None) as batch_op: + batch_op.alter_column('main_image', + existing_type=mysql.TINYINT(display_width=1), + type_=sa.SmallInteger(), + existing_nullable=False, + existing_server_default=sa.text("'0'")) + + # Assign unique values to fs_uniquifier for existing users before adding the unique constraint + conn = op.get_bind() + users = conn.execute(sa.text("SELECT id FROM users WHERE fs_uniquifier IS NULL OR fs_uniquifier = ''")).fetchall() + for user in users: + conn.execute(sa.text("UPDATE users SET fs_uniquifier = :fs_uniquifier WHERE id = :id"), {'fs_uniquifier': os.urandom(32).hex(), 'id': user.id}) + + with op.batch_alter_table('users', schema=None) as batch_op: + if index_exists('users', 'api_token'): + batch_op.drop_index('api_token') + batch_op.create_unique_constraint(None, ['fs_uniquifier']) + if column_exists('users', 'token_expiry'): + batch_op.drop_column('token_expiry') + if column_exists('users', 'api_token'): + batch_op.drop_column('api_token') + if column_exists('users', 'role_id'): + batch_op.drop_column('role_id') + +def downgrade(): + with op.batch_alter_table('users', schema=None) as batch_op: + batch_op.add_column(sa.Column('role_id', mysql.SMALLINT(), server_default=sa.text("'1'"), autoincrement=False, nullable=False)) + batch_op.add_column(sa.Column('api_token', mysql.VARCHAR(length=255), nullable=True)) + batch_op.add_column(sa.Column('token_expiry', mysql.DATETIME(), nullable=True)) + batch_op.drop_constraint(None, type_='unique') + batch_op.create_index('api_token', ['api_token'], unique=True) + batch_op.drop_column('fs_uniquifier') + + with op.batch_alter_table('listing_images', schema=None) as batch_op: + batch_op.alter_column('main_image', + existing_type=sa.SmallInteger(), + type_=mysql.TINYINT(display_width=1), + existing_nullable=False, + existing_server_default=sa.text("'0'")) + + op.drop_table('roles_users') + op.drop_table('roles') diff --git a/migrations/versions/ad8ca3c3dfaa_add_composite_primary_key_and_index_to_.py b/migrations/versions/ad8ca3c3dfaa_add_composite_primary_key_and_index_to_.py new file mode 100644 index 0000000000000000000000000000000000000000..6ffcbf94f8f82a0daec75b7331366ac24c3f77c6 --- /dev/null +++ b/migrations/versions/ad8ca3c3dfaa_add_composite_primary_key_and_index_to_.py @@ -0,0 +1,35 @@ +"""Add composite primary key and index to roles_users + +Revision ID: ad8ca3c3dfaa +Revises: 22de5b143d05 +Create Date: 2025-01-06 13:56:13.747100 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'ad8ca3c3dfaa' +down_revision = '22de5b143d05' +branch_labels = None +depends_on = None + + +def upgrade(): + op.drop_table('roles_users') + + op.create_table( + 'roles_users', + sa.Column('user_id', sa.Integer, sa.ForeignKey('users.id'), primary_key=True, index=True), + sa.Column('role_id', sa.Integer, sa.ForeignKey('roles.id'), primary_key=True) + ) + +def downgrade(): + op.drop_table('roles_users') + + op.create_table( + 'roles_users', + sa.Column('user_id', sa.Integer, sa.ForeignKey('users.id')), + sa.Column('role_id', sa.Integer, sa.ForeignKey('roles.id')) + ) diff --git a/requirements.txt b/requirements.txt index eebb30e827ff7a6a96ee58cf463372a60942ccb4..19f100b7c8d8625dfff47bdad0b237b2a169217d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,6 @@ jinja2 cryptography flask-login debugpy -flask-wtf \ No newline at end of file +flask-wtf +flask-security +flask-principal \ No newline at end of file