Fetch the temporary session during validation
This commit is contained in:
parent
d41a92d07a
commit
24a7fdd4ef
6 changed files with 128 additions and 19 deletions
|
@ -5,7 +5,7 @@ use bb8_redis::bb8::RunError;
|
|||
use http::StatusCode;
|
||||
use redis::RedisError;
|
||||
use sqlx::{error::DatabaseError, migrate::MigrateError, Error as SqlxError};
|
||||
use tracing::trace;
|
||||
use tracing::{field, trace};
|
||||
|
||||
use super::ApiResponse;
|
||||
|
||||
|
@ -45,6 +45,10 @@ impl AppError {
|
|||
Self::new(ErrorKind::MissingEnvironmentVariables(missing_vars))
|
||||
}
|
||||
|
||||
pub fn missing_session_field(field: &'static str) -> Self {
|
||||
Self::new(ErrorKind::MissingSessionField(field))
|
||||
}
|
||||
|
||||
pub fn token_key() -> Self {
|
||||
Self::new(ErrorKind::TokenKey)
|
||||
}
|
||||
|
@ -120,6 +124,10 @@ impl Display for AppError {
|
|||
"Missing required environment variables: {}",
|
||||
missing_vars.join(", ")
|
||||
),
|
||||
ErrorKind::MissingSessionField(field) => write!(
|
||||
f,
|
||||
"Cannot retrieve session: missing required field: {field}"
|
||||
),
|
||||
ErrorKind::Sqlx(err) => write!(f, "{err}"),
|
||||
ErrorKind::TokenKey => write!(
|
||||
f,
|
||||
|
@ -140,6 +148,7 @@ enum ErrorKind {
|
|||
InvalidPassword,
|
||||
InvalidToken,
|
||||
MissingEnvironmentVariables(Vec<&'static str>),
|
||||
MissingSessionField(&'static str),
|
||||
Sqlx(SqlxError),
|
||||
TokenKey,
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use std::time::SystemTime;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use humantime::format_rfc3339;
|
||||
use redis::ToRedisArgs;
|
||||
|
@ -30,3 +30,14 @@ impl Session {
|
|||
]
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Session {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
user_id: Default::default(),
|
||||
username: Default::default(),
|
||||
created_at: UNIX_EPOCH,
|
||||
expires_at: UNIX_EPOCH,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
use std::str::FromStr;
|
||||
|
||||
use axum::{
|
||||
debug_handler,
|
||||
extract::{Path, Query, State},
|
||||
|
@ -6,13 +8,14 @@ use axum::{
|
|||
};
|
||||
use http::StatusCode;
|
||||
use pasetors::{claims::ClaimsValidationRules, keys::SymmetricKey, version4::V4};
|
||||
use tracing::{debug, error};
|
||||
use tracing::{debug, error, trace};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
db::{verify_user, DbPool},
|
||||
models::{ApiResponse, AppError},
|
||||
requests::AppState,
|
||||
services::auth_token::verify_token,
|
||||
services::{auth_token::verify_token, user_session, CachePool},
|
||||
};
|
||||
|
||||
use super::{UserVerifyGetParams, UserVerifyGetResponse};
|
||||
|
@ -23,16 +26,18 @@ pub async fn user_verification_get_handler(
|
|||
Path(user_id): Path<i32>,
|
||||
Query(query): Query<UserVerifyGetParams>,
|
||||
) -> Result<Response, AppError> {
|
||||
let pool = state.db_pool();
|
||||
let db_pool = state.db_pool();
|
||||
let cache_pool = state.cache_pool();
|
||||
let env = state.env();
|
||||
|
||||
let UserVerifyGetParams { verification_token } = query;
|
||||
let token_key = env.token_key();
|
||||
verify_new_user_request(pool, user_id, verification_token, token_key).await
|
||||
verify_new_user_request(db_pool, cache_pool, user_id, verification_token, token_key).await
|
||||
}
|
||||
|
||||
async fn verify_new_user_request(
|
||||
pool: &DbPool,
|
||||
db_pool: &DbPool,
|
||||
cache_pool: &CachePool,
|
||||
user_id: i32,
|
||||
verification_token: String,
|
||||
token_key: &SymmetricKey<V4>,
|
||||
|
@ -45,15 +50,30 @@ async fn verify_new_user_request(
|
|||
rules
|
||||
};
|
||||
|
||||
let verified_token = verify_token(
|
||||
token_key,
|
||||
verification_token.as_str(),
|
||||
Some(validation_rules.clone()),
|
||||
)
|
||||
.inspect_err(|err| error!(?err))?;
|
||||
|
||||
let token_id = verified_token
|
||||
.payload_claims()
|
||||
.map(|claims| claims.get_claim("jti"))
|
||||
.flatten()
|
||||
.map(|jti| Uuid::from_str(jti.as_str().unwrap()).unwrap())
|
||||
.unwrap();
|
||||
|
||||
let _ = user_session::get_user_session(cache_pool, token_id).await?;
|
||||
|
||||
let response = verify_token(
|
||||
token_key,
|
||||
verification_token.as_str(),
|
||||
Some(validation_rules),
|
||||
)
|
||||
.map(|_| UserVerifyGetResponse::new(token_key, user_id))
|
||||
.inspect_err(|err| error!(?err))?;
|
||||
.map(|_| UserVerifyGetResponse::new(token_key, user_id))?;
|
||||
|
||||
verify_user(pool, user_id)
|
||||
verify_user(db_pool, user_id)
|
||||
.await
|
||||
.inspect_err(|err| error!(?err))?;
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ use pasetors::{
|
|||
footer::Footer,
|
||||
keys::SymmetricKey,
|
||||
local,
|
||||
token::UntrustedToken,
|
||||
token::{TrustedToken, UntrustedToken},
|
||||
version4::V4,
|
||||
};
|
||||
use tracing::error;
|
||||
|
@ -21,7 +21,7 @@ pub fn verify_token(
|
|||
key: &SymmetricKey<V4>,
|
||||
token: &str,
|
||||
validation_rules: Option<ClaimsValidationRules>,
|
||||
) -> Result<(), AppError> {
|
||||
) -> Result<TrustedToken, AppError> {
|
||||
let token = UntrustedToken::try_from(token)
|
||||
.inspect_err(|err| error!(?err))
|
||||
.map_err(|_| AppError::invalid_token())?;
|
||||
|
@ -38,7 +38,7 @@ pub fn verify_token(
|
|||
footer
|
||||
};
|
||||
|
||||
let _ = local::decrypt(
|
||||
let token = local::decrypt(
|
||||
key,
|
||||
&token,
|
||||
&validation_rules,
|
||||
|
@ -48,7 +48,7 @@ pub fn verify_token(
|
|||
.inspect_err(|err| error!(?err))
|
||||
.map_err(|_| AppError::invalid_token())?;
|
||||
|
||||
Ok(())
|
||||
Ok(token)
|
||||
}
|
||||
|
||||
pub fn generate_access_token(key: &SymmetricKey<V4>, user_id: i32) -> (String, Uuid, SystemTime) {
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
use std::{fmt::Debug, time::Duration};
|
||||
use std::{collections::HashMap, fmt::Debug, hash::Hash, time::Duration};
|
||||
|
||||
use bb8_redis::{
|
||||
bb8::{Pool, PooledConnection},
|
||||
RedisConnectionManager,
|
||||
};
|
||||
use redis::{AsyncCommands, IntoConnectionInfo, ToRedisArgs};
|
||||
use redis::{AsyncCommands, FromRedisValue, IntoConnectionInfo, ToRedisArgs, Value};
|
||||
|
||||
use crate::models::AppError;
|
||||
|
||||
|
@ -57,6 +57,16 @@ pub async fn store_object_with_expiration<
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn get_object<K: FromRedisValue + Eq + Hash>(
|
||||
cache_pool: &CachePool,
|
||||
key: &str,
|
||||
) -> Result<Option<HashMap<K, Value>>, AppError> {
|
||||
let mut conn = get_connection_from_pool(cache_pool).await?;
|
||||
let hash_fields: Option<HashMap<_, Value>> = conn.hgetall(key).await?;
|
||||
|
||||
Ok(hash_fields)
|
||||
}
|
||||
|
||||
async fn get_connection_from_pool(
|
||||
cache_pool: &CachePool,
|
||||
) -> Result<PooledConnection<'_, RedisConnectionManager>, AppError> {
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
use std::time::Duration;
|
||||
use std::time::{Duration, SystemTime};
|
||||
|
||||
use humantime::parse_rfc3339;
|
||||
use redis::FromRedisValue;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
|
@ -17,8 +19,7 @@ pub async fn store_user_session(
|
|||
session: Session,
|
||||
expiration: Option<Duration>,
|
||||
) -> Result<(), AppError> {
|
||||
let hashed_token = blake3::hash(token_id.as_bytes());
|
||||
let key = format!("{USER_SESSION_CACHE_KEY_PREFIX}{hashed_token}");
|
||||
let key = make_key(token_id);
|
||||
let cacheable_object = session.to_cacheable_object();
|
||||
|
||||
match expiration {
|
||||
|
@ -36,3 +37,61 @@ pub async fn store_user_session(
|
|||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn get_user_session(
|
||||
cache_pool: &CachePool,
|
||||
token_id: Uuid,
|
||||
) -> Result<Option<Session>, AppError> {
|
||||
let key = make_key(token_id);
|
||||
let object = cache::get_object::<String>(cache_pool, key.as_str()).await?;
|
||||
|
||||
if let Some(object) = object {
|
||||
let user_id: i32 = object
|
||||
.get("userId")
|
||||
.ok_or(AppError::missing_session_field("userId"))
|
||||
.and_then(|value| {
|
||||
FromRedisValue::from_redis_value(value)
|
||||
.map_err(|_| AppError::missing_session_field("userId"))
|
||||
})?;
|
||||
|
||||
let username: String = object
|
||||
.get("username")
|
||||
.ok_or(AppError::missing_session_field("username"))
|
||||
.and_then(|value| {
|
||||
FromRedisValue::from_redis_value(value)
|
||||
.map_err(|_| AppError::missing_session_field("username"))
|
||||
})?;
|
||||
|
||||
let created_at: SystemTime = object
|
||||
.get("createdAt")
|
||||
.ok_or(AppError::missing_session_field("createdAt"))
|
||||
.and_then(|value| {
|
||||
FromRedisValue::from_redis_value(value)
|
||||
.map_err(|_| AppError::missing_session_field("createdAt"))
|
||||
})
|
||||
.map(|serialized: String| parse_rfc3339(serialized.as_str()).unwrap())?;
|
||||
|
||||
let expires_at: SystemTime = object
|
||||
.get("createdAt")
|
||||
.ok_or(AppError::missing_session_field("expiresAt"))
|
||||
.and_then(|value| {
|
||||
FromRedisValue::from_redis_value(value)
|
||||
.map_err(|_| AppError::missing_session_field("expiresAt"))
|
||||
})
|
||||
.map(|serialized: String| parse_rfc3339(serialized.as_str()).unwrap())?;
|
||||
|
||||
Ok(Some(Session {
|
||||
user_id,
|
||||
username,
|
||||
created_at,
|
||||
expires_at,
|
||||
}))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
fn make_key(token_id: Uuid) -> String {
|
||||
let hashed_token = blake3::hash(token_id.as_bytes());
|
||||
format!("{USER_SESSION_CACHE_KEY_PREFIX}{hashed_token}")
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue