diff --git a/rust_crud_api/Cargo.toml b/rust_crud_api/Cargo.toml index 1714ef2decb1555a9dedb243ae149e754d982f49..ef75e3ba299654186d7413aace2d9a1784967f49 100644 --- a/rust_crud_api/Cargo.toml +++ b/rust_crud_api/Cargo.toml @@ -12,4 +12,6 @@ pyo3 = { version = "0.18", features = ["extension-module"] } postgres = "0.19" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" -jsonwebtoken = "8.2.0" +jsonwebtoken = "8.2.0" +argon2 = "0.4" +rand = "0.8" diff --git a/rust_crud_api/src/lib.rs b/rust_crud_api/src/lib.rs index 0d446d7835ba4324103216efaef85504c5c3baec..fbe7ef2d856cfcc2e72a0dd7a4115066da1eecb1 100644 --- a/rust_crud_api/src/lib.rs +++ b/rust_crud_api/src/lib.rs @@ -67,6 +67,37 @@ fn verify_jwt(token: String, secret: String) -> PyResult<PyObject> { Ok(dict.to_object(py)) } +// Pasword Hashing with Argon2 + +#[pyfunction] +fn hash_password(password: &str) -> PyResult<String> { + use argon2::{Argon2, PasswordHasher}; + use argon2::password_hash::{SaltString, rand_core::OsRng}; + + // Generate a random salt using OsRng. + let salt = SaltString::generate(&mut OsRng); + // Create a default Argon2 instance. + let argon2 = Argon2::default(); + // Hash the password using the salt. + let password_hash = argon2.hash_password(password.as_bytes(), &salt) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))? + .to_string(); + Ok(password_hash) +} + +#[pyfunction] +fn verify_password(hash: &str, password: &str) -> PyResult<bool> { + use argon2::{Argon2, PasswordVerifier}; + use argon2::password_hash::PasswordHash; + + let argon2 = Argon2::default(); + let parsed_hash = PasswordHash::new(hash) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + argon2.verify_password(password.as_bytes(), &parsed_hash) + .map_err(|e| PyRuntimeError::new_err(e.to_string())) + .map(|_| true) +} + /// A helper function to convert postgres::Error into a Python RuntimeError. fn pg_err(e: postgres::Error) -> PyErr { PyRuntimeError::new_err(e.to_string()) @@ -81,7 +112,8 @@ fn init_db(db_url: &str) -> PyResult<()> { CREATE TABLE IF NOT EXISTS users ( id SERIAL PRIMARY KEY, name VARCHAR NOT NULL, - email VARCHAR NOT NULL + email VARCHAR NOT NULL, + password_hash VARCHAR NOT NULL ); CREATE TABLE IF NOT EXISTS groups ( id SERIAL PRIMARY KEY, @@ -99,11 +131,12 @@ fn init_db(db_url: &str) -> PyResult<()> { /// Create a new user by inserting into the database. #[pyfunction] -fn create_user(db_url: &str, name: &str, email: &str) -> PyResult<()> { +fn create_user(db_url: &str, name: &str, email: &str, password: &str) -> PyResult<()> { + let password_hash = hash_password(password)?; let mut client = Client::connect(db_url, NoTls).map_err(pg_err)?; client.execute( - "INSERT INTO users (name, email) VALUES ($1, $2)", - &[&name, &email] + "INSERT INTO users (name, email, password_hash) VALUES ($1, $2, $3)", + &[&name, &email, &password_hash] ).map_err(pg_err)?; Ok(()) } @@ -173,6 +206,24 @@ fn delete_user(db_url: &str, user_id: i32) -> PyResult<bool> { Ok(deleted > 0) } +///Verify credentials +///Retrieve a stored password hash for the email and compare. +#[pyfunction] +fn verify_user(db_url: &str, email: &str, password: & str) -> PyResult<bool> { + let mut client = Client::connect(db_url, NoTls).map_err(pg_err)?; + let row_opt = client.query_opt( + "SELECT password_hash FROM users WHERE email = $1", + &[&email] + ).map_err(pg_err)?; + if let Some(row) = row_opt { + let stored_hash: String = row.get(0); + verify_password(&stored_hash, password) + } else { + Ok(false) + } +} + + /// Create a group #[pyfunction] fn create_group(db_url: &str, name: &str) -> PyResult<()> { @@ -270,6 +321,8 @@ fn rust_crud_api(_py: Python, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(get_group_members, m)?)?; m.add_function(wrap_pyfunction!(generate_jwt, m)?)?; m.add_function(wrap_pyfunction!(verify_jwt, m)?)?; + m.add_function(wrap_pyfunction!(hash_password, m)?)?; + m.add_function(wrap_pyfunction!(verify_password, m)?)?; Ok(()) }