refactor: switch error handle to snafu
This commit is contained in:
@@ -1,5 +1,3 @@
|
||||
use std::fmt;
|
||||
|
||||
use async_graphql::dynamic::ResolverContext;
|
||||
use axum::{
|
||||
Json,
|
||||
@@ -11,72 +9,86 @@ use openidconnect::{
|
||||
StandardErrorResponse, core::CoreErrorResponseType,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
use snafu::prelude::*;
|
||||
|
||||
use crate::{fetch::HttpClientError, models::auth::AuthType};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[derive(Debug, Snafu)]
|
||||
#[snafu(visibility(pub(crate)))]
|
||||
pub enum AuthError {
|
||||
#[error("Not support auth method")]
|
||||
#[snafu(display("Not support auth method"))]
|
||||
NotSupportAuthMethod {
|
||||
supported: Vec<AuthType>,
|
||||
current: AuthType,
|
||||
},
|
||||
#[error("Failed to find auth record")]
|
||||
#[snafu(display("Failed to find auth record"))]
|
||||
FindAuthRecordError,
|
||||
#[error("Invalid credentials")]
|
||||
#[snafu(display("Invalid credentials"))]
|
||||
BasicInvalidCredentials,
|
||||
#[error(transparent)]
|
||||
OidcInitError(#[from] jwt_authorizer::error::InitError),
|
||||
#[error("Invalid oidc provider meta client error: {0}")]
|
||||
OidcProviderHttpClientError(HttpClientError),
|
||||
#[error(transparent)]
|
||||
OidcProviderMetaError(#[from] openidconnect::DiscoveryError<HttpClientError>),
|
||||
#[error("Invalid oidc provider URL: {0}")]
|
||||
OidcProviderUrlError(url::ParseError),
|
||||
#[error("Invalid oidc redirect URI: {0}")]
|
||||
OidcRequestRedirectUriError(url::ParseError),
|
||||
#[error("Oidc request session not found or expired")]
|
||||
#[snafu(transparent)]
|
||||
OidcInitError {
|
||||
source: jwt_authorizer::error::InitError,
|
||||
},
|
||||
#[snafu(display("Invalid oidc provider meta client error: {source}"))]
|
||||
OidcProviderHttpClientError { source: HttpClientError },
|
||||
#[snafu(transparent)]
|
||||
OidcProviderMetaError {
|
||||
source: openidconnect::DiscoveryError<HttpClientError>,
|
||||
},
|
||||
#[snafu(display("Invalid oidc provider URL: {source}"))]
|
||||
OidcProviderUrlError { source: url::ParseError },
|
||||
#[snafu(display("Invalid oidc redirect URI: {source}"))]
|
||||
OidcRequestRedirectUriError {
|
||||
#[snafu(source)]
|
||||
source: url::ParseError,
|
||||
},
|
||||
#[snafu(display("Oidc request session not found or expired"))]
|
||||
OidcCallbackRecordNotFoundOrExpiredError,
|
||||
#[error("Invalid oidc request callback nonce")]
|
||||
#[snafu(display("Invalid oidc request callback nonce"))]
|
||||
OidcInvalidNonceError,
|
||||
#[error("Invalid oidc request callback state")]
|
||||
#[snafu(display("Invalid oidc request callback state"))]
|
||||
OidcInvalidStateError,
|
||||
#[error("Invalid oidc request callback code")]
|
||||
#[snafu(display("Invalid oidc request callback code"))]
|
||||
OidcInvalidCodeError,
|
||||
#[error(transparent)]
|
||||
OidcCallbackTokenConfigurationError(#[from] ConfigurationError),
|
||||
#[error(transparent)]
|
||||
OidcRequestTokenError(
|
||||
#[from] RequestTokenError<HttpClientError, StandardErrorResponse<CoreErrorResponseType>>,
|
||||
),
|
||||
#[error("Invalid oidc id token")]
|
||||
#[snafu(transparent)]
|
||||
OidcCallbackTokenConfigurationError { source: ConfigurationError },
|
||||
#[snafu(transparent)]
|
||||
OidcRequestTokenError {
|
||||
source: RequestTokenError<HttpClientError, StandardErrorResponse<CoreErrorResponseType>>,
|
||||
},
|
||||
#[snafu(display("Invalid oidc id token"))]
|
||||
OidcInvalidIdTokenError,
|
||||
#[error("Invalid oidc access token")]
|
||||
#[snafu(display("Invalid oidc access token"))]
|
||||
OidcInvalidAccessTokenError,
|
||||
#[error(transparent)]
|
||||
OidcSignatureVerificationError(#[from] SignatureVerificationError),
|
||||
#[error(transparent)]
|
||||
OidcSigningError(#[from] SigningError),
|
||||
#[error(transparent)]
|
||||
OidcJwtAuthError(#[from] jwt_authorizer::AuthError),
|
||||
#[error("Extra scopes {expected} do not match found scopes {found}")]
|
||||
#[snafu(transparent)]
|
||||
OidcSignatureVerificationError { source: SignatureVerificationError },
|
||||
#[snafu(transparent)]
|
||||
OidcSigningError { source: SigningError },
|
||||
#[snafu(transparent)]
|
||||
OidcJwtAuthError { source: jwt_authorizer::AuthError },
|
||||
#[snafu(display("Extra scopes {expected} do not match found scopes {found}"))]
|
||||
OidcExtraScopesMatchError { expected: String, found: String },
|
||||
#[error("Extra claim {key} does not match expected value {expected}, found {found}")]
|
||||
#[snafu(display("Extra claim {key} does not match expected value {expected}, found {found}"))]
|
||||
OidcExtraClaimMatchError {
|
||||
key: String,
|
||||
expected: String,
|
||||
found: String,
|
||||
},
|
||||
#[error("Extra claim {0} missing")]
|
||||
OidcExtraClaimMissingError(String),
|
||||
#[error("Audience {0} missing")]
|
||||
OidcAudMissingError(String),
|
||||
#[error("Subject missing")]
|
||||
#[snafu(display("Extra claim {claim} missing"))]
|
||||
OidcExtraClaimMissingError { claim: String },
|
||||
#[snafu(display("Audience {aud} missing"))]
|
||||
OidcAudMissingError { aud: String },
|
||||
#[snafu(display("Subject missing"))]
|
||||
OidcSubMissingError,
|
||||
#[error(fmt = display_graphql_permission_error)]
|
||||
#[snafu(display(
|
||||
"GraphQL permission denied since {context_path}{}{field}{}{column}: {}",
|
||||
(if field.is_empty() { "" } else { "." }),
|
||||
(if column.is_empty() { "" } else { "." }),
|
||||
source.message
|
||||
))]
|
||||
GraphQLPermissionError {
|
||||
inner_error: async_graphql::Error,
|
||||
#[snafu(source(false))]
|
||||
source: Box<async_graphql::Error>,
|
||||
field: String,
|
||||
column: String,
|
||||
context_path: String,
|
||||
@@ -85,13 +97,13 @@ pub enum AuthError {
|
||||
|
||||
impl AuthError {
|
||||
pub fn from_graphql_subscribe_id_guard(
|
||||
inner_error: async_graphql::Error,
|
||||
source: async_graphql::Error,
|
||||
context: &ResolverContext,
|
||||
field_name: &str,
|
||||
column_name: &str,
|
||||
) -> AuthError {
|
||||
AuthError::GraphQLPermissionError {
|
||||
inner_error,
|
||||
source: Box::new(source),
|
||||
field: field_name.to_string(),
|
||||
column: column_name.to_string(),
|
||||
context_path: context
|
||||
@@ -103,22 +115,6 @@ impl AuthError {
|
||||
}
|
||||
}
|
||||
|
||||
fn display_graphql_permission_error(
|
||||
inner_error: &async_graphql::Error,
|
||||
field: &String,
|
||||
column: &String,
|
||||
context_path: &String,
|
||||
formatter: &mut fmt::Formatter<'_>,
|
||||
) -> fmt::Result {
|
||||
write!(
|
||||
formatter,
|
||||
"GraphQL permission denied since {context_path}{}{field}{}{column}: {}",
|
||||
(if field.is_empty() { "" } else { "." }),
|
||||
(if column.is_empty() { "" } else { "." }),
|
||||
inner_error.message
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct AuthErrorResponse {
|
||||
pub success: bool,
|
||||
|
||||
@@ -16,11 +16,12 @@ use openidconnect::{
|
||||
use sea_orm::DbErr;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use snafu::ResultExt;
|
||||
use url::Url;
|
||||
|
||||
use super::{
|
||||
config::OidcAuthConfig,
|
||||
errors::AuthError,
|
||||
errors::{AuthError, OidcProviderUrlSnafu, OidcRequestRedirectUriSnafu},
|
||||
service::{AuthServiceTrait, AuthUserInfo},
|
||||
};
|
||||
use crate::{app::AppContextTrait, errors::RError, fetch::HttpClient, models::auth::AuthType};
|
||||
@@ -125,13 +126,13 @@ impl OidcAuthService {
|
||||
redirect_uri: &str,
|
||||
) -> Result<OidcAuthRequest, AuthError> {
|
||||
let provider_metadata = CoreProviderMetadata::discover_async(
|
||||
IssuerUrl::new(self.config.issuer.clone()).map_err(AuthError::OidcProviderUrlError)?,
|
||||
IssuerUrl::new(self.config.issuer.clone()).context(OidcProviderUrlSnafu)?,
|
||||
&self.oidc_provider_client,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let redirect_uri = RedirectUrl::new(redirect_uri.to_string())
|
||||
.map_err(AuthError::OidcRequestRedirectUriError)?;
|
||||
let redirect_uri =
|
||||
RedirectUrl::new(redirect_uri.to_string()).context(OidcRequestRedirectUriSnafu)?;
|
||||
|
||||
let oidc_client = CoreClient::from_provider_metadata(
|
||||
provider_metadata,
|
||||
@@ -207,7 +208,7 @@ impl OidcAuthService {
|
||||
let request_cache = self.load_authorization_request(&csrf_token).await?;
|
||||
|
||||
let provider_metadata = CoreProviderMetadata::discover_async(
|
||||
IssuerUrl::new(self.config.issuer.clone()).map_err(AuthError::OidcProviderUrlError)?,
|
||||
IssuerUrl::new(self.config.issuer.clone()).context(OidcProviderUrlSnafu)?,
|
||||
&self.oidc_provider_client,
|
||||
)
|
||||
.await?;
|
||||
@@ -265,9 +266,10 @@ impl AuthServiceTrait for OidcAuthService {
|
||||
request: &mut Parts,
|
||||
) -> Result<AuthUserInfo, AuthError> {
|
||||
let config = &self.config;
|
||||
let token = self.api_authorizer.extract_token(&request.headers).ok_or(
|
||||
AuthError::OidcJwtAuthError(jwt_authorizer::AuthError::MissingToken()),
|
||||
)?;
|
||||
let token = self
|
||||
.api_authorizer
|
||||
.extract_token(&request.headers)
|
||||
.ok_or(jwt_authorizer::AuthError::MissingToken())?;
|
||||
|
||||
let token_data = self.api_authorizer.check_auth(&token).await?;
|
||||
let claims = token_data.claims;
|
||||
@@ -277,7 +279,9 @@ impl AuthServiceTrait for OidcAuthService {
|
||||
return Err(AuthError::OidcSubMissingError);
|
||||
};
|
||||
if !claims.contains_audience(&config.audience) {
|
||||
return Err(AuthError::OidcAudMissingError(config.audience.clone()));
|
||||
return Err(AuthError::OidcAudMissingError {
|
||||
aud: config.audience.clone(),
|
||||
});
|
||||
}
|
||||
if let Some(expected_scopes) = config.extra_scopes.as_ref() {
|
||||
let found_scopes = claims.scopes().collect::<HashSet<_>>();
|
||||
@@ -293,7 +297,7 @@ impl AuthServiceTrait for OidcAuthService {
|
||||
}
|
||||
if let Some(key) = config.extra_claim_key.as_ref() {
|
||||
if !claims.has_claim(key) {
|
||||
return Err(AuthError::OidcExtraClaimMissingError(key.clone()));
|
||||
return Err(AuthError::OidcExtraClaimMissingError { claim: key.clone() });
|
||||
}
|
||||
if let Some(value) = config.extra_claim_value.as_ref() {
|
||||
if claims.get_claim(key).is_none_or(|v| &v != value) {
|
||||
@@ -306,9 +310,9 @@ impl AuthServiceTrait for OidcAuthService {
|
||||
}
|
||||
}
|
||||
let subscriber_auth = match crate::models::auth::Model::find_by_pid(ctx, sub).await {
|
||||
Err(RError::DbError(DbErr::RecordNotFound(..))) => {
|
||||
crate::models::auth::Model::create_from_oidc(ctx, sub.to_string()).await
|
||||
}
|
||||
Err(RError::DbError {
|
||||
source: DbErr::RecordNotFound(..),
|
||||
}) => crate::models::auth::Model::create_from_oidc(ctx, sub.to_string()).await,
|
||||
r => r,
|
||||
}
|
||||
.map_err(|_| AuthError::FindAuthRecordError)?;
|
||||
|
||||
@@ -9,11 +9,12 @@ use axum::{
|
||||
use jwt_authorizer::{JwtAuthorizer, Validation};
|
||||
use moka::future::Cache;
|
||||
use reqwest::header::HeaderValue;
|
||||
use snafu::prelude::*;
|
||||
|
||||
use super::{
|
||||
AuthConfig,
|
||||
basic::BasicAuthService,
|
||||
errors::AuthError,
|
||||
errors::{AuthError, OidcProviderHttpClientSnafu},
|
||||
oidc::{OidcAuthClaims, OidcAuthService},
|
||||
};
|
||||
use crate::{
|
||||
@@ -59,14 +60,14 @@ pub trait AuthServiceTrait {
|
||||
}
|
||||
|
||||
pub enum AuthService {
|
||||
Basic(BasicAuthService),
|
||||
Oidc(OidcAuthService),
|
||||
Basic(Box<BasicAuthService>),
|
||||
Oidc(Box<OidcAuthService>),
|
||||
}
|
||||
|
||||
impl AuthService {
|
||||
pub async fn from_conf(config: AuthConfig) -> Result<Self, AuthError> {
|
||||
let result = match config {
|
||||
AuthConfig::Basic(config) => AuthService::Basic(BasicAuthService { config }),
|
||||
AuthConfig::Basic(config) => AuthService::Basic(Box::new(BasicAuthService { config })),
|
||||
AuthConfig::Oidc(config) => {
|
||||
let validation = Validation::new()
|
||||
.iss(&[&config.issuer])
|
||||
@@ -78,14 +79,14 @@ impl AuthService {
|
||||
cache_preset: Some(HttpClientCachePresetConfig::RFC7234),
|
||||
..Default::default()
|
||||
})
|
||||
.map_err(AuthError::OidcProviderHttpClientError)?;
|
||||
.context(OidcProviderHttpClientSnafu)?;
|
||||
|
||||
let api_authorizer = JwtAuthorizer::<OidcAuthClaims>::from_oidc(&config.issuer)
|
||||
.validation(validation)
|
||||
.build()
|
||||
.await?;
|
||||
|
||||
AuthService::Oidc(OidcAuthService {
|
||||
AuthService::Oidc(Box::new(OidcAuthService {
|
||||
config,
|
||||
api_authorizer,
|
||||
oidc_provider_client,
|
||||
@@ -93,7 +94,7 @@ impl AuthService {
|
||||
.time_to_live(Duration::from_mins(5))
|
||||
.name("oidc_request_cache")
|
||||
.build(),
|
||||
})
|
||||
}))
|
||||
}
|
||||
};
|
||||
Ok(result)
|
||||
|
||||
Reference in New Issue
Block a user