Fetch the temporary session during validation

This commit is contained in:
Z. Charles Dziura 2024-10-05 10:45:41 -04:00
parent d41a92d07a
commit 24a7fdd4ef
6 changed files with 128 additions and 19 deletions

View file

@ -5,7 +5,7 @@ use bb8_redis::bb8::RunError;
use http::StatusCode; use http::StatusCode;
use redis::RedisError; use redis::RedisError;
use sqlx::{error::DatabaseError, migrate::MigrateError, Error as SqlxError}; use sqlx::{error::DatabaseError, migrate::MigrateError, Error as SqlxError};
use tracing::trace; use tracing::{field, trace};
use super::ApiResponse; use super::ApiResponse;
@ -45,6 +45,10 @@ impl AppError {
Self::new(ErrorKind::MissingEnvironmentVariables(missing_vars)) Self::new(ErrorKind::MissingEnvironmentVariables(missing_vars))
} }
pub fn missing_session_field(field: &'static str) -> Self {
Self::new(ErrorKind::MissingSessionField(field))
}
pub fn token_key() -> Self { pub fn token_key() -> Self {
Self::new(ErrorKind::TokenKey) Self::new(ErrorKind::TokenKey)
} }
@ -120,6 +124,10 @@ impl Display for AppError {
"Missing required environment variables: {}", "Missing required environment variables: {}",
missing_vars.join(", ") missing_vars.join(", ")
), ),
ErrorKind::MissingSessionField(field) => write!(
f,
"Cannot retrieve session: missing required field: {field}"
),
ErrorKind::Sqlx(err) => write!(f, "{err}"), ErrorKind::Sqlx(err) => write!(f, "{err}"),
ErrorKind::TokenKey => write!( ErrorKind::TokenKey => write!(
f, f,
@ -140,6 +148,7 @@ enum ErrorKind {
InvalidPassword, InvalidPassword,
InvalidToken, InvalidToken,
MissingEnvironmentVariables(Vec<&'static str>), MissingEnvironmentVariables(Vec<&'static str>),
MissingSessionField(&'static str),
Sqlx(SqlxError), Sqlx(SqlxError),
TokenKey, TokenKey,
} }

View file

@ -1,4 +1,4 @@
use std::time::SystemTime; use std::time::{SystemTime, UNIX_EPOCH};
use humantime::format_rfc3339; use humantime::format_rfc3339;
use redis::ToRedisArgs; use redis::ToRedisArgs;
@ -30,3 +30,14 @@ impl Session {
] ]
} }
} }
impl Default for Session {
fn default() -> Self {
Self {
user_id: Default::default(),
username: Default::default(),
created_at: UNIX_EPOCH,
expires_at: UNIX_EPOCH,
}
}
}

View file

@ -1,3 +1,5 @@
use std::str::FromStr;
use axum::{ use axum::{
debug_handler, debug_handler,
extract::{Path, Query, State}, extract::{Path, Query, State},
@ -6,13 +8,14 @@ use axum::{
}; };
use http::StatusCode; use http::StatusCode;
use pasetors::{claims::ClaimsValidationRules, keys::SymmetricKey, version4::V4}; use pasetors::{claims::ClaimsValidationRules, keys::SymmetricKey, version4::V4};
use tracing::{debug, error}; use tracing::{debug, error, trace};
use uuid::Uuid;
use crate::{ use crate::{
db::{verify_user, DbPool}, db::{verify_user, DbPool},
models::{ApiResponse, AppError}, models::{ApiResponse, AppError},
requests::AppState, requests::AppState,
services::auth_token::verify_token, services::{auth_token::verify_token, user_session, CachePool},
}; };
use super::{UserVerifyGetParams, UserVerifyGetResponse}; use super::{UserVerifyGetParams, UserVerifyGetResponse};
@ -23,16 +26,18 @@ pub async fn user_verification_get_handler(
Path(user_id): Path<i32>, Path(user_id): Path<i32>,
Query(query): Query<UserVerifyGetParams>, Query(query): Query<UserVerifyGetParams>,
) -> Result<Response, AppError> { ) -> Result<Response, AppError> {
let pool = state.db_pool(); let db_pool = state.db_pool();
let cache_pool = state.cache_pool();
let env = state.env(); let env = state.env();
let UserVerifyGetParams { verification_token } = query; let UserVerifyGetParams { verification_token } = query;
let token_key = env.token_key(); let token_key = env.token_key();
verify_new_user_request(pool, user_id, verification_token, token_key).await verify_new_user_request(db_pool, cache_pool, user_id, verification_token, token_key).await
} }
async fn verify_new_user_request( async fn verify_new_user_request(
pool: &DbPool, db_pool: &DbPool,
cache_pool: &CachePool,
user_id: i32, user_id: i32,
verification_token: String, verification_token: String,
token_key: &SymmetricKey<V4>, token_key: &SymmetricKey<V4>,
@ -45,15 +50,30 @@ async fn verify_new_user_request(
rules rules
}; };
let verified_token = verify_token(
token_key,
verification_token.as_str(),
Some(validation_rules.clone()),
)
.inspect_err(|err| error!(?err))?;
let token_id = verified_token
.payload_claims()
.map(|claims| claims.get_claim("jti"))
.flatten()
.map(|jti| Uuid::from_str(jti.as_str().unwrap()).unwrap())
.unwrap();
let _ = user_session::get_user_session(cache_pool, token_id).await?;
let response = verify_token( let response = verify_token(
token_key, token_key,
verification_token.as_str(), verification_token.as_str(),
Some(validation_rules), Some(validation_rules),
) )
.map(|_| UserVerifyGetResponse::new(token_key, user_id)) .map(|_| UserVerifyGetResponse::new(token_key, user_id))?;
.inspect_err(|err| error!(?err))?;
verify_user(pool, user_id) verify_user(db_pool, user_id)
.await .await
.inspect_err(|err| error!(?err))?; .inspect_err(|err| error!(?err))?;

View file

@ -5,7 +5,7 @@ use pasetors::{
footer::Footer, footer::Footer,
keys::SymmetricKey, keys::SymmetricKey,
local, local,
token::UntrustedToken, token::{TrustedToken, UntrustedToken},
version4::V4, version4::V4,
}; };
use tracing::error; use tracing::error;
@ -21,7 +21,7 @@ pub fn verify_token(
key: &SymmetricKey<V4>, key: &SymmetricKey<V4>,
token: &str, token: &str,
validation_rules: Option<ClaimsValidationRules>, validation_rules: Option<ClaimsValidationRules>,
) -> Result<(), AppError> { ) -> Result<TrustedToken, AppError> {
let token = UntrustedToken::try_from(token) let token = UntrustedToken::try_from(token)
.inspect_err(|err| error!(?err)) .inspect_err(|err| error!(?err))
.map_err(|_| AppError::invalid_token())?; .map_err(|_| AppError::invalid_token())?;
@ -38,7 +38,7 @@ pub fn verify_token(
footer footer
}; };
let _ = local::decrypt( let token = local::decrypt(
key, key,
&token, &token,
&validation_rules, &validation_rules,
@ -48,7 +48,7 @@ pub fn verify_token(
.inspect_err(|err| error!(?err)) .inspect_err(|err| error!(?err))
.map_err(|_| AppError::invalid_token())?; .map_err(|_| AppError::invalid_token())?;
Ok(()) Ok(token)
} }
pub fn generate_access_token(key: &SymmetricKey<V4>, user_id: i32) -> (String, Uuid, SystemTime) { pub fn generate_access_token(key: &SymmetricKey<V4>, user_id: i32) -> (String, Uuid, SystemTime) {

View file

@ -1,10 +1,10 @@
use std::{fmt::Debug, time::Duration}; use std::{collections::HashMap, fmt::Debug, hash::Hash, time::Duration};
use bb8_redis::{ use bb8_redis::{
bb8::{Pool, PooledConnection}, bb8::{Pool, PooledConnection},
RedisConnectionManager, RedisConnectionManager,
}; };
use redis::{AsyncCommands, IntoConnectionInfo, ToRedisArgs}; use redis::{AsyncCommands, FromRedisValue, IntoConnectionInfo, ToRedisArgs, Value};
use crate::models::AppError; use crate::models::AppError;
@ -57,6 +57,16 @@ pub async fn store_object_with_expiration<
Ok(()) Ok(())
} }
pub async fn get_object<K: FromRedisValue + Eq + Hash>(
cache_pool: &CachePool,
key: &str,
) -> Result<Option<HashMap<K, Value>>, AppError> {
let mut conn = get_connection_from_pool(cache_pool).await?;
let hash_fields: Option<HashMap<_, Value>> = conn.hgetall(key).await?;
Ok(hash_fields)
}
async fn get_connection_from_pool( async fn get_connection_from_pool(
cache_pool: &CachePool, cache_pool: &CachePool,
) -> Result<PooledConnection<'_, RedisConnectionManager>, AppError> { ) -> Result<PooledConnection<'_, RedisConnectionManager>, AppError> {

View file

@ -1,5 +1,7 @@
use std::time::Duration; use std::time::{Duration, SystemTime};
use humantime::parse_rfc3339;
use redis::FromRedisValue;
use uuid::Uuid; use uuid::Uuid;
use crate::{ use crate::{
@ -17,8 +19,7 @@ pub async fn store_user_session(
session: Session, session: Session,
expiration: Option<Duration>, expiration: Option<Duration>,
) -> Result<(), AppError> { ) -> Result<(), AppError> {
let hashed_token = blake3::hash(token_id.as_bytes()); let key = make_key(token_id);
let key = format!("{USER_SESSION_CACHE_KEY_PREFIX}{hashed_token}");
let cacheable_object = session.to_cacheable_object(); let cacheable_object = session.to_cacheable_object();
match expiration { match expiration {
@ -36,3 +37,61 @@ pub async fn store_user_session(
Ok(()) Ok(())
} }
pub async fn get_user_session(
cache_pool: &CachePool,
token_id: Uuid,
) -> Result<Option<Session>, AppError> {
let key = make_key(token_id);
let object = cache::get_object::<String>(cache_pool, key.as_str()).await?;
if let Some(object) = object {
let user_id: i32 = object
.get("userId")
.ok_or(AppError::missing_session_field("userId"))
.and_then(|value| {
FromRedisValue::from_redis_value(value)
.map_err(|_| AppError::missing_session_field("userId"))
})?;
let username: String = object
.get("username")
.ok_or(AppError::missing_session_field("username"))
.and_then(|value| {
FromRedisValue::from_redis_value(value)
.map_err(|_| AppError::missing_session_field("username"))
})?;
let created_at: SystemTime = object
.get("createdAt")
.ok_or(AppError::missing_session_field("createdAt"))
.and_then(|value| {
FromRedisValue::from_redis_value(value)
.map_err(|_| AppError::missing_session_field("createdAt"))
})
.map(|serialized: String| parse_rfc3339(serialized.as_str()).unwrap())?;
let expires_at: SystemTime = object
.get("createdAt")
.ok_or(AppError::missing_session_field("expiresAt"))
.and_then(|value| {
FromRedisValue::from_redis_value(value)
.map_err(|_| AppError::missing_session_field("expiresAt"))
})
.map(|serialized: String| parse_rfc3339(serialized.as_str()).unwrap())?;
Ok(Some(Session {
user_id,
username,
created_at,
expires_at,
}))
} else {
Ok(None)
}
}
fn make_key(token_id: Uuid) -> String {
let hashed_token = blake3::hash(token_id.as_bytes());
format!("{USER_SESSION_CACHE_KEY_PREFIX}{hashed_token}")
}