Fetch the temporary session during validation

This commit is contained in:
Z. Charles Dziura 2024-10-05 10:45:41 -04:00
parent d41a92d07a
commit 24a7fdd4ef
6 changed files with 128 additions and 19 deletions

View file

@ -5,7 +5,7 @@ use bb8_redis::bb8::RunError;
use http::StatusCode;
use redis::RedisError;
use sqlx::{error::DatabaseError, migrate::MigrateError, Error as SqlxError};
use tracing::trace;
use tracing::{field, trace};
use super::ApiResponse;
@ -45,6 +45,10 @@ impl AppError {
Self::new(ErrorKind::MissingEnvironmentVariables(missing_vars))
}
pub fn missing_session_field(field: &'static str) -> Self {
Self::new(ErrorKind::MissingSessionField(field))
}
pub fn token_key() -> Self {
Self::new(ErrorKind::TokenKey)
}
@ -120,6 +124,10 @@ impl Display for AppError {
"Missing required environment variables: {}",
missing_vars.join(", ")
),
ErrorKind::MissingSessionField(field) => write!(
f,
"Cannot retrieve session: missing required field: {field}"
),
ErrorKind::Sqlx(err) => write!(f, "{err}"),
ErrorKind::TokenKey => write!(
f,
@ -140,6 +148,7 @@ enum ErrorKind {
InvalidPassword,
InvalidToken,
MissingEnvironmentVariables(Vec<&'static str>),
MissingSessionField(&'static str),
Sqlx(SqlxError),
TokenKey,
}

View file

@ -1,4 +1,4 @@
use std::time::SystemTime;
use std::time::{SystemTime, UNIX_EPOCH};
use humantime::format_rfc3339;
use redis::ToRedisArgs;
@ -30,3 +30,14 @@ impl Session {
]
}
}
impl Default for Session {
fn default() -> Self {
Self {
user_id: Default::default(),
username: Default::default(),
created_at: UNIX_EPOCH,
expires_at: UNIX_EPOCH,
}
}
}

View file

@ -1,3 +1,5 @@
use std::str::FromStr;
use axum::{
debug_handler,
extract::{Path, Query, State},
@ -6,13 +8,14 @@ use axum::{
};
use http::StatusCode;
use pasetors::{claims::ClaimsValidationRules, keys::SymmetricKey, version4::V4};
use tracing::{debug, error};
use tracing::{debug, error, trace};
use uuid::Uuid;
use crate::{
db::{verify_user, DbPool},
models::{ApiResponse, AppError},
requests::AppState,
services::auth_token::verify_token,
services::{auth_token::verify_token, user_session, CachePool},
};
use super::{UserVerifyGetParams, UserVerifyGetResponse};
@ -23,16 +26,18 @@ pub async fn user_verification_get_handler(
Path(user_id): Path<i32>,
Query(query): Query<UserVerifyGetParams>,
) -> Result<Response, AppError> {
let pool = state.db_pool();
let db_pool = state.db_pool();
let cache_pool = state.cache_pool();
let env = state.env();
let UserVerifyGetParams { verification_token } = query;
let token_key = env.token_key();
verify_new_user_request(pool, user_id, verification_token, token_key).await
verify_new_user_request(db_pool, cache_pool, user_id, verification_token, token_key).await
}
async fn verify_new_user_request(
pool: &DbPool,
db_pool: &DbPool,
cache_pool: &CachePool,
user_id: i32,
verification_token: String,
token_key: &SymmetricKey<V4>,
@ -45,15 +50,30 @@ async fn verify_new_user_request(
rules
};
let verified_token = verify_token(
token_key,
verification_token.as_str(),
Some(validation_rules.clone()),
)
.inspect_err(|err| error!(?err))?;
let token_id = verified_token
.payload_claims()
.map(|claims| claims.get_claim("jti"))
.flatten()
.map(|jti| Uuid::from_str(jti.as_str().unwrap()).unwrap())
.unwrap();
let _ = user_session::get_user_session(cache_pool, token_id).await?;
let response = verify_token(
token_key,
verification_token.as_str(),
Some(validation_rules),
)
.map(|_| UserVerifyGetResponse::new(token_key, user_id))
.inspect_err(|err| error!(?err))?;
.map(|_| UserVerifyGetResponse::new(token_key, user_id))?;
verify_user(pool, user_id)
verify_user(db_pool, user_id)
.await
.inspect_err(|err| error!(?err))?;

View file

@ -5,7 +5,7 @@ use pasetors::{
footer::Footer,
keys::SymmetricKey,
local,
token::UntrustedToken,
token::{TrustedToken, UntrustedToken},
version4::V4,
};
use tracing::error;
@ -21,7 +21,7 @@ pub fn verify_token(
key: &SymmetricKey<V4>,
token: &str,
validation_rules: Option<ClaimsValidationRules>,
) -> Result<(), AppError> {
) -> Result<TrustedToken, AppError> {
let token = UntrustedToken::try_from(token)
.inspect_err(|err| error!(?err))
.map_err(|_| AppError::invalid_token())?;
@ -38,7 +38,7 @@ pub fn verify_token(
footer
};
let _ = local::decrypt(
let token = local::decrypt(
key,
&token,
&validation_rules,
@ -48,7 +48,7 @@ pub fn verify_token(
.inspect_err(|err| error!(?err))
.map_err(|_| AppError::invalid_token())?;
Ok(())
Ok(token)
}
pub fn generate_access_token(key: &SymmetricKey<V4>, user_id: i32) -> (String, Uuid, SystemTime) {

View file

@ -1,10 +1,10 @@
use std::{fmt::Debug, time::Duration};
use std::{collections::HashMap, fmt::Debug, hash::Hash, time::Duration};
use bb8_redis::{
bb8::{Pool, PooledConnection},
RedisConnectionManager,
};
use redis::{AsyncCommands, IntoConnectionInfo, ToRedisArgs};
use redis::{AsyncCommands, FromRedisValue, IntoConnectionInfo, ToRedisArgs, Value};
use crate::models::AppError;
@ -57,6 +57,16 @@ pub async fn store_object_with_expiration<
Ok(())
}
pub async fn get_object<K: FromRedisValue + Eq + Hash>(
cache_pool: &CachePool,
key: &str,
) -> Result<Option<HashMap<K, Value>>, AppError> {
let mut conn = get_connection_from_pool(cache_pool).await?;
let hash_fields: Option<HashMap<_, Value>> = conn.hgetall(key).await?;
Ok(hash_fields)
}
async fn get_connection_from_pool(
cache_pool: &CachePool,
) -> Result<PooledConnection<'_, RedisConnectionManager>, AppError> {

View file

@ -1,5 +1,7 @@
use std::time::Duration;
use std::time::{Duration, SystemTime};
use humantime::parse_rfc3339;
use redis::FromRedisValue;
use uuid::Uuid;
use crate::{
@ -17,8 +19,7 @@ pub async fn store_user_session(
session: Session,
expiration: Option<Duration>,
) -> Result<(), AppError> {
let hashed_token = blake3::hash(token_id.as_bytes());
let key = format!("{USER_SESSION_CACHE_KEY_PREFIX}{hashed_token}");
let key = make_key(token_id);
let cacheable_object = session.to_cacheable_object();
match expiration {
@ -36,3 +37,61 @@ pub async fn store_user_session(
Ok(())
}
pub async fn get_user_session(
cache_pool: &CachePool,
token_id: Uuid,
) -> Result<Option<Session>, AppError> {
let key = make_key(token_id);
let object = cache::get_object::<String>(cache_pool, key.as_str()).await?;
if let Some(object) = object {
let user_id: i32 = object
.get("userId")
.ok_or(AppError::missing_session_field("userId"))
.and_then(|value| {
FromRedisValue::from_redis_value(value)
.map_err(|_| AppError::missing_session_field("userId"))
})?;
let username: String = object
.get("username")
.ok_or(AppError::missing_session_field("username"))
.and_then(|value| {
FromRedisValue::from_redis_value(value)
.map_err(|_| AppError::missing_session_field("username"))
})?;
let created_at: SystemTime = object
.get("createdAt")
.ok_or(AppError::missing_session_field("createdAt"))
.and_then(|value| {
FromRedisValue::from_redis_value(value)
.map_err(|_| AppError::missing_session_field("createdAt"))
})
.map(|serialized: String| parse_rfc3339(serialized.as_str()).unwrap())?;
let expires_at: SystemTime = object
.get("createdAt")
.ok_or(AppError::missing_session_field("expiresAt"))
.and_then(|value| {
FromRedisValue::from_redis_value(value)
.map_err(|_| AppError::missing_session_field("expiresAt"))
})
.map(|serialized: String| parse_rfc3339(serialized.as_str()).unwrap())?;
Ok(Some(Session {
user_id,
username,
created_at,
expires_at,
}))
} else {
Ok(None)
}
}
fn make_key(token_id: Uuid) -> String {
let hashed_token = blake3::hash(token_id.as_bytes());
format!("{USER_SESSION_CACHE_KEY_PREFIX}{hashed_token}")
}