diff --git a/Cargo.lock b/Cargo.lock index 82114de..79a4fbc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -131,6 +131,19 @@ dependencies = [ "tokio", ] +[[package]] +name = "bcrypt" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28d1c9c15093eb224f0baa400f38fcd713fc1391a6f1c389d886beef146d60a3" +dependencies = [ + "base64 0.21.2", + "blowfish", + "getrandom", + "subtle", + "zeroize", +] + [[package]] name = "binascii" version = "0.1.4" @@ -158,6 +171,16 @@ dependencies = [ "generic-array", ] +[[package]] +name = "blowfish" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e412e2cd0f2b2d93e02543ceae7917b3c70331573df19ee046bcbc35e45e87d7" +dependencies = [ + "byteorder", + "cipher", +] + [[package]] name = "bumpalo" version = "3.14.0" @@ -191,6 +214,16 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "cipher" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" +dependencies = [ + "crypto-common", + "inout", +] + [[package]] name = "clap" version = "4.4.11" @@ -229,6 +262,7 @@ name = "common" version = "0.2.0" dependencies = [ "anyhow", + "bcrypt", "binascii", "diesel", "diesel-async", @@ -710,6 +744,15 @@ dependencies = [ "hashbrown", ] +[[package]] +name = "inout" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0c10553d664a4d0bcff9f4215d0aac67a639cc68ef660840afe309b807bc9f5" +dependencies = [ + "generic-array", +] + [[package]] name = "is-terminal" version = "0.4.9" @@ -2087,3 +2130,9 @@ checksum = "6c830786f7720c2fd27a1a0e27a709dbd3c4d009b56d098fc742d4f4eab91fe2" dependencies = [ "memchr", ] + +[[package]] +name = "zeroize" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d" diff --git a/common/Cargo.toml b/common/Cargo.toml index 12f372e..379d679 100644 --- a/common/Cargo.toml +++ b/common/Cargo.toml @@ -22,6 +22,7 @@ diesel = {version = "2.1", optional = true, default-features = false, features = diesel-async = {version = "0.3", optional = true, features = ["postgres", "bb8"]} anyhow = "1.0" prometheus-client = "0.21.2" +bcrypt = "0.15.0" [features] default = ["web", "db"] diff --git a/common/migrations/2023-02-01-210714_init/up.sql b/common/migrations/2023-02-01-210714_init/up.sql index c545f5a..e89f6de 100644 --- a/common/migrations/2023-02-01-210714_init/up.sql +++ b/common/migrations/2023-02-01-210714_init/up.sql @@ -37,3 +37,9 @@ CREATE TABLE IF NOT EXISTS funcs ( CREATE UNIQUE INDEX IF NOT EXISTS funcs_db ON funcs(chksum, db_id); CREATE INDEX IF NOT EXISTS funcs_ranking ON funcs(chksum,rank); CREATE INDEX IF NOT EXISTS func_chksum ON funcs(chksum); + +CREATE TABLE IF NOT EXISTS auth_users ( + id SERIAL PRIMARY KEY, + username VARCHAR(32) NOT NULL UNIQUE, + password_hash VARCHAR(128) NOT NULL +); diff --git a/common/src/db/mod.rs b/common/src/db/mod.rs index 8f83bbc..8899be1 100644 --- a/common/src/db/mod.rs +++ b/common/src/db/mod.rs @@ -3,11 +3,12 @@ use postgres_native_tls::MakeTlsConnector; use serde::Serialize; use tokio_postgres::{tls::MakeTlsConnect, Socket, NoTls}; use std::{collections::HashMap}; +use bcrypt::{hash, verify, DEFAULT_COST}; use crate::async_drop::{AsyncDropper, AsyncDropGuard}; mod schema_auto; pub mod schema; -use diesel::{upsert::excluded, ExpressionMethods, QueryDsl, NullableExpressionMethods, sql_types::{Array, Binary, VarChar, Integer}, query_builder::{QueryFragment, Query}}; +use diesel::{upsert::excluded, ExpressionMethods, QueryDsl, NullableExpressionMethods, sql_types::{Array, Binary, VarChar, Integer}, query_builder::{QueryFragment, Query}, result::Error::NotFound}; use diesel_async::RunQueryDsl; pub type DynConfig = dyn crate::config::HasConfig + Send + Sync; @@ -337,6 +338,76 @@ impl Database { Ok(results) } + pub async fn register_user(&self, username: &str, password: &str) -> Result<(), anyhow::Error> { + let conn = &mut self.diesel.get().await?; + + // Hash the password + let hashed_password = hash(password, DEFAULT_COST)?; + + // Insert new user + diesel::insert_into(schema::auth_users::table) + .values(( + schema::auth_users::username.eq(username), + schema::auth_users::password_hash.eq(hashed_password), + )) + .execute(conn) + .await?; + + Ok(()) + } + + pub async fn change_user_password(&self, username: &str, new_password: &str) -> Result<(), anyhow::Error> { + let conn = &mut self.diesel.get().await?; + + // Hash the new password + let hashed_password = hash(new_password, DEFAULT_COST)?; + + // Update the user's password + diesel::update(schema::auth_users::table.filter(schema::auth_users::username.eq(username))) + .set(schema::auth_users::password_hash.eq(hashed_password)) + .execute(conn) + .await?; + + Ok(()) + } + + pub async fn remove_user(&self, username: &str) -> Result<(), anyhow::Error> { + let conn = &mut self.diesel.get().await?; + + // Execute the delete query + diesel::delete(schema::auth_users::table.filter(schema::auth_users::username.eq(username))) + .execute(conn) + .await?; + + Ok(()) + } + + pub async fn auth_user(&self, login: &str, password: &str) -> Result { + let conn = &mut self.diesel.get().await?; + + match schema::auth_users::table + .select((schema::auth_users::username, schema::auth_users::password_hash)) + .filter(schema::auth_users::username.eq(login)) + .first::<(String, String)>(conn) + .await { + Ok((_, password_hash)) => { + // If user is found, verify the password + match verify(password, &password_hash) { + Ok(valid) => Ok(valid), + Err(e) => Err(anyhow::Error::new(e)), + } + }, + Err(diesel::result::Error::NotFound) => { + // If user is not found, return false + Ok(false) + }, + Err(e) => { + // For all other errors, return the error + Err(anyhow::Error::new(e)) + } + } + } + pub async fn get_files_with_func(&self, func: &[u8]) -> Result>, anyhow::Error> { let conn = &mut self.diesel.get().await?; diff --git a/common/src/db/schema_auto.rs b/common/src/db/schema_auto.rs index 3700def..ccd69b2 100644 --- a/common/src/db/schema_auto.rs +++ b/common/src/db/schema_auto.rs @@ -1,5 +1,13 @@ // @generated automatically by Diesel CLI. +diesel::table! { + auth_users (id) { + id -> Int4, + username -> Varchar, + password_hash -> Varchar, + } +} + diesel::table! { dbs (id) { id -> Int4, @@ -46,6 +54,7 @@ diesel::joinable!(dbs -> users (user_id)); diesel::joinable!(funcs -> dbs (db_id)); diesel::allow_tables_to_appear_in_same_query!( + auth_users, dbs, files, funcs, diff --git a/lumen/src/main.rs b/lumen/src/main.rs index adc649e..03b68d0 100644 --- a/lumen/src/main.rs +++ b/lumen/src/main.rs @@ -200,15 +200,25 @@ async fn handle_client(state: &SharedState, m }).inc(); if let Some(ref creds) = creds { - if creds.username != "guest" { + + let auth_state = state.db.auth_user(creds.username, creds.password).await; + + if !auth_state.is_ok() || !auth_state.unwrap() { // Only allow "guest" to connect for now. rpc::RpcMessage::Fail(rpc::RpcFail { code: 1, - message: &format!("{server_name}: invalid username or password. Try logging in with `guest` instead."), + message: &format!("{server_name}: invalid username or password."), }).async_write(&mut stream).await?; return Ok(()); } } + else { + rpc::RpcMessage::Fail(rpc::RpcFail { + code: 1, + message: &format!("{server_name}: username and password should be specified."), + }).async_write(&mut stream).await?; + return Ok(()); + } let resp = match hello.protocol_version { 0..=4 => rpc::RpcMessage::Ok(()), @@ -329,6 +339,33 @@ fn main() { .default_value("config.toml") .help("Configuration file path") ) + .subcommand( + clap::Command::new("add_user") + .about("Adds a new user") + .arg(Arg::new("username") + .help("The username for the new user") + .required(true)) + .arg(Arg::new("password") + .help("The password for the new user") + .required(true)) + ) + .subcommand( + clap::Command::new("change_user_pass") + .about("Changes a user's password") + .arg(Arg::new("username") + .help("The username of the user") + .required(true)) + .arg(Arg::new("new_password") + .help("The new password for the user") + .required(true)) + ) + .subcommand( + clap::Command::new("remove_user") + .about("Removes user") + .arg(Arg::new("username") + .help("The username of the user") + .required(true)) + ) .get_matches(); let config = { @@ -347,7 +384,7 @@ fn main() { exit(1); }, }; - + let db = rt.block_on(async { match Database::open(&config.database).await { Ok(v) => v, @@ -367,8 +404,46 @@ fn main() { metrics: common::metrics::Metrics::default(), }); - let tls_acceptor; + let subcommand_future = async { + match matches.subcommand() { + Some(("add_user", sub_m)) => { + let username = sub_m.get_one::("username").unwrap(); + let password = sub_m.get_one::("password").unwrap(); + + match state.db.register_user(username, password).await { + Ok(_) => println!("User added successfully"), + Err(e) => eprintln!("Error adding user: {}", e), + } + exit(0); + }, + Some(("change_user_pass", sub_m)) => { + let username = sub_m.get_one::("username").unwrap(); + let new_password = sub_m.get_one::("new_password").unwrap(); + match state.db.change_user_password(username, new_password).await { + Ok(_) => println!("User password changed successfully"), + Err(e) => eprintln!("Error changing user password: {}", e), + } + exit(0); + }, + Some(("remove_user", sub_m)) => { + let username = sub_m.get_one::("username").unwrap(); + + match state.db.remove_user(username).await { + Ok(_) => println!("User removed successfully"), + Err(e) => eprintln!("Error removing user: {}", e), + } + exit(0); + }, + _ => { + + } + } + }; + + rt.block_on(subcommand_future); + + let tls_acceptor; if state.config.lumina.use_tls.unwrap_or_default() { let cert_path = &state.config.lumina.tls.as_ref().expect("tls section is missing").server_cert; let mut crt = match std::fs::read(cert_path) { @@ -427,7 +502,7 @@ fn main() { }; info!("listening on {:?} secure={}", server.local_addr().unwrap(), tls_acceptor.is_some()); - + serve(server, tls_acceptor, state, exit_signal_rx).await; };