Adding a service to handle caching

This commit is contained in:
Z. Charles Dziura 2024-10-05 08:09:46 -04:00
parent 8d4a987f0d
commit d41a92d07a
28 changed files with 466 additions and 115 deletions

View file

@ -1,4 +1,5 @@
ASSETS_DIR=/home/zcdziura/Documents/Projects/debt-pirate/api/assets ASSETS_DIR=/home/zcdziura/Documents/Projects/debt-pirate/api/assets
CACHE_URL=redis://debt_pirate:H553jOui2734@192.168.122.251:6379
DATABASE_URL=postgres://debt_pirate:HRURqlUmtjIy@192.168.122.251/debt_pirate DATABASE_URL=postgres://debt_pirate:HRURqlUmtjIy@192.168.122.251/debt_pirate
HOSTNAME=localhost HOSTNAME=localhost
MAINTENANCE_USER_ACCOUNT=debt_pirate:HRURqlUmtjIy MAINTENANCE_USER_ACCOUNT=debt_pirate:HRURqlUmtjIy

View file

@ -11,10 +11,12 @@ axum = { version = "0.7", features = [
"ws", "ws",
] } ] }
base64 = "0.22" base64 = "0.22"
bb8-redis = "0.17"
blake3 = { version = "1.5", features = ["serde"] }
dotenvy = "0.15" dotenvy = "0.15"
futures = "0.3" futures = "0.3"
http = "1.0" http = "1.0"
humantime = "2.1.0" humantime = "2.1"
humantime-serde = "1.1" humantime-serde = "1.1"
hyper = { version = "1.1", features = ["full"] } hyper = { version = "1.1", features = ["full"] }
lettre = { version = "0.11", default-features = false, features = [ lettre = { version = "0.11", default-features = false, features = [
@ -27,6 +29,7 @@ lettre = { version = "0.11", default-features = false, features = [
] } ] }
num_cpus = "1.16" num_cpus = "1.16"
pasetors = "0.7" pasetors = "0.7"
redis = { version = "0.27", features = ["aio"] }
serde = { version = "1.0", features = ["derive", "rc", "std"] } serde = { version = "1.0", features = ["derive", "rc", "std"] }
serde_json = "1.0" serde_json = "1.0"
serde_with = "3.9" serde_with = "3.9"
@ -37,11 +40,11 @@ sqlx = { version = "0.8", features = [
"runtime-tokio", "runtime-tokio",
] } ] }
syslog-tracing = "0.3.1" syslog-tracing = "0.3.1"
time = { version = "0.3.36", features = ["formatting", "macros"] } time = { version = "0.3", features = ["formatting", "macros"] }
tokio = { version = "1.35", features = ["full"] } tokio = { version = "1.35", features = ["full"] }
tower = "0.5" tower = "0.5"
tower-http = { version = "0.6", features = ["full"] } tower-http = { version = "0.6", features = ["full"] }
tracing = "0.1.40" tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", features = ["env-filter", "time"] } tracing-subscriber = { version = "0.3", features = ["env-filter", "time"] }
ulid = "1.1" url = { version = "2.5.2", features = ["expose_internals"] }
uuid = { version = "1.10", features = ["serde", "v7"] } uuid = { version = "1.10", features = ["serde", "v7"] }

View file

@ -7,12 +7,12 @@ use crate::models::AppError;
pub type DbPool = Pool<Postgres>; pub type DbPool = Pool<Postgres>;
pub async fn create_connection_pool(connection_uri: &str) -> DbPool { pub async fn create_db_connection_pool(connection_uri: String) -> DbPool {
let num_cpus = num_cpus::get() as u32; let num_cpus = num_cpus::get() as u32;
PgPoolOptions::new() PgPoolOptions::new()
.max_connections(num_cpus) .max_connections(num_cpus)
.connect(connection_uri) .connect(connection_uri.as_str())
.await .await
.unwrap() .unwrap()
} }

View file

@ -59,13 +59,35 @@ pub async fn insert_new_user(
}) })
} }
#[derive(Debug, FromRow)]
pub struct UserIdAndHashedPassword {
pub id: i32,
pub password: String,
}
pub async fn get_username_and_password_by_username(
pool: &DbPool,
username: String,
) -> Result<UserIdAndHashedPassword, AppError> {
sqlx::query_as::<_, UserIdAndHashedPassword>(
"SELECT id, password FROM public.user WHERE username = $1;",
)
.bind(username)
.fetch_one(pool)
.await
.map_err(|err| {
error!(%err, "Unable to find user");
AppError::from(err)
})
}
pub async fn verify_user(pool: &DbPool, user_id: i32) -> Result<(), AppError> { pub async fn verify_user(pool: &DbPool, user_id: i32) -> Result<(), AppError> {
sqlx::query("UPDATE public.user SET status_id = 1, updated_at = now() WHERE id = $1;") sqlx::query("UPDATE public.user SET status_id = 1, updated_at = now() WHERE id = $1;")
.bind(user_id) .bind(user_id)
.execute(pool) .execute(pool)
.await .await
.map_err(|err| { .map_err(|err| {
eprintln!("Error verifying user with id '{user_id}'."); error!(%err, user_id, "Error verifying user");
AppError::from(err) AppError::from(err)
}) })
.map(|_| ()) .map(|_| ())

View file

@ -7,11 +7,13 @@ mod models;
mod requests; mod requests;
mod services; mod services;
use db::{create_connection_pool, run_migrations}; use db::{create_db_connection_pool, run_migrations};
use requests::start_app; use requests::start_app;
use services::{initialize_logger, start_emailer_service, UserConfirmationMessage}; use services::{
create_cache_connection_pool, initialize_logger, start_emailer_service, UserConfirmationMessage,
};
use tokio::runtime::Handle; use tokio::runtime::Handle;
use tracing::info; use tracing::{error, info};
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
@ -28,11 +30,18 @@ async fn main() {
initialize_logger(&env); initialize_logger(&env);
info!("Initializing database connection pool..."); info!("Initializing database connection pool...");
let pool = create_connection_pool(env.db_connection_uri()).await; let db_pool = create_db_connection_pool(env.db_connection_uri().to_string()).await;
info!("Database connection pool created successfully."); info!("Database connection pool created successfully.");
info!("Initializing cache service connection pool...");
let cache_pool = create_cache_connection_pool(env.cache_url().to_string())
.await
.inspect_err(|err| error!(?err))
.unwrap();
info!("Cache service connection pool created successfully.");
info!("Running database schema migrations..."); info!("Running database schema migrations...");
if let Err(err) = run_migrations(&pool).await { if let Err(err) = run_migrations(&db_pool).await {
eprintln!("{err:?}"); eprintln!("{err:?}");
process::exit(2); process::exit(2);
} }
@ -42,7 +51,7 @@ async fn main() {
start_emailer_service(Handle::current(), env.assets_dir(), rx); start_emailer_service(Handle::current(), env.assets_dir(), rx);
info!("Email service started successfully."); info!("Email service started successfully.");
if let Err(err) = start_app(pool, env).await { if let Err(err) = start_app(db_pool, cache_pool, env).await {
eprintln!("{err:?}"); eprintln!("{err:?}");
process::exit(3); process::exit(3);
} }

View file

@ -50,7 +50,7 @@ impl ApiResponse<()> {
} }
} }
pub fn error(error: &'static str) -> Self { pub fn _error(error: &'static str) -> Self {
Self { Self {
meta: None, meta: None,
data: None, data: None,

View file

@ -4,6 +4,8 @@ use std::{
}; };
use pasetors::{keys::SymmetricKey, version4::V4}; use pasetors::{keys::SymmetricKey, version4::V4};
use tracing::trace;
use url::Url;
use crate::services::UserConfirmationMessage; use crate::services::UserConfirmationMessage;
@ -12,7 +14,8 @@ use super::AppError;
#[derive(Clone)] #[derive(Clone)]
pub struct Environment { pub struct Environment {
assets_dir: PathBuf, assets_dir: PathBuf,
database_url: String, cache_url: Url,
database_url: Url,
email_sender: Sender<UserConfirmationMessage>, email_sender: Sender<UserConfirmationMessage>,
hostname: String, hostname: String,
port: u32, port: u32,
@ -29,6 +32,7 @@ impl Environment {
.filter_map(|item| item.ok()) .filter_map(|item| item.ok())
.for_each(|(key, value)| match key.as_str() { .for_each(|(key, value)| match key.as_str() {
"ASSETS_DIR" => builder.with_assets_dir(value), "ASSETS_DIR" => builder.with_assets_dir(value),
"CACHE_URL" => builder.with_cache_url(value),
"DATABASE_URL" => builder.with_database_url(value), "DATABASE_URL" => builder.with_database_url(value),
"HOSTNAME" => builder.with_hostname(value), "HOSTNAME" => builder.with_hostname(value),
"PORT" => builder.with_port(value), "PORT" => builder.with_port(value),
@ -46,6 +50,22 @@ impl Environment {
} }
} }
pub fn assets_dir(&self) -> &Path {
self.assets_dir.as_path()
}
pub fn cache_url(&self) -> &Url {
&self.cache_url
}
pub fn db_connection_uri(&self) -> &Url {
&self.database_url
}
pub fn email_sender(&self) -> &Sender<UserConfirmationMessage> {
&self.email_sender
}
pub fn hostname(&self) -> &str { pub fn hostname(&self) -> &str {
self.hostname.as_str() self.hostname.as_str()
} }
@ -54,22 +74,6 @@ impl Environment {
self.port self.port
} }
pub fn token_key(&self) -> &SymmetricKey<V4> {
&self.token_key
}
pub fn db_connection_uri(&self) -> &str {
self.database_url.as_str()
}
pub fn assets_dir(&self) -> &Path {
self.assets_dir.as_path()
}
pub fn email_sender(&self) -> &Sender<UserConfirmationMessage> {
&self.email_sender
}
pub fn rust_log(&self) -> &str { pub fn rust_log(&self) -> &str {
self.rust_log.as_str() self.rust_log.as_str()
} }
@ -77,12 +81,17 @@ impl Environment {
pub fn send_verification_email(&self) -> bool { pub fn send_verification_email(&self) -> bool {
self.send_verification_email self.send_verification_email
} }
pub fn token_key(&self) -> &SymmetricKey<V4> {
&self.token_key
}
} }
impl From<EnvironmentObjectBuilder> for Environment { impl From<EnvironmentObjectBuilder> for Environment {
fn from(builder: EnvironmentObjectBuilder) -> Self { fn from(builder: EnvironmentObjectBuilder) -> Self {
let EnvironmentObjectBuilder { let EnvironmentObjectBuilder {
assets_dir, assets_dir,
cache_url,
database_url, database_url,
email_sender, email_sender,
hostname, hostname,
@ -94,6 +103,7 @@ impl From<EnvironmentObjectBuilder> for Environment {
Self { Self {
assets_dir: assets_dir.unwrap(), assets_dir: assets_dir.unwrap(),
cache_url: cache_url.unwrap(),
database_url: database_url.unwrap(), database_url: database_url.unwrap(),
email_sender: email_sender.unwrap(), email_sender: email_sender.unwrap(),
hostname: hostname.unwrap(), hostname: hostname.unwrap(),
@ -108,7 +118,8 @@ impl From<EnvironmentObjectBuilder> for Environment {
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub struct EnvironmentObjectBuilder { pub struct EnvironmentObjectBuilder {
pub assets_dir: Option<PathBuf>, pub assets_dir: Option<PathBuf>,
pub database_url: Option<String>, pub cache_url: Option<Url>,
pub database_url: Option<Url>,
pub email_sender: Option<Sender<UserConfirmationMessage>>, pub email_sender: Option<Sender<UserConfirmationMessage>>,
pub hostname: Option<String>, pub hostname: Option<String>,
pub port: Option<u32>, pub port: Option<u32>,
@ -128,13 +139,20 @@ impl EnvironmentObjectBuilder {
pub fn uninitialized_variables(&self) -> Option<Vec<&'static str>> { pub fn uninitialized_variables(&self) -> Option<Vec<&'static str>> {
let mut missing_vars = [ let mut missing_vars = [
("HOSTNAME", self.hostname.as_deref()), ("HOSTNAME", self.hostname.as_deref()),
("DATABASE_URL", self.database_url.as_deref()),
("RUST_LOG", self.rust_log.as_deref()), ("RUST_LOG", self.rust_log.as_deref()),
] ]
.into_iter() .into_iter()
.filter_map(|(key, value)| value.map(|_| key).xor(Some(key))) .filter_map(|(key, value)| value.map(|_| key).xor(Some(key)))
.collect::<Vec<&'static str>>(); .collect::<Vec<&'static str>>();
if self.cache_url.is_none() {
missing_vars.push("CACHE_URL");
}
if self.database_url.is_none() {
missing_vars.push("DATABASE_URL");
}
if self.token_key.is_none() { if self.token_key.is_none() {
missing_vars.push("TOKEN_KEY"); missing_vars.push("TOKEN_KEY");
} }
@ -170,8 +188,28 @@ impl EnvironmentObjectBuilder {
}; };
} }
pub fn with_cache_url(&mut self, url: String) {
trace!(?url);
let cache_url = url
.parse::<Url>()
.expect("The 'CACHE_URL' variable is not in valid URI format");
trace!(?cache_url);
if cache_url.scheme().to_lowercase() != "redis" {
panic!("The 'CACHE_URL' must be a valid Redis connection string; it must use the 'redis://' scheme");
}
self.cache_url = Some(cache_url);
}
pub fn with_database_url(&mut self, url: String) { pub fn with_database_url(&mut self, url: String) {
self.database_url = Some(url); let database_url = url
.parse::<Url>()
.expect("The 'DATABASE_URL' variable is not in valid URI format");
self.database_url = Some(database_url);
} }
pub fn with_assets_dir(&mut self, assets_dir_path: String) { pub fn with_assets_dir(&mut self, assets_dir_path: String) {

View file

@ -1,8 +1,11 @@
use std::{borrow::Cow, error::Error, fmt::Display, io}; use std::{borrow::Cow, error::Error, fmt::Display, io};
use axum::response::IntoResponse; use axum::response::IntoResponse;
use bb8_redis::bb8::RunError;
use http::StatusCode; use http::StatusCode;
use redis::RedisError;
use sqlx::{error::DatabaseError, migrate::MigrateError, Error as SqlxError}; use sqlx::{error::DatabaseError, migrate::MigrateError, Error as SqlxError};
use tracing::trace;
use super::ApiResponse; use super::ApiResponse;
@ -11,6 +14,8 @@ pub struct AppError {
kind: ErrorKind, kind: ErrorKind,
} }
impl Error for AppError {}
impl AppError { impl AppError {
fn new(kind: ErrorKind) -> Self { fn new(kind: ErrorKind) -> Self {
Self { kind } Self { kind }
@ -20,17 +25,20 @@ impl AppError {
Self::new(ErrorKind::AppStartupError(error)) Self::new(ErrorKind::AppStartupError(error))
} }
pub fn connection_info(service_name: &'static str) -> Self {
Self::new(ErrorKind::ConnectionInfo(service_name))
}
pub fn duplicate_record(message: &str) -> Self { pub fn duplicate_record(message: &str) -> Self {
Self::new(ErrorKind::DuplicateRecord(message.to_owned())) Self::new(ErrorKind::DuplicateRecord(message.to_owned()))
} }
pub fn invalid_token() -> Self { pub fn invalid_password() -> Self {
Self::new(ErrorKind::InvalidToken) Self::new(ErrorKind::InvalidPassword)
} }
#[allow(dead_code)] pub fn invalid_token() -> Self {
pub fn invalid_token_audience(audience: &str) -> Self { Self::new(ErrorKind::InvalidToken)
Self::new(ErrorKind::InvalidTokenAudience(audience.to_owned()))
} }
pub fn missing_environment_variables(missing_vars: Vec<&'static str>) -> Self { pub fn missing_environment_variables(missing_vars: Vec<&'static str>) -> Self {
@ -74,11 +82,30 @@ impl From<SqlxError> for AppError {
} }
} }
impl From<RedisError> for AppError {
fn from(other: RedisError) -> Self {
trace!(err = ?other, "Cache error");
Self::new(ErrorKind::Cache(other.to_string()))
}
}
impl From<RunError<RedisError>> for AppError {
fn from(other: RunError<RedisError>) -> Self {
trace!(err = ?other, "Cache pool error");
Self::new(ErrorKind::Cache(other.to_string()))
}
}
impl Display for AppError { impl Display for AppError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.kind { match &self.kind {
ErrorKind::AppStartupError(err) => write!(f, "{err}"), ErrorKind::AppStartupError(err) => write!(f, "{err}"),
ErrorKind::Database => write!(f, "Unknown database error occurred."), ErrorKind::Cache(err) => write!(f, "{err}"),
ErrorKind::ConnectionInfo(service) => write!(
f,
"Unable to connect to '{service}' service; invalid connection string"
),
ErrorKind::Database => write!(f, "Unknown database error occurred"),
ErrorKind::DbMigration(err) => write!( ErrorKind::DbMigration(err) => write!(
f, f,
"Error occurred while initializing connection to database: {err}" "Error occurred while initializing connection to database: {err}"
@ -86,11 +113,8 @@ impl Display for AppError {
ErrorKind::DuplicateRecord(message) => { ErrorKind::DuplicateRecord(message) => {
write!(f, "Duplicate database record: {message}") write!(f, "Duplicate database record: {message}")
} }
ErrorKind::InvalidToken => write!(f, "The provided token is invalid."), ErrorKind::InvalidPassword => write!(f, "Invalid password"),
ErrorKind::InvalidTokenAudience(audience) => write!( ErrorKind::InvalidToken => write!(f, "The provided token is invalid"),
f,
"The provided token is not valid for this endpoint: '{audience}'."
),
ErrorKind::MissingEnvironmentVariables(missing_vars) => write!( ErrorKind::MissingEnvironmentVariables(missing_vars) => write!(
f, f,
"Missing required environment variables: {}", "Missing required environment variables: {}",
@ -99,22 +123,22 @@ impl Display for AppError {
ErrorKind::Sqlx(err) => write!(f, "{err}"), ErrorKind::Sqlx(err) => write!(f, "{err}"),
ErrorKind::TokenKey => write!( ErrorKind::TokenKey => write!(
f, f,
"Invalid PASETO symmetric key; must be in valid PASERK format." "Invalid PASETO symmetric key; must be in valid PASERK format"
), ),
} }
} }
} }
impl Error for AppError {}
#[derive(Debug)] #[derive(Debug)]
enum ErrorKind { enum ErrorKind {
AppStartupError(io::Error), AppStartupError(io::Error),
Cache(String),
ConnectionInfo(&'static str),
Database, Database,
DbMigration(MigrateError), DbMigration(MigrateError),
DuplicateRecord(String), DuplicateRecord(String),
InvalidPassword,
InvalidToken, InvalidToken,
InvalidTokenAudience(String),
MissingEnvironmentVariables(Vec<&'static str>), MissingEnvironmentVariables(Vec<&'static str>),
Sqlx(SqlxError), Sqlx(SqlxError),
TokenKey, TokenKey,
@ -127,7 +151,7 @@ impl IntoResponse for AppError {
StatusCode::CONFLICT, StatusCode::CONFLICT,
ApiResponse::new_with_error(self).into_json_response(), ApiResponse::new_with_error(self).into_json_response(),
), ),
&ErrorKind::InvalidToken | &ErrorKind::InvalidTokenAudience(_) => ( &ErrorKind::InvalidPassword | &ErrorKind::InvalidToken => (
StatusCode::BAD_REQUEST, StatusCode::BAD_REQUEST,
ApiResponse::new_with_error(self).into_json_response(), ApiResponse::new_with_error(self).into_json_response(),
), ),

View file

@ -1,7 +1,9 @@
mod api_response; mod api_response;
mod environment; mod environment;
mod error; mod error;
mod session;
pub use api_response::*; pub use api_response::*;
pub use environment::*; pub use environment::*;
pub use error::*; pub use error::*;
pub use session::*;

32
api/src/models/session.rs Normal file
View file

@ -0,0 +1,32 @@
use std::time::SystemTime;
use humantime::format_rfc3339;
use redis::ToRedisArgs;
#[derive(Debug)]
pub struct Session {
pub user_id: i32,
pub username: String,
pub created_at: SystemTime,
pub expires_at: SystemTime,
}
impl Session {
pub fn to_cacheable_object<'a>(
self,
) -> Vec<(&'static str, impl ToRedisArgs + Send + Sync + 'a)> {
let Self {
user_id,
username,
created_at,
expires_at,
} = self;
vec![
("userId", user_id.to_string()),
("username", username),
("createdAt", format_rfc3339(created_at).to_string()),
("expiresAt", format_rfc3339(expires_at).to_string()),
]
}
}

View file

@ -1,11 +1,39 @@
use axum::{ use axum::{
debug_handler, debug_handler,
extract::State,
response::{IntoResponse, Response}, response::{IntoResponse, Response},
Json,
};
use tracing::debug;
use crate::{
db::{get_username_and_password_by_username, DbPool, UserIdAndHashedPassword},
models::AppError,
requests::AppState,
services::verify_password,
}; };
use crate::models::AppError; use super::models::AuthLoginRequest;
#[debug_handler] #[debug_handler]
pub async fn auth_login_post_handler() -> Result<Response, AppError> { pub async fn auth_login_post_handler(
State(state): State<AppState>,
Json(body): Json<AuthLoginRequest>,
) -> Result<Response, AppError> {
let pool = state.db_pool();
auth_login_request(pool, body).await
}
async fn auth_login_request(pool: &DbPool, body: AuthLoginRequest) -> Result<Response, AppError> {
debug!(?body);
let AuthLoginRequest { username, password } = body;
let UserIdAndHashedPassword {
id: _id,
password: hashed_password,
} = get_username_and_password_by_username(pool, username).await?;
verify_password(password, hashed_password)?;
Ok(().into_response()) Ok(().into_response())
} }

View file

@ -1,3 +1,4 @@
mod handler; mod handler;
mod models;
pub use handler::*; pub use handler::*;

View file

@ -0,0 +1,5 @@
mod request;
mod response;
pub use request::*;
pub use response::*;

View file

@ -0,0 +1,20 @@
use std::fmt::Debug;
use serde::Deserialize;
use serde_with::serde_as;
#[serde_as]
#[derive(Deserialize)]
pub struct AuthLoginRequest {
pub username: String,
pub password: String,
}
impl Debug for AuthLoginRequest {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AuthLoginRequest")
.field("username", &self.username)
.field("password", &"********")
.finish()
}
}

View file

@ -0,0 +1,19 @@
use std::time::SystemTime;
use serde::Serialize;
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct AuthLoginResponse {
pub user_id: i32,
pub access: AuthLoginTokenData,
pub auth: AuthLoginTokenData,
}
#[derive(Debug, Serialize)]
pub struct AuthLoginTokenData {
pub token: String,
#[serde(serialize_with = "humantime_serde::serialize")]
pub expiration: SystemTime,
}

View file

@ -12,26 +12,36 @@ use humantime::format_duration;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tower_http::trace::TraceLayer; use tower_http::trace::TraceLayer;
use tracing::{error, info, info_span, warn, Span}; use tracing::{error, info, info_span, warn, Span};
use ulid::Ulid; use uuid::Uuid;
use crate::{ use crate::{
db::DbPool, db::DbPool,
models::{AppError, Environment}, models::{AppError, Environment},
services::CachePool,
}; };
#[derive(Clone)] #[derive(Clone)]
pub struct AppState { pub struct AppState {
pool: DbPool, db_pool: DbPool,
cache_pool: CachePool,
env: Environment, env: Environment,
} }
impl AppState { impl AppState {
pub fn new(pool: DbPool, env: Environment) -> Self { pub fn new(db_pool: DbPool, cache_pool: CachePool, env: Environment) -> Self {
Self { pool, env } Self {
db_pool,
cache_pool,
env,
}
} }
pub fn pool(&self) -> &DbPool { pub fn db_pool(&self) -> &DbPool {
&self.pool &self.db_pool
}
pub fn cache_pool(&self) -> &CachePool {
&self.cache_pool
} }
pub fn env(&self) -> &Environment { pub fn env(&self) -> &Environment {
@ -39,7 +49,11 @@ impl AppState {
} }
} }
pub async fn start_app(pool: DbPool, env: Environment) -> Result<(), AppError> { pub async fn start_app(
db_pool: DbPool,
cache_pool: CachePool,
env: Environment,
) -> Result<(), AppError> {
let address = env.hostname(); let address = env.hostname();
let port = env.port(); let port = env.port();
@ -55,7 +69,7 @@ pub async fn start_app(pool: DbPool, env: Environment) -> Result<(), AppError> {
.get::<MatchedPath>() .get::<MatchedPath>()
.map(MatchedPath::as_str).unwrap_or(request.uri().path()); .map(MatchedPath::as_str).unwrap_or(request.uri().path());
info_span!("api_request", request_id = %Ulid::new(), method = %request.method(), %path, status = tracing::field::Empty) info_span!("api_request", request_id = %Uuid::now_v7(), method = %request.method(), %path, status = tracing::field::Empty)
}) })
.on_response(|response: &Response, duration: Duration, span: &Span| { .on_response(|response: &Response, duration: Duration, span: &Span| {
let status = response.status(); let status = response.status();
@ -68,7 +82,7 @@ pub async fn start_app(pool: DbPool, env: Environment) -> Result<(), AppError> {
} }
}); });
let state = AppState::new(pool, env); let state = AppState::new(db_pool, cache_pool, env);
let app = Router::new() let app = Router::new()
.merge(user::requests(state.clone())) .merge(user::requests(state.clone()))
.merge(auth::requests(state.clone())) .merge(auth::requests(state.clone()))

View file

@ -1,10 +1,13 @@
use std::sync::mpsc::Sender; use std::{sync::mpsc::Sender, time::SystemTime};
use crate::{ use crate::{
db::{insert_new_user, DbPool, NewUserEntity, UserEntity}, db::{insert_new_user, DbPool, NewUserEntity, UserEntity},
models::{ApiResponse, AppError}, models::{ApiResponse, AppError, Session},
requests::AppState, requests::AppState,
services::{auth_token::generate_new_user_token, hash_string, UserConfirmationMessage}, services::{
self, auth_token::generate_new_user_token, hash_password, CachePool,
UserConfirmationMessage,
},
}; };
use axum::{ use axum::{
debug_handler, debug_handler,
@ -27,7 +30,8 @@ pub async fn user_registration_post_handler(
register_new_user_request( register_new_user_request(
request, request,
state.pool(), state.db_pool(),
state.cache_pool(),
env.token_key(), env.token_key(),
env.send_verification_email(), env.send_verification_email(),
env.email_sender(), env.email_sender(),
@ -37,7 +41,8 @@ pub async fn user_registration_post_handler(
async fn register_new_user_request( async fn register_new_user_request(
body: UserRegistrationRequest, body: UserRegistrationRequest,
pool: &DbPool, db_pool: &DbPool,
cache_pool: &CachePool,
signing_key: &SymmetricKey<V4>, signing_key: &SymmetricKey<V4>,
send_verification_email: bool, send_verification_email: bool,
email_sender: &Sender<UserConfirmationMessage>, email_sender: &Sender<UserConfirmationMessage>,
@ -51,10 +56,10 @@ async fn register_new_user_request(
name, name,
} = body; } = body;
let hashed_password = hash_string(password); let hashed_password = hash_password(password);
let new_user = NewUserEntity { let new_user = NewUserEntity {
username, username: username.clone(),
password: hashed_password.to_string(), password: hashed_password.to_string(),
email, email,
name, name,
@ -65,7 +70,7 @@ async fn register_new_user_request(
name, name,
email, email,
.. ..
} = insert_new_user(pool, new_user).await.map_err(|err| { } = insert_new_user(db_pool, new_user).await.map_err(|err| {
if err.is_duplicate_record() { if err.is_duplicate_record() {
AppError::duplicate_record( AppError::duplicate_record(
"There is already an account associated with this username or email address.", "There is already an account associated with this username or email address.",
@ -75,7 +80,23 @@ async fn register_new_user_request(
} }
})?; })?;
let (verification_token, expiration) = generate_new_user_token(signing_key, user_id); let (verification_token, token_id, expires_at) = generate_new_user_token(signing_key, user_id);
let new_user_session = Session {
user_id,
username,
created_at: SystemTime::now(),
expires_at,
};
let expires_in = expires_at.duration_since(SystemTime::now()).unwrap();
services::user_session::store_user_session(
cache_pool,
token_id,
new_user_session,
Some(expires_in),
)
.await?;
let response_body = if send_verification_email { let response_body = if send_verification_email {
let new_user_confirmation_message = UserConfirmationMessage { let new_user_confirmation_message = UserConfirmationMessage {
@ -92,13 +113,13 @@ async fn register_new_user_request(
UserRegistrationResponse { UserRegistrationResponse {
id: user_id, id: user_id,
expiration, expires_at,
verification_token: None, verification_token: None,
} }
} else { } else {
UserRegistrationResponse { UserRegistrationResponse {
id: user_id, id: user_id,
expiration, expires_at,
verification_token: Some(verification_token), verification_token: Some(verification_token),
} }
}; };

View file

@ -1,5 +1,5 @@
mod registration_request; mod request;
mod registration_response; mod response;
pub use registration_request::*; pub use request::*;
pub use registration_response::*; pub use response::*;

View file

@ -11,7 +11,7 @@ pub struct UserRegistrationResponse {
pub id: i32, pub id: i32,
#[serde(serialize_with = "humantime_serde::serialize")] #[serde(serialize_with = "humantime_serde::serialize")]
pub expiration: SystemTime, pub expires_at: SystemTime,
pub verification_token: Option<String>, pub verification_token: Option<String>,
} }

View file

@ -23,7 +23,7 @@ pub async fn user_verification_get_handler(
Path(user_id): Path<i32>, Path(user_id): Path<i32>,
Query(query): Query<UserVerifyGetParams>, Query(query): Query<UserVerifyGetParams>,
) -> Result<Response, AppError> { ) -> Result<Response, AppError> {
let pool = state.pool(); let pool = state.db_pool();
let env = state.env(); let env = state.env();
let UserVerifyGetParams { verification_token } = query; let UserVerifyGetParams { verification_token } = query;
@ -42,7 +42,6 @@ async fn verify_new_user_request(
let validation_rules = { let validation_rules = {
let mut rules = ClaimsValidationRules::new(); let mut rules = ClaimsValidationRules::new();
rules.validate_audience_with(format!("/user/{user_id}/verify").as_str()); rules.validate_audience_with(format!("/user/{user_id}/verify").as_str());
rules rules
}; };

View file

@ -12,8 +12,8 @@ pub struct UserVerifyGetResponse {
impl UserVerifyGetResponse { impl UserVerifyGetResponse {
pub fn new(key: &SymmetricKey<V4>, user_id: i32) -> Self { pub fn new(key: &SymmetricKey<V4>, user_id: i32) -> Self {
let (access_token, access_token_expiration) = generate_access_token(key, user_id); let (access_token, _, access_token_expiration) = generate_access_token(key, user_id);
let (auth_token, auth_token_expiration) = generate_auth_token(key, user_id); let (auth_token, _, auth_token_expiration) = generate_auth_token(key, user_id);
Self { Self {
access: UserVerifyGetResponseTokenAndExpiration { access: UserVerifyGetResponseTokenAndExpiration {

View file

@ -13,7 +13,7 @@ use uuid::Uuid;
use crate::models::AppError; use crate::models::AppError;
static FOURTY_FIVE_DAYS: Duration = Duration::from_secs(3_888_000); static ONE_DAY: Duration = Duration::from_secs(86_400);
static ONE_HOUR: Duration = Duration::from_secs(3_600); static ONE_HOUR: Duration = Duration::from_secs(3_600);
static FIFTEEN_MINUTES: Duration = Duration::from_secs(900); static FIFTEEN_MINUTES: Duration = Duration::from_secs(900);
@ -51,19 +51,19 @@ pub fn verify_token(
Ok(()) Ok(())
} }
pub fn generate_access_token(key: &SymmetricKey<V4>, user_id: i32) -> (String, SystemTime) { pub fn generate_access_token(key: &SymmetricKey<V4>, user_id: i32) -> (String, Uuid, SystemTime) {
generate_token(key, user_id, Some(FOURTY_FIVE_DAYS), None) generate_token(key, user_id, ONE_HOUR, None)
} }
pub fn generate_auth_token(key: &SymmetricKey<V4>, user_id: i32) -> (String, SystemTime) { pub fn generate_auth_token(key: &SymmetricKey<V4>, user_id: i32) -> (String, Uuid, SystemTime) {
generate_token(key, user_id, None, None) generate_token(key, user_id, ONE_DAY, None)
} }
pub fn generate_new_user_token(key: &SymmetricKey<V4>, user_id: i32) -> (String, SystemTime) { pub fn generate_new_user_token(key: &SymmetricKey<V4>, user_id: i32) -> (String, Uuid, SystemTime) {
generate_token( generate_token(
key, key,
user_id, user_id,
Some(FIFTEEN_MINUTES), FIFTEEN_MINUTES,
Some(format!("/user/{user_id}/verify").as_str()), Some(format!("/user/{user_id}/verify").as_str()),
) )
} }
@ -71,20 +71,16 @@ pub fn generate_new_user_token(key: &SymmetricKey<V4>, user_id: i32) -> (String,
fn generate_token( fn generate_token(
key: &SymmetricKey<V4>, key: &SymmetricKey<V4>,
user_id: i32, user_id: i32,
duration: Option<Duration>, expires_in: Duration,
audience: Option<&str>, audience: Option<&str>,
) -> (String, SystemTime) { ) -> (String, Uuid, SystemTime) {
let now = SystemTime::now(); let now = SystemTime::now();
let expiration = if let Some(duration) = duration { let token_id = Uuid::now_v7();
duration
} else {
ONE_HOUR
};
let token = Claims::new_expires_in(&expiration) let token = Claims::new_expires_in(&expires_in)
.and_then(|mut claims| { .and_then(|mut claims| {
claims claims
.token_identifier(Uuid::now_v7().to_string().as_str()) .token_identifier(token_id.to_string().as_str())
.map(|_| claims) .map(|_| claims)
}) })
.and_then(|mut claims| { .and_then(|mut claims| {
@ -116,7 +112,7 @@ fn generate_token(
}) })
.unwrap(); .unwrap();
(token, now + expiration) (token, token_id, now + expires_in)
} }
#[cfg(test)] #[cfg(test)]
@ -136,7 +132,7 @@ mod tests {
.and_then(|bytes| SymmetricKey::<V4>::from(bytes.as_slice()).map_err(|_| ())) .and_then(|bytes| SymmetricKey::<V4>::from(bytes.as_slice()).map_err(|_| ()))
.unwrap(); .unwrap();
let token = generate_token(&key, 1, Some(Duration::from_secs(60)), Some("testing")).0; let token = generate_token(&key, 1, Duration::from_secs(60), Some("testing")).0;
let footer = { let footer = {
let mut footer = Footer::new(); let mut footer = Footer::new();

64
api/src/services/cache.rs Normal file
View file

@ -0,0 +1,64 @@
use std::{fmt::Debug, time::Duration};
use bb8_redis::{
bb8::{Pool, PooledConnection},
RedisConnectionManager,
};
use redis::{AsyncCommands, IntoConnectionInfo, ToRedisArgs};
use crate::models::AppError;
pub type CachePool = Pool<RedisConnectionManager>;
pub async fn create_cache_connection_pool(
connection_info: impl IntoConnectionInfo + Clone + Debug,
) -> Result<CachePool, AppError> {
let manager = RedisConnectionManager::new(connection_info)
.map_err(|_| AppError::connection_info("cache"))?;
let pool = Pool::builder().build(manager).await.unwrap();
Ok(pool)
}
pub async fn store_object<
'a,
F: ToRedisArgs + Send + Sync + 'a,
V: ToRedisArgs + Send + Sync + 'a,
O: AsRef<[(F, V)]>,
>(
cache_pool: &CachePool,
key: &str,
object: O,
) -> Result<(), AppError> {
let mut conn = get_connection_from_pool(cache_pool).await?;
let _: () = conn.hset_multiple(key, object.as_ref()).await?;
Ok(())
}
pub async fn store_object_with_expiration<
'a,
F: ToRedisArgs + Send + Sync + 'a,
V: ToRedisArgs + Send + Sync + 'a,
O: AsRef<[(F, V)]>,
>(
cache_pool: &CachePool,
key: &str,
object: O,
expiration: Duration,
) -> Result<(), AppError> {
let mut conn = get_connection_from_pool(cache_pool).await?;
let _: () = conn.hset_multiple(key, object.as_ref()).await?;
let _: () = conn
.expire(key, expiration.as_secs().try_into().unwrap())
.await?;
Ok(())
}
async fn get_connection_from_pool(
cache_pool: &CachePool,
) -> Result<PooledConnection<'_, RedisConnectionManager>, AppError> {
cache_pool.get().await.map_err(From::from)
}

View file

@ -1,14 +0,0 @@
use argon2::{
password_hash::{rand_core::OsRng, PasswordHash, PasswordHashString, SaltString},
Argon2,
};
pub fn hash_string(string: String) -> PasswordHashString {
let algorithm = Argon2::default();
let salt = SaltString::generate(&mut OsRng);
let hashed_password =
PasswordHash::generate(algorithm, string.as_bytes(), salt.as_salt()).unwrap();
hashed_password.serialize()
}

View file

@ -1,8 +1,12 @@
pub mod auth_token; pub mod auth_token;
mod hasher; mod cache;
mod logger; mod logger;
mod mailer; mod mailer;
mod password_hasher;
pub mod user_session;
pub use hasher::*;
pub use logger::*; pub use logger::*;
pub use mailer::*; pub use mailer::*;
pub use password_hasher::*;
pub use cache::{create_cache_connection_pool, CachePool};

View file

@ -0,0 +1,25 @@
use argon2::{
password_hash::{rand_core::OsRng, PasswordHash, PasswordHashString, SaltString},
Argon2, PasswordVerifier,
};
use crate::models::AppError;
pub fn hash_password(password: String) -> PasswordHashString {
let algorithm = Argon2::default();
let salt = SaltString::generate(&mut OsRng);
let hashed_password =
PasswordHash::generate(algorithm, password.as_bytes(), salt.as_salt()).unwrap();
hashed_password.serialize()
}
pub fn verify_password(password: String, hashed_password: String) -> Result<(), AppError> {
let algorithm = Argon2::default();
let hash = PasswordHash::new(hashed_password.as_str()).unwrap();
algorithm
.verify_password(password.as_bytes(), &hash)
.map_err(|_| AppError::invalid_password())
}

View file

@ -0,0 +1,38 @@
use std::time::Duration;
use uuid::Uuid;
use crate::{
models::{AppError, Session},
services::cache::CachePool,
};
use super::cache;
static USER_SESSION_CACHE_KEY_PREFIX: &'static str = "debt_pirate:session:";
pub async fn store_user_session(
cache_pool: &CachePool,
token_id: Uuid,
session: Session,
expiration: Option<Duration>,
) -> Result<(), AppError> {
let hashed_token = blake3::hash(token_id.as_bytes());
let key = format!("{USER_SESSION_CACHE_KEY_PREFIX}{hashed_token}");
let cacheable_object = session.to_cacheable_object();
match expiration {
Some(expiration) => {
cache::store_object_with_expiration(
cache_pool,
key.as_str(),
cacheable_object,
expiration,
)
.await?
}
None => cache::store_object(cache_pool, key.as_str(), cacheable_object).await?,
}
Ok(())
}