fix: add basic auth and oidc auth

This commit is contained in:
2024-12-31 00:52:44 +08:00
parent 4c6cc1116b
commit abd399aacd
39 changed files with 712 additions and 49 deletions

View File

@@ -0,0 +1,31 @@
use axum::{http::request::Parts, RequestPartsExt};
use axum_auth::AuthBasic;
use super::{
config::BasicAuthConfig,
errors::AuthError,
service::{AuthService, AuthUserInfo},
};
use crate::models::{auth::AuthType, subscribers::SEED_SUBSCRIBER};
#[derive(Debug)]
pub struct BasicAuthService {
pub config: BasicAuthConfig,
}
#[async_trait::async_trait]
impl AuthService for BasicAuthService {
async fn extract_user_info(&self, request: &mut Parts) -> Result<AuthUserInfo, AuthError> {
if let Ok(AuthBasic((found_user, found_password))) = request.extract().await {
if self.config.user == found_user
&& self.config.password == found_password.unwrap_or_default()
{
return Ok(AuthUserInfo {
user_pid: SEED_SUBSCRIBER.to_string(),
auth_type: AuthType::Basic,
});
}
}
Err(AuthError::BasicInvalidCredentials)
}
}

View File

@@ -0,0 +1,31 @@
use jwt_authorizer::OneOrArray;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct BasicAuthConfig {
#[serde(rename = "basic_user")]
pub user: String,
#[serde(rename = "basic_password")]
pub password: String,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct OidcAuthConfig {
#[serde(rename = "oidc_api_issuer")]
pub issuer: String,
#[serde(rename = "oidc_api_audience")]
pub audience: String,
#[serde(rename = "oidc_extra_scopes")]
pub extra_scopes: Option<OneOrArray<String>>,
#[serde(rename = "oidc_extra_claim_key")]
pub extra_claim_key: Option<String>,
#[serde(rename = "oidc_extra_claim_value")]
pub extra_claim_value: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum AppAuthConfig {
Basic(BasicAuthConfig),
Oidc(OidcAuthConfig),
}

View File

@@ -0,0 +1,35 @@
use axum::{
http::StatusCode,
response::{IntoResponse, Response},
};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum AuthError {
#[error(transparent)]
OidcInitError(#[from] jwt_authorizer::error::InitError),
#[error("Invalid credentials")]
BasicInvalidCredentials,
#[error(transparent)]
OidcJwtAuthError(#[from] jwt_authorizer::AuthError),
#[error("Extra scopes {expected} do not match found scopes {found}")]
OidcExtraScopesMatchError { expected: String, found: String },
#[error("Extra claim {key} does not match expected value {expected}, found {found}")]
OidcExtraClaimMatchError {
key: String,
expected: String,
found: String,
},
#[error("Extra claim {0} missing")]
OidcExtraClaimMissingError(String),
#[error("Audience {0} missing")]
OidcAudMissingError(String),
#[error("Subject missing")]
OidcSubMissingError,
}
impl IntoResponse for AuthError {
fn into_response(self) -> Response {
(StatusCode::UNAUTHORIZED, self.to_string()).into_response()
}
}

View File

@@ -0,0 +1,9 @@
pub mod basic;
pub mod config;
pub mod errors;
pub mod oidc;
pub mod service;
pub use config::{AppAuthConfig, BasicAuthConfig, OidcAuthConfig};
pub use errors::AuthError;
pub use service::{AppAuthService, AuthService, AuthUserInfo};

View File

@@ -0,0 +1,137 @@
use std::collections::{HashMap, HashSet};
use axum::http::request::Parts;
use itertools::Itertools;
use jwt_authorizer::{authorizer::Authorizer, NumericDate, OneOrArray};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use super::{
config::OidcAuthConfig,
errors::AuthError,
service::{AuthService, AuthUserInfo},
};
use crate::models::auth::AuthType;
#[derive(Deserialize, Serialize, Clone, Debug)]
pub struct OidcAuthClaims {
#[serde(skip_serializing_if = "Option::is_none")]
pub iss: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sub: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub aud: Option<OneOrArray<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub exp: Option<NumericDate>,
#[serde(skip_serializing_if = "Option::is_none")]
pub nbf: Option<NumericDate>,
#[serde(skip_serializing_if = "Option::is_none")]
pub iat: Option<NumericDate>,
#[serde(skip_serializing_if = "Option::is_none")]
pub jti: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub scope: Option<String>,
#[serde(flatten)]
pub custom: HashMap<String, Value>,
}
impl OidcAuthClaims {
pub fn scopes(&self) -> std::str::Split<'_, char> {
self.scope.as_deref().unwrap_or_default().split(',')
}
pub fn get_claim(&self, key: &str) -> Option<String> {
match key {
"iss" => self.iss.clone(),
"sub" => self.sub.clone(),
"aud" => self.aud.as_ref().map(|s| s.iter().join(",")),
"exp" => self.exp.clone().map(|s| s.0.to_string()),
"nbf" => self.nbf.clone().map(|s| s.0.to_string()),
"iat" => self.iat.clone().map(|s| s.0.to_string()),
"jti" => self.jti.clone(),
"scope" => self.scope.clone(),
key => self.custom.get(key).map(|s| s.to_string()),
}
}
pub fn has_claim(&self, key: &str) -> bool {
match key {
"iss" => self.iss.is_some(),
"sub" => self.sub.is_some(),
"aud" => self.aud.is_some(),
"exp" => self.exp.is_some(),
"nbf" => self.nbf.is_some(),
"iat" => self.iat.is_some(),
"jti" => self.jti.is_some(),
"scope" => self.scope.is_some(),
key => self.custom.contains_key(key),
}
}
pub fn contains_audience(&self, aud: &str) -> bool {
self.aud
.as_ref()
.is_some_and(|arr| arr.iter().any(|s| s == aud))
}
}
pub struct OidcAuthService {
pub config: OidcAuthConfig,
pub authorizer: Authorizer<OidcAuthClaims>,
}
#[async_trait::async_trait]
impl AuthService for OidcAuthService {
async fn extract_user_info(&self, request: &mut Parts) -> Result<AuthUserInfo, AuthError> {
let config = &self.config;
let token =
self.authorizer
.extract_token(&request.headers)
.ok_or(AuthError::OidcJwtAuthError(
jwt_authorizer::AuthError::MissingToken(),
))?;
let token_data = self.authorizer.check_auth(&token).await?;
let claims = token_data.claims;
if !claims.sub.as_deref().is_some_and(|s| !s.trim().is_empty()) {
return Err(AuthError::OidcSubMissingError);
}
if !claims.contains_audience(&config.audience) {
return Err(AuthError::OidcAudMissingError(config.audience.clone()));
}
if let Some(expected_scopes) = config.extra_scopes.as_ref() {
let found_scopes = claims.scopes().collect::<HashSet<_>>();
if !expected_scopes
.iter()
.all(|es| found_scopes.contains(&es as &str))
{
return Err(AuthError::OidcExtraScopesMatchError {
expected: expected_scopes.iter().join(","),
found: claims.scope.unwrap_or_default(),
});
}
}
if let Some(key) = config.extra_claim_key.as_ref() {
if !claims.has_claim(key) {
return Err(AuthError::OidcExtraClaimMissingError(key.clone()));
}
if let Some(value) = config.extra_claim_value.as_ref() {
if claims.get_claim(key).is_none_or(|v| &v != value) {
return Err(AuthError::OidcExtraClaimMatchError {
expected: value.clone(),
found: claims.get_claim(key).unwrap_or_default().to_string(),
key: key.clone(),
});
}
}
}
Ok(AuthUserInfo {
user_pid: claims
.sub
.as_deref()
.map(|s| s.trim().to_string())
.unwrap_or_else(|| unreachable!("sub should be present and validated")),
auth_type: AuthType::Oidc,
})
}
}

View File

@@ -0,0 +1,116 @@
use axum::{
extract::FromRequestParts,
http::request::Parts,
response::{IntoResponse as _, Response},
Extension,
};
use jwt_authorizer::{JwtAuthorizer, Validation};
use loco_rs::app::{AppContext, Initializer};
use once_cell::sync::OnceCell;
use super::{
basic::BasicAuthService,
errors::AuthError,
oidc::{OidcAuthClaims, OidcAuthService},
AppAuthConfig,
};
use crate::{app::AppContextExt as _, config::AppConfigExt, models::auth::AuthType};
pub struct AuthUserInfo {
pub user_pid: String,
pub auth_type: AuthType,
}
#[async_trait::async_trait]
impl<S> FromRequestParts<S> for AuthUserInfo
where
S: Send + Sync,
{
type Rejection = Response;
async fn from_request_parts(req: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let Extension(ctx) = Extension::<AppContext>::from_request_parts(req, state)
.await
.expect("AppContext should be present");
let auth_service = ctx.get_auth_service();
auth_service
.extract_user_info(req)
.await
.map_err(|err| err.into_response())
}
}
#[async_trait::async_trait]
pub trait AuthService {
async fn extract_user_info(&self, request: &mut Parts) -> Result<AuthUserInfo, AuthError>;
}
pub enum AppAuthService {
Basic(BasicAuthService),
Oidc(OidcAuthService),
}
static APP_AUTH_SERVICE: OnceCell<AppAuthService> = OnceCell::new();
impl AppAuthService {
pub fn app_instance() -> &'static Self {
APP_AUTH_SERVICE
.get()
.expect("AppAuthService is not initialized")
}
pub async fn from_conf(config: AppAuthConfig) -> Result<Self, AuthError> {
let result = match config {
AppAuthConfig::Basic(config) => AppAuthService::Basic(BasicAuthService { config }),
AppAuthConfig::Oidc(config) => {
let validation = Validation::new()
.iss(&[&config.issuer])
.aud(&[&config.audience]);
let jwt_auth = JwtAuthorizer::<OidcAuthClaims>::from_oidc(&config.issuer)
.validation(validation)
.build()
.await?;
AppAuthService::Oidc(OidcAuthService {
config,
authorizer: jwt_auth,
})
}
};
Ok(result)
}
}
#[async_trait::async_trait]
impl AuthService for AppAuthService {
async fn extract_user_info(&self, request: &mut Parts) -> Result<AuthUserInfo, AuthError> {
match self {
AppAuthService::Basic(service) => service.extract_user_info(request).await,
AppAuthService::Oidc(service) => service.extract_user_info(request).await,
}
}
}
pub struct AppAuthServiceInitializer;
#[async_trait::async_trait]
impl Initializer for AppAuthServiceInitializer {
fn name(&self) -> String {
String::from("AppAuthServiceInitializer")
}
async fn before_run(&self, ctx: &AppContext) -> Result<(), loco_rs::Error> {
let auth_conf = ctx.config.get_app_conf()?.auth;
let service = AppAuthService::from_conf(auth_conf)
.await
.map_err(|e| loco_rs::Error::wrap(e))?;
APP_AUTH_SERVICE.get_or_init(|| service);
Ok(())
}
}