Change status code associated with missing session

This commit is contained in:
Z. Charles Dziura 2024-10-05 13:20:53 -04:00
parent 24a7fdd4ef
commit 3768cd26e6
6 changed files with 84 additions and 46 deletions

View file

@ -5,7 +5,7 @@ use bb8_redis::bb8::RunError;
use http::StatusCode; use http::StatusCode;
use redis::RedisError; use redis::RedisError;
use sqlx::{error::DatabaseError, migrate::MigrateError, Error as SqlxError}; use sqlx::{error::DatabaseError, migrate::MigrateError, Error as SqlxError};
use tracing::{field, trace}; use tracing::trace;
use super::ApiResponse; use super::ApiResponse;
@ -45,10 +45,14 @@ impl AppError {
Self::new(ErrorKind::MissingEnvironmentVariables(missing_vars)) Self::new(ErrorKind::MissingEnvironmentVariables(missing_vars))
} }
pub fn missing_session_field(field: &'static str) -> Self { pub fn _missing_session_field(field: &'static str) -> Self {
Self::new(ErrorKind::MissingSessionField(field)) Self::new(ErrorKind::MissingSessionField(field))
} }
pub fn no_session_found() -> Self {
Self::new(ErrorKind::NoSessionFound)
}
pub fn token_key() -> Self { pub fn token_key() -> Self {
Self::new(ErrorKind::TokenKey) Self::new(ErrorKind::TokenKey)
} }
@ -128,6 +132,7 @@ impl Display for AppError {
f, f,
"Cannot retrieve session: missing required field: {field}" "Cannot retrieve session: missing required field: {field}"
), ),
ErrorKind::NoSessionFound => write!(f, "No session found"),
ErrorKind::Sqlx(err) => write!(f, "{err}"), ErrorKind::Sqlx(err) => write!(f, "{err}"),
ErrorKind::TokenKey => write!( ErrorKind::TokenKey => write!(
f, f,
@ -137,6 +142,7 @@ impl Display for AppError {
} }
} }
#[allow(dead_code)]
#[derive(Debug)] #[derive(Debug)]
enum ErrorKind { enum ErrorKind {
AppStartupError(io::Error), AppStartupError(io::Error),
@ -149,6 +155,7 @@ enum ErrorKind {
InvalidToken, InvalidToken,
MissingEnvironmentVariables(Vec<&'static str>), MissingEnvironmentVariables(Vec<&'static str>),
MissingSessionField(&'static str), MissingSessionField(&'static str),
NoSessionFound,
Sqlx(SqlxError), Sqlx(SqlxError),
TokenKey, TokenKey,
} }
@ -164,6 +171,10 @@ impl IntoResponse for AppError {
StatusCode::BAD_REQUEST, StatusCode::BAD_REQUEST,
ApiResponse::new_with_error(self).into_json_response(), ApiResponse::new_with_error(self).into_json_response(),
), ),
&ErrorKind::NoSessionFound => (
StatusCode::UNAUTHORIZED,
ApiResponse::new_with_error(self).into_json_response(),
),
_ => ( _ => (
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
ApiResponse::new_with_error(self).into_json_response(), ApiResponse::new_with_error(self).into_json_response(),

View file

@ -2,4 +2,4 @@ mod request;
mod response; mod response;
pub use request::*; pub use request::*;
pub use response::*; // pub use response::*;

View file

@ -8,7 +8,7 @@ use axum::{
}; };
use http::StatusCode; use http::StatusCode;
use pasetors::{claims::ClaimsValidationRules, keys::SymmetricKey, version4::V4}; use pasetors::{claims::ClaimsValidationRules, keys::SymmetricKey, version4::V4};
use tracing::{debug, error, trace}; use tracing::{debug, error};
use uuid::Uuid; use uuid::Uuid;
use crate::{ use crate::{
@ -64,7 +64,15 @@ async fn verify_new_user_request(
.map(|jti| Uuid::from_str(jti.as_str().unwrap()).unwrap()) .map(|jti| Uuid::from_str(jti.as_str().unwrap()).unwrap())
.unwrap(); .unwrap();
let _ = user_session::get_user_session(cache_pool, token_id).await?; user_session::exists(cache_pool, token_id)
.await
.and_then(|exists| {
if exists {
Ok(())
} else {
Err(AppError::no_session_found())
}
})?;
let response = verify_token( let response = verify_token(
token_key, token_key,

View file

@ -46,7 +46,10 @@ pub fn verify_token(
Some("TODO_ENV_NAME_HERE".as_bytes()), Some("TODO_ENV_NAME_HERE".as_bytes()),
) )
.inspect_err(|err| error!(?err)) .inspect_err(|err| error!(?err))
.map_err(|_| AppError::invalid_token())?; .map_err(|err| {
error!(?err);
AppError::invalid_token()
})?;
Ok(token) Ok(token)
} }

View file

@ -57,7 +57,7 @@ pub async fn store_object_with_expiration<
Ok(()) Ok(())
} }
pub async fn get_object<K: FromRedisValue + Eq + Hash>( pub async fn _get_object<K: FromRedisValue + Eq + Hash>(
cache_pool: &CachePool, cache_pool: &CachePool,
key: &str, key: &str,
) -> Result<Option<HashMap<K, Value>>, AppError> { ) -> Result<Option<HashMap<K, Value>>, AppError> {
@ -67,6 +67,11 @@ pub async fn get_object<K: FromRedisValue + Eq + Hash>(
Ok(hash_fields) Ok(hash_fields)
} }
pub async fn exists(cache_pool: &CachePool, key: &str) -> Result<bool, AppError> {
let mut conn = get_connection_from_pool(cache_pool).await?;
conn.exists(key).await.map_err(Into::into)
}
async fn get_connection_from_pool( async fn get_connection_from_pool(
cache_pool: &CachePool, cache_pool: &CachePool,
) -> Result<PooledConnection<'_, RedisConnectionManager>, AppError> { ) -> Result<PooledConnection<'_, RedisConnectionManager>, AppError> {

View file

@ -38,59 +38,70 @@ pub async fn store_user_session(
Ok(()) Ok(())
} }
pub async fn get_user_session( pub async fn _get_user_session(
cache_pool: &CachePool, cache_pool: &CachePool,
token_id: Uuid, token_id: Uuid,
) -> Result<Option<Session>, AppError> { ) -> Result<Option<Session>, AppError> {
let key = make_key(token_id); let key = make_key(token_id);
let object = cache::get_object::<String>(cache_pool, key.as_str()).await?; let session_exists = cache::exists(cache_pool, key.as_str()).await?;
if let Some(object) = object { if session_exists {
let user_id: i32 = object let object = cache::_get_object::<String>(cache_pool, key.as_str()).await?;
.get("userId")
.ok_or(AppError::missing_session_field("userId"))
.and_then(|value| {
FromRedisValue::from_redis_value(value)
.map_err(|_| AppError::missing_session_field("userId"))
})?;
let username: String = object if let Some(object) = object {
.get("username") let user_id: i32 = object
.ok_or(AppError::missing_session_field("username")) .get("userId")
.and_then(|value| { .ok_or(AppError::_missing_session_field("userId"))
FromRedisValue::from_redis_value(value) .and_then(|value| {
.map_err(|_| AppError::missing_session_field("username")) FromRedisValue::from_redis_value(value)
})?; .map_err(|_| AppError::_missing_session_field("userId"))
})?;
let created_at: SystemTime = object let username: String = object
.get("createdAt") .get("username")
.ok_or(AppError::missing_session_field("createdAt")) .ok_or(AppError::_missing_session_field("username"))
.and_then(|value| { .and_then(|value| {
FromRedisValue::from_redis_value(value) FromRedisValue::from_redis_value(value)
.map_err(|_| AppError::missing_session_field("createdAt")) .map_err(|_| AppError::_missing_session_field("username"))
}) })?;
.map(|serialized: String| parse_rfc3339(serialized.as_str()).unwrap())?;
let expires_at: SystemTime = object let created_at: SystemTime = object
.get("createdAt") .get("createdAt")
.ok_or(AppError::missing_session_field("expiresAt")) .ok_or(AppError::_missing_session_field("createdAt"))
.and_then(|value| { .and_then(|value| {
FromRedisValue::from_redis_value(value) FromRedisValue::from_redis_value(value)
.map_err(|_| AppError::missing_session_field("expiresAt")) .map_err(|_| AppError::_missing_session_field("createdAt"))
}) })
.map(|serialized: String| parse_rfc3339(serialized.as_str()).unwrap())?; .map(|serialized: String| parse_rfc3339(serialized.as_str()).unwrap())?;
Ok(Some(Session { let expires_at: SystemTime = object
user_id, .get("createdAt")
username, .ok_or(AppError::_missing_session_field("expiresAt"))
created_at, .and_then(|value| {
expires_at, FromRedisValue::from_redis_value(value)
})) .map_err(|_| AppError::_missing_session_field("expiresAt"))
})
.map(|serialized: String| parse_rfc3339(serialized.as_str()).unwrap())?;
Ok(Some(Session {
user_id,
username,
created_at,
expires_at,
}))
} else {
Ok(None)
}
} else { } else {
Ok(None) Ok(None)
} }
} }
pub async fn exists(cache_pool: &CachePool, token_id: Uuid) -> Result<bool, AppError> {
let key = make_key(token_id);
cache::exists(cache_pool, key.as_str()).await
}
fn make_key(token_id: Uuid) -> String { fn make_key(token_id: Uuid) -> String {
let hashed_token = blake3::hash(token_id.as_bytes()); let hashed_token = blake3::hash(token_id.as_bytes());
format!("{USER_SESSION_CACHE_KEY_PREFIX}{hashed_token}") format!("{USER_SESSION_CACHE_KEY_PREFIX}{hashed_token}")