diff --git a/api/src/models/error.rs b/api/src/models/error.rs index 6002ae1..5ef7ca7 100644 --- a/api/src/models/error.rs +++ b/api/src/models/error.rs @@ -5,7 +5,7 @@ use bb8_redis::bb8::RunError; use http::StatusCode; use redis::RedisError; use sqlx::{error::DatabaseError, migrate::MigrateError, Error as SqlxError}; -use tracing::trace; +use tracing::{field, trace}; use super::ApiResponse; @@ -45,6 +45,10 @@ impl AppError { Self::new(ErrorKind::MissingEnvironmentVariables(missing_vars)) } + pub fn missing_session_field(field: &'static str) -> Self { + Self::new(ErrorKind::MissingSessionField(field)) + } + pub fn token_key() -> Self { Self::new(ErrorKind::TokenKey) } @@ -120,6 +124,10 @@ impl Display for AppError { "Missing required environment variables: {}", missing_vars.join(", ") ), + ErrorKind::MissingSessionField(field) => write!( + f, + "Cannot retrieve session: missing required field: {field}" + ), ErrorKind::Sqlx(err) => write!(f, "{err}"), ErrorKind::TokenKey => write!( f, @@ -140,6 +148,7 @@ enum ErrorKind { InvalidPassword, InvalidToken, MissingEnvironmentVariables(Vec<&'static str>), + MissingSessionField(&'static str), Sqlx(SqlxError), TokenKey, } diff --git a/api/src/models/session.rs b/api/src/models/session.rs index db171c5..1f6bb1c 100644 --- a/api/src/models/session.rs +++ b/api/src/models/session.rs @@ -1,4 +1,4 @@ -use std::time::SystemTime; +use std::time::{SystemTime, UNIX_EPOCH}; use humantime::format_rfc3339; use redis::ToRedisArgs; @@ -30,3 +30,14 @@ impl Session { ] } } + +impl Default for Session { + fn default() -> Self { + Self { + user_id: Default::default(), + username: Default::default(), + created_at: UNIX_EPOCH, + expires_at: UNIX_EPOCH, + } + } +} diff --git a/api/src/requests/user/verify/handler.rs b/api/src/requests/user/verify/handler.rs index b7faa92..52761a9 100644 --- a/api/src/requests/user/verify/handler.rs +++ b/api/src/requests/user/verify/handler.rs @@ -1,3 +1,5 @@ +use std::str::FromStr; + use axum::{ debug_handler, extract::{Path, Query, State}, @@ -6,13 +8,14 @@ use axum::{ }; use http::StatusCode; use pasetors::{claims::ClaimsValidationRules, keys::SymmetricKey, version4::V4}; -use tracing::{debug, error}; +use tracing::{debug, error, trace}; +use uuid::Uuid; use crate::{ db::{verify_user, DbPool}, models::{ApiResponse, AppError}, requests::AppState, - services::auth_token::verify_token, + services::{auth_token::verify_token, user_session, CachePool}, }; use super::{UserVerifyGetParams, UserVerifyGetResponse}; @@ -23,16 +26,18 @@ pub async fn user_verification_get_handler( Path(user_id): Path, Query(query): Query, ) -> Result { - let pool = state.db_pool(); + let db_pool = state.db_pool(); + let cache_pool = state.cache_pool(); let env = state.env(); let UserVerifyGetParams { verification_token } = query; let token_key = env.token_key(); - verify_new_user_request(pool, user_id, verification_token, token_key).await + verify_new_user_request(db_pool, cache_pool, user_id, verification_token, token_key).await } async fn verify_new_user_request( - pool: &DbPool, + db_pool: &DbPool, + cache_pool: &CachePool, user_id: i32, verification_token: String, token_key: &SymmetricKey, @@ -45,15 +50,30 @@ async fn verify_new_user_request( rules }; + let verified_token = verify_token( + token_key, + verification_token.as_str(), + Some(validation_rules.clone()), + ) + .inspect_err(|err| error!(?err))?; + + let token_id = verified_token + .payload_claims() + .map(|claims| claims.get_claim("jti")) + .flatten() + .map(|jti| Uuid::from_str(jti.as_str().unwrap()).unwrap()) + .unwrap(); + + let _ = user_session::get_user_session(cache_pool, token_id).await?; + let response = verify_token( token_key, verification_token.as_str(), Some(validation_rules), ) - .map(|_| UserVerifyGetResponse::new(token_key, user_id)) - .inspect_err(|err| error!(?err))?; + .map(|_| UserVerifyGetResponse::new(token_key, user_id))?; - verify_user(pool, user_id) + verify_user(db_pool, user_id) .await .inspect_err(|err| error!(?err))?; diff --git a/api/src/services/auth_token.rs b/api/src/services/auth_token.rs index bc0656c..ac9e82b 100644 --- a/api/src/services/auth_token.rs +++ b/api/src/services/auth_token.rs @@ -5,7 +5,7 @@ use pasetors::{ footer::Footer, keys::SymmetricKey, local, - token::UntrustedToken, + token::{TrustedToken, UntrustedToken}, version4::V4, }; use tracing::error; @@ -21,7 +21,7 @@ pub fn verify_token( key: &SymmetricKey, token: &str, validation_rules: Option, -) -> Result<(), AppError> { +) -> Result { let token = UntrustedToken::try_from(token) .inspect_err(|err| error!(?err)) .map_err(|_| AppError::invalid_token())?; @@ -38,7 +38,7 @@ pub fn verify_token( footer }; - let _ = local::decrypt( + let token = local::decrypt( key, &token, &validation_rules, @@ -48,7 +48,7 @@ pub fn verify_token( .inspect_err(|err| error!(?err)) .map_err(|_| AppError::invalid_token())?; - Ok(()) + Ok(token) } pub fn generate_access_token(key: &SymmetricKey, user_id: i32) -> (String, Uuid, SystemTime) { diff --git a/api/src/services/cache.rs b/api/src/services/cache.rs index 4d40f6f..81633e1 100644 --- a/api/src/services/cache.rs +++ b/api/src/services/cache.rs @@ -1,10 +1,10 @@ -use std::{fmt::Debug, time::Duration}; +use std::{collections::HashMap, fmt::Debug, hash::Hash, time::Duration}; use bb8_redis::{ bb8::{Pool, PooledConnection}, RedisConnectionManager, }; -use redis::{AsyncCommands, IntoConnectionInfo, ToRedisArgs}; +use redis::{AsyncCommands, FromRedisValue, IntoConnectionInfo, ToRedisArgs, Value}; use crate::models::AppError; @@ -57,6 +57,16 @@ pub async fn store_object_with_expiration< Ok(()) } +pub async fn get_object( + cache_pool: &CachePool, + key: &str, +) -> Result>, AppError> { + let mut conn = get_connection_from_pool(cache_pool).await?; + let hash_fields: Option> = conn.hgetall(key).await?; + + Ok(hash_fields) +} + async fn get_connection_from_pool( cache_pool: &CachePool, ) -> Result, AppError> { diff --git a/api/src/services/user_session.rs b/api/src/services/user_session.rs index 9152de9..84a4738 100644 --- a/api/src/services/user_session.rs +++ b/api/src/services/user_session.rs @@ -1,5 +1,7 @@ -use std::time::Duration; +use std::time::{Duration, SystemTime}; +use humantime::parse_rfc3339; +use redis::FromRedisValue; use uuid::Uuid; use crate::{ @@ -17,8 +19,7 @@ pub async fn store_user_session( session: Session, expiration: Option, ) -> Result<(), AppError> { - let hashed_token = blake3::hash(token_id.as_bytes()); - let key = format!("{USER_SESSION_CACHE_KEY_PREFIX}{hashed_token}"); + let key = make_key(token_id); let cacheable_object = session.to_cacheable_object(); match expiration { @@ -36,3 +37,61 @@ pub async fn store_user_session( Ok(()) } + +pub async fn get_user_session( + cache_pool: &CachePool, + token_id: Uuid, +) -> Result, AppError> { + let key = make_key(token_id); + let object = cache::get_object::(cache_pool, key.as_str()).await?; + + if let Some(object) = object { + let user_id: i32 = object + .get("userId") + .ok_or(AppError::missing_session_field("userId")) + .and_then(|value| { + FromRedisValue::from_redis_value(value) + .map_err(|_| AppError::missing_session_field("userId")) + })?; + + let username: String = object + .get("username") + .ok_or(AppError::missing_session_field("username")) + .and_then(|value| { + FromRedisValue::from_redis_value(value) + .map_err(|_| AppError::missing_session_field("username")) + })?; + + let created_at: SystemTime = object + .get("createdAt") + .ok_or(AppError::missing_session_field("createdAt")) + .and_then(|value| { + FromRedisValue::from_redis_value(value) + .map_err(|_| AppError::missing_session_field("createdAt")) + }) + .map(|serialized: String| parse_rfc3339(serialized.as_str()).unwrap())?; + + let expires_at: SystemTime = object + .get("createdAt") + .ok_or(AppError::missing_session_field("expiresAt")) + .and_then(|value| { + FromRedisValue::from_redis_value(value) + .map_err(|_| AppError::missing_session_field("expiresAt")) + }) + .map(|serialized: String| parse_rfc3339(serialized.as_str()).unwrap())?; + + Ok(Some(Session { + user_id, + username, + created_at, + expires_at, + })) + } else { + Ok(None) + } +} + +fn make_key(token_id: Uuid) -> String { + let hashed_token = blake3::hash(token_id.as_bytes()); + format!("{USER_SESSION_CACHE_KEY_PREFIX}{hashed_token}") +}