feat: alter unsafe packages

This commit is contained in:
2025-05-23 02:54:53 +08:00
parent f1d8318500
commit 0fcbc6bbe9
15 changed files with 470 additions and 405 deletions

View File

@@ -1,6 +1,8 @@
use jwt_authorizer::OneOrArray;
use std::collections::HashMap;
use jwtk::OneOrMany;
use serde::{Deserialize, Serialize};
use serde_with::{NoneAsEmptyString, serde_as};
use serde_with::serde_as;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct BasicAuthConfig {
@@ -22,13 +24,9 @@ pub struct OidcAuthConfig {
#[serde(rename = "oidc_client_secret")]
pub client_secret: String,
#[serde(rename = "oidc_extra_scopes")]
pub extra_scopes: Option<OneOrArray<String>>,
#[serde_as(as = "NoneAsEmptyString")]
#[serde(rename = "oidc_extra_claim_key")]
pub extra_claim_key: Option<String>,
#[serde(rename = "oidc_extra_claim_value")]
#[serde_as(as = "NoneAsEmptyString")]
pub extra_claim_value: Option<String>,
pub extra_scopes: Option<OneOrMany<String>>,
#[serde(rename = "oidc_extra_claims")]
pub extra_claims: Option<HashMap<String, Option<String>>>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]

View File

@@ -27,10 +27,6 @@ pub enum AuthError {
FindAuthRecordError,
#[snafu(display("Invalid credentials"))]
BasicInvalidCredentials,
#[snafu(transparent)]
OidcInitError {
source: jwt_authorizer::error::InitError,
},
#[snafu(display("Invalid oidc provider meta client error: {source}"))]
OidcProviderHttpClientError { source: HttpClientError },
#[snafu(transparent)]
@@ -66,8 +62,10 @@ pub enum AuthError {
OidcSignatureVerificationError { source: SignatureVerificationError },
#[snafu(transparent)]
OidcSigningError { source: SigningError },
#[snafu(display("Missing Bearer token"))]
OidcMissingBearerToken,
#[snafu(transparent)]
OidcJwtAuthError { source: jwt_authorizer::AuthError },
OidcJwtkError { source: jwtk::Error },
#[snafu(display("Extra scopes {expected} do not match found scopes {found}"))]
OidcExtraScopesMatchError { expected: String, found: String },
#[snafu(display("Extra claim {key} does not match expected value {expected}, found {found}"))]

View File

@@ -12,8 +12,9 @@ use axum::{
http::{HeaderValue, request::Parts},
};
use fetch::{HttpClient, client::HttpClientError};
use http::header::AUTHORIZATION;
use itertools::Itertools;
use jwt_authorizer::{NumericDate, OneOrArray, authorizer::Authorizer};
use jwtk::jwk::RemoteJwksVerifier;
use moka::future::Cache;
use openidconnect::{
AccessTokenHash, AuthorizationCode, ClientId, ClientSecret, CsrfToken, IssuerUrl, Nonce,
@@ -77,21 +78,6 @@ impl<'c> openidconnect::AsyncHttpClient<'c> for OidcHttpClient {
#[derive(Deserialize, Serialize, Clone, Debug)]
pub struct OidcAuthClaims {
#[serde(skip_serializing_if = "Option::is_none")]
pub iss: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sub: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub aud: Option<OneOrArray<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub exp: Option<NumericDate>,
#[serde(skip_serializing_if = "Option::is_none")]
pub nbf: Option<NumericDate>,
#[serde(skip_serializing_if = "Option::is_none")]
pub iat: Option<NumericDate>,
#[serde(skip_serializing_if = "Option::is_none")]
pub jti: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub scope: Option<String>,
#[serde(flatten)]
pub custom: HashMap<String, Value>,
@@ -101,40 +87,6 @@ impl OidcAuthClaims {
pub fn scopes(&self) -> std::str::Split<'_, char> {
self.scope.as_deref().unwrap_or_default().split(',')
}
pub fn get_claim(&self, key: &str) -> Option<String> {
match key {
"iss" => self.iss.clone(),
"sub" => self.sub.clone(),
"aud" => self.aud.as_ref().map(|s| s.iter().join(",")),
"exp" => self.exp.clone().map(|s| s.0.to_string()),
"nbf" => self.nbf.clone().map(|s| s.0.to_string()),
"iat" => self.iat.clone().map(|s| s.0.to_string()),
"jti" => self.jti.clone(),
"scope" => self.scope.clone(),
key => self.custom.get(key).map(|s| s.to_string()),
}
}
pub fn has_claim(&self, key: &str) -> bool {
match key {
"iss" => self.iss.is_some(),
"sub" => self.sub.is_some(),
"aud" => self.aud.is_some(),
"exp" => self.exp.is_some(),
"nbf" => self.nbf.is_some(),
"iat" => self.iat.is_some(),
"jti" => self.jti.is_some(),
"scope" => self.scope.is_some(),
key => self.custom.contains_key(key),
}
}
pub fn contains_audience(&self, aud: &str) -> bool {
self.aud
.as_ref()
.is_some_and(|arr| arr.iter().any(|s| s == aud))
}
}
#[derive(Debug, Clone, Serialize)]
@@ -164,7 +116,7 @@ pub struct OidcAuthCallbackPayload {
pub struct OidcAuthService {
pub config: OidcAuthConfig,
pub api_authorizer: Authorizer<OidcAuthClaims>,
pub jwk_verifier: RemoteJwksVerifier,
pub oidc_provider_client: Arc<HttpClient>,
pub oidc_request_cache: Cache<String, OidcAuthRequest>,
}
@@ -317,47 +269,68 @@ impl AuthServiceTrait for OidcAuthService {
request: &mut Parts,
) -> Result<AuthUserInfo, AuthError> {
let config = &self.config;
let token = self
.api_authorizer
.extract_token(&request.headers)
.ok_or(jwt_authorizer::AuthError::MissingToken())?;
let token = request
.headers
.get(AUTHORIZATION)
.and_then(|authorization| {
authorization
.to_str()
.ok()
.and_then(|s| s.strip_prefix("Bearer "))
})
.ok_or(AuthError::OidcMissingBearerToken)?;
let token_data = self.api_authorizer.check_auth(&token).await?;
let claims = token_data.claims;
let token_data = self.jwk_verifier.verify::<OidcAuthClaims>(token).await?;
let claims = token_data.claims();
let sub = if let Some(sub) = claims.sub.as_deref() {
sub
} else {
return Err(AuthError::OidcSubMissingError);
};
if !claims.contains_audience(&config.audience) {
if !claims.aud.iter().any(|aud| aud == &config.audience) {
return Err(AuthError::OidcAudMissingError {
aud: config.audience.clone(),
});
}
let extra_claims = &claims.extra;
if let Some(expected_scopes) = config.extra_scopes.as_ref() {
let found_scopes = claims.scopes().collect::<HashSet<_>>();
let found_scopes = extra_claims.scopes().collect::<HashSet<_>>();
if !expected_scopes
.iter()
.all(|es| found_scopes.contains(es as &str))
{
return Err(AuthError::OidcExtraScopesMatchError {
expected: expected_scopes.iter().join(","),
found: claims.scope.unwrap_or_default(),
found: extra_claims
.scope
.as_deref()
.unwrap_or_default()
.to_string(),
});
}
}
if let Some(key) = config.extra_claim_key.as_ref() {
if !claims.has_claim(key) {
return Err(AuthError::OidcExtraClaimMissingError { claim: key.clone() });
}
if let Some(value) = config.extra_claim_value.as_ref()
&& claims.get_claim(key).is_none_or(|v| &v != value)
{
return Err(AuthError::OidcExtraClaimMatchError {
expected: value.clone(),
found: claims.get_claim(key).unwrap_or_default().to_string(),
key: key.clone(),
});
if let Some(expected_extra_claims) = config.extra_claims.as_ref() {
for (expected_key, expected_value) in expected_extra_claims.iter() {
match (extra_claims.custom.get(expected_key), expected_value) {
(found_value, Some(expected_value)) => {
if let Some(Value::String(found_value)) = found_value
&& expected_value == found_value
{
} else {
return Err(AuthError::OidcExtraClaimMatchError {
expected: expected_value.clone(),
found: found_value.map(|v| v.to_string()).unwrap_or_default(),
key: expected_key.clone(),
});
}
}
(None, None) => {
return Err(AuthError::OidcExtraClaimMissingError {
claim: expected_key.clone(),
});
}
_ => {}
}
}
}
let subscriber_auth = match crate::models::auth::Model::find_by_pid(ctx, sub).await {

View File

@@ -1,25 +1,22 @@
use std::{sync::Arc, time::Duration};
use async_trait::async_trait;
use axum::{
extract::FromRequestParts,
http::request::Parts,
response::{IntoResponse as _, Response},
};
use axum::http::request::Parts;
use fetch::{
HttpClient, HttpClientConfig,
client::{HttpClientCacheBackendConfig, HttpClientCachePresetConfig},
};
use http::header::HeaderValue;
use jwt_authorizer::{JwtAuthorizer, Validation};
use jwtk::jwk::RemoteJwksVerifier;
use moka::future::Cache;
use openidconnect::{IssuerUrl, core::CoreProviderMetadata};
use snafu::prelude::*;
use super::{
AuthConfig,
basic::BasicAuthService,
errors::{AuthError, OidcProviderHttpClientSnafu},
oidc::{OidcAuthClaims, OidcAuthService},
errors::{AuthError, OidcProviderHttpClientSnafu, OidcProviderUrlSnafu},
oidc::{OidcAuthService, OidcHttpClient},
};
use crate::{app::AppContextTrait, models::auth::AuthType};
@@ -29,22 +26,6 @@ pub struct AuthUserInfo {
pub auth_type: AuthType,
}
impl FromRequestParts<Arc<dyn AppContextTrait>> for AuthUserInfo {
type Rejection = Response;
async fn from_request_parts(
parts: &mut Parts,
state: &Arc<dyn AppContextTrait>,
) -> Result<Self, Self::Rejection> {
let auth_service = state.auth();
auth_service
.extract_user_info(state.as_ref(), parts)
.await
.map_err(|err| err.into_response())
}
}
#[async_trait]
pub trait AuthServiceTrait {
async fn extract_user_info(
@@ -66,27 +47,33 @@ impl AuthService {
let result = match config {
AuthConfig::Basic(config) => AuthService::Basic(Box::new(BasicAuthService { config })),
AuthConfig::Oidc(config) => {
let validation = Validation::new()
.iss(&[&config.issuer])
.aud(&[&config.audience]);
let oidc_provider_client = Arc::new(
HttpClient::from_config(HttpClientConfig {
exponential_backoff_max_retries: Some(3),
cache_backend: Some(HttpClientCacheBackendConfig::Moka { cache_size: 1 }),
cache_preset: Some(HttpClientCachePresetConfig::RFC7234),
..Default::default()
})
.context(OidcProviderHttpClientSnafu)?,
);
let oidc_provider_client = HttpClient::from_config(HttpClientConfig {
exponential_backoff_max_retries: Some(3),
cache_backend: Some(HttpClientCacheBackendConfig::Moka { cache_size: 1 }),
cache_preset: Some(HttpClientCachePresetConfig::RFC7234),
..Default::default()
})
.context(OidcProviderHttpClientSnafu)?;
let provider_metadata = {
let client = OidcHttpClient(oidc_provider_client.clone());
let issuer_url =
IssuerUrl::new(config.issuer.clone()).context(OidcProviderUrlSnafu)?;
CoreProviderMetadata::discover_async(issuer_url, &client).await
}?;
let api_authorizer = JwtAuthorizer::<OidcAuthClaims>::from_oidc(&config.issuer)
.validation(validation)
.build()
.await?;
let jwk_verifier = RemoteJwksVerifier::new(
provider_metadata.jwks_uri().to_string().clone(),
None,
Duration::from_secs(300),
);
AuthService::Oidc(Box::new(OidcAuthService {
config,
api_authorizer,
oidc_provider_client: Arc::new(oidc_provider_client),
jwk_verifier,
oidc_provider_client,
oidc_request_cache: Cache::builder()
.time_to_live(Duration::from_mins(5))
.name("oidc_request_cache")
@@ -100,6 +87,7 @@ impl AuthService {
#[async_trait]
impl AuthServiceTrait for AuthService {
#[tracing::instrument(skip(self, ctx, request))]
async fn extract_user_info(
&self,
ctx: &dyn AppContextTrait,