refactor: remove loco-rs deps

This commit is contained in:
2025-03-01 15:21:14 +08:00
parent a68aab1452
commit 2844e1fc32
66 changed files with 2565 additions and 1876 deletions

View File

@@ -0,0 +1,30 @@
use serde::{Deserialize, Serialize};
use super::middleware::MiddlewareConfig;
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct WebServerConfig {
/// The address on which the server should listen on for incoming
/// connections.
#[serde(default = "default_binding")]
pub binding: String,
/// The port on which the server should listen for incoming connections.
#[serde(default = "default_port")]
pub port: u16,
/// The webserver host
pub host: String,
/// Identify via the `Server` header
pub ident: Option<String>,
/// Middleware configurations for the server, including payload limits,
/// logging, and error handling.
#[serde(default)]
pub middlewares: MiddlewareConfig,
}
pub fn default_binding() -> String {
"127.0.0.1".to_string()
}
pub fn default_port() -> u16 {
5_001
}

View File

@@ -0,0 +1,52 @@
import type { Injector } from '@outposts/injection-js';
import {
// Link,
Outlet,
createRootRouteWithContext,
} from '@tanstack/react-router';
import { TanStackRouterDevtools } from '@tanstack/router-devtools';
import type { OidcSecurityService } from 'oidc-client-rx';
export type RouterContext =
| {
isAuthenticated: false;
injector: Injector;
oidcSecurityService: OidcSecurityService;
}
| {
isAuthenticated: true;
injector?: Injector;
oidcSecurityService?: OidcSecurityService;
};
export const Route = createRootRouteWithContext<RouterContext>()({
component: RootComponent,
});
function RootComponent() {
return (
<>
{/* <div className="flex gap-2 p-2 text-lg ">
<Link
to="/"
activeProps={{
className: 'font-bold',
}}
>
Home
</Link>{' '}
<Link
to="/graphql"
activeProps={{
className: 'font-bold',
}}
>
GraphQL
</Link>
</div> */}
{/* <hr /> */}
<Outlet />
<TanStackRouterDevtools position="bottom-right" />
</>
);
}

View File

@@ -0,0 +1,50 @@
use std::{borrow::Cow, sync::Arc};
use axum::Router;
use crate::app::AppContext;
pub trait ControllerTrait: Sized {
fn apply_to(self, router: Router<Arc<AppContext>>) -> Router<Arc<AppContext>>;
}
pub struct PrefixController {
prefix: Cow<'static, str>,
router: Router<Arc<AppContext>>,
}
impl PrefixController {
pub fn new(prefix: impl Into<Cow<'static, str>>, router: Router<Arc<AppContext>>) -> Self {
Self {
prefix: prefix.into(),
router,
}
}
}
impl ControllerTrait for PrefixController {
fn apply_to(self, router: Router<Arc<AppContext>>) -> Router<Arc<AppContext>> {
router.nest(&self.prefix, self.router)
}
}
pub enum Controller {
Prefix(PrefixController),
}
impl Controller {
pub fn from_prefix(
prefix: impl Into<Cow<'static, str>>,
router: Router<Arc<AppContext>>,
) -> Self {
Self::Prefix(PrefixController::new(prefix, router))
}
}
impl ControllerTrait for Controller {
fn apply_to(self, router: Router<Arc<AppContext>>) -> Router<Arc<AppContext>> {
match self {
Self::Prefix(p) => p.apply_to(router),
}
}
}

View File

@@ -0,0 +1,36 @@
import { type Fetcher, createGraphiQLFetcher } from '@graphiql/toolkit';
import { createFileRoute } from '@tanstack/react-router';
import GraphiQL from 'graphiql';
import { useMemo } from 'react';
import { beforeLoadGuard } from '../../../auth/guard';
import 'graphiql/graphiql.css';
import { firstValueFrom } from 'rxjs';
import { useAuth } from '../../../auth/hooks';
export const Route = createFileRoute('/graphql/')({
component: RouteComponent,
beforeLoad: beforeLoadGuard,
});
function RouteComponent() {
const { oidcSecurityService } = useAuth();
const fetcher = useMemo(
(): Fetcher => async (props) => {
const accessToken = oidcSecurityService
? await firstValueFrom(oidcSecurityService.getAccessToken())
: undefined;
return createGraphiQLFetcher({
url: '/api/graphql',
headers: accessToken
? {
Authorization: `Bearer ${accessToken}`,
}
: undefined,
})(props);
},
[oidcSecurityService]
);
return <GraphiQL fetcher={fetcher} className="h-svh" />;
}

View File

@@ -0,0 +1,33 @@
use std::sync::Arc;
use async_graphql_axum::{GraphQLRequest, GraphQLResponse};
use axum::{Extension, Router, extract::State, middleware::from_fn_with_state, routing::post};
use super::core::Controller;
use crate::{
app::AppContext,
auth::{AuthUserInfo, header_www_authenticate_middleware},
errors::RResult,
};
pub const CONTROLLER_PREFIX: &str = "/api/graphql";
async fn graphql_handler(
State(ctx): State<Arc<AppContext>>,
Extension(auth_user_info): Extension<AuthUserInfo>,
req: GraphQLRequest,
) -> GraphQLResponse {
let graphql_service = &ctx.graphql;
let mut req = req.into_inner();
req = req.data(auth_user_info);
graphql_service.schema.execute(req).await.into()
}
pub async fn create(ctx: Arc<AppContext>) -> RResult<Controller> {
let router = Router::<Arc<AppContext>>::new()
.route("/", post(graphql_handler))
.layer(from_fn_with_state(ctx, header_www_authenticate_middleware));
Ok(Controller::from_prefix(CONTROLLER_PREFIX, router))
}

View File

@@ -0,0 +1,9 @@
import { createFileRoute } from '@tanstack/react-router'
export const Route = createFileRoute('/')({
component: RouteComponent,
})
function RouteComponent() {
return <div>Hello to playground!</div>
}

View File

@@ -0,0 +1,5 @@
pub mod core;
pub mod graphql;
pub mod oidc;
pub use core::{Controller, ControllerTrait, PrefixController};

View File

@@ -0,0 +1,32 @@
import { createFileRoute, redirect } from '@tanstack/react-router';
import { EventTypes } from 'oidc-client-rx';
import { useAuth } from '../../../auth/hooks';
export const Route = createFileRoute('/oidc/callback')({
component: RouteComponent,
beforeLoad: ({ context }) => {
if (!context.oidcSecurityService) {
throw redirect({
to: '/',
});
}
},
});
function RouteComponent() {
const auth = useAuth();
if (!auth.checkAuthResultEvent) {
return <div>Loading...</div>;
}
return (
<div>
OpenID Connect Auth Callback:{' '}
{auth.checkAuthResultEvent?.type ===
EventTypes.CheckingAuthFinishedWithError
? auth.checkAuthResultEvent.value
: 'success'}
</div>
);
}

View File

@@ -0,0 +1,79 @@
use std::sync::Arc;
use axum::{
Json, Router,
extract::{Query, State},
http::request::Parts,
routing::get,
};
use super::core::Controller;
use crate::{
app::AppContext,
auth::{
AuthError, AuthService, AuthServiceTrait,
oidc::{OidcAuthCallbackPayload, OidcAuthCallbackQuery, OidcAuthRequest},
},
errors::RResult,
extract::http::ForwardedRelatedInfo,
models::auth::AuthType,
};
pub const CONTROLLER_PREFIX: &str = "/api/oidc";
async fn oidc_callback(
State(ctx): State<Arc<AppContext>>,
Query(query): Query<OidcAuthCallbackQuery>,
) -> Result<Json<OidcAuthCallbackPayload>, AuthError> {
let auth_service = &ctx.auth;
if let AuthService::Oidc(oidc_auth_service) = auth_service {
let response = oidc_auth_service
.extract_authorization_request_callback(query)
.await?;
Ok(Json(response))
} else {
Err(AuthError::NotSupportAuthMethod {
supported: vec![auth_service.auth_type()],
current: AuthType::Oidc,
})
}
}
async fn oidc_auth(
State(ctx): State<Arc<AppContext>>,
parts: Parts,
) -> Result<Json<OidcAuthRequest>, AuthError> {
let auth_service = &ctx.auth;
if let AuthService::Oidc(oidc_auth_service) = auth_service {
let mut redirect_uri = ForwardedRelatedInfo::from_request_parts(&parts)
.resolved_origin()
.ok_or_else(|| AuthError::OidcRequestRedirectUriError(url::ParseError::EmptyHost))?;
redirect_uri.set_path(&format!("{CONTROLLER_PREFIX}/callback"));
let auth_request = oidc_auth_service
.build_authorization_request(redirect_uri.as_str())
.await?;
{
oidc_auth_service
.store_authorization_request(auth_request.clone())
.await?;
}
Ok(Json(auth_request))
} else {
Err(AuthError::NotSupportAuthMethod {
supported: vec![auth_service.auth_type()],
current: AuthType::Oidc,
})
}
}
pub async fn create(_context: Arc<AppContext>) -> RResult<Controller> {
let router = Router::<Arc<AppContext>>::new()
.route("/auth", get(oidc_auth))
.route("/callback", get(oidc_callback));
Ok(Controller::from_prefix(CONTROLLER_PREFIX, router))
}

View File

@@ -0,0 +1,58 @@
//! Catch Panic Middleware for Axum
//!
//! This middleware catches panics that occur during request handling in the
//! application. When a panic occurs, it logs the error and returns an
//! internal server error response. This middleware helps ensure that the
//! application can gracefully handle unexpected errors without crashing the
//! server.
use std::sync::Arc;
use axum::{Router, response::IntoResponse};
use http::StatusCode;
use serde::{Deserialize, Serialize};
use tower_http::catch_panic::CatchPanicLayer;
use crate::{app::AppContext, errors::RResult, web::middleware::MiddlewareLayer};
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct CatchPanic {
#[serde(default)]
pub enable: bool,
}
/// Handler function for the [`CatchPanicLayer`] middleware.
///
/// This function processes panics by extracting error messages, logging them,
/// and returning an internal server error response.
#[allow(clippy::needless_pass_by_value)]
fn handle_panic(err: Box<dyn std::any::Any + Send + 'static>) -> axum::response::Response {
let err = err.downcast_ref::<String>().map_or_else(
|| err.downcast_ref::<&str>().map_or("no error details", |s| s),
|s| s.as_str(),
);
tracing::error!(err.msg = err, "server_panic");
StatusCode::INTERNAL_SERVER_ERROR.into_response()
}
impl MiddlewareLayer for CatchPanic {
/// Returns the name of the middleware
fn name(&self) -> &'static str {
"catch_panic"
}
/// Returns whether the middleware is enabled or not
fn is_enabled(&self) -> bool {
self.enable
}
fn config(&self) -> serde_json::Result<serde_json::Value> {
serde_json::to_value(self)
}
/// Applies the Catch Panic middleware layer to the Axum router.
fn apply(&self, app: Router<Arc<AppContext>>) -> RResult<Router<Arc<AppContext>>> {
Ok(app.layer(CatchPanicLayer::custom(handle_panic)))
}
}

View File

@@ -0,0 +1,41 @@
//! Compression Middleware for Axum
//!
//! This middleware applies compression to HTTP responses to reduce the size of
//! the data being transmitted. This can improve performance by decreasing load
//! times and reducing bandwidth usage. The middleware configuration allows for
//! enabling or disabling compression based on the application settings.
use std::sync::Arc;
use axum::Router;
use serde::{Deserialize, Serialize};
use tower_http::compression::CompressionLayer;
use crate::{app::AppContext, errors::RResult, web::middleware::MiddlewareLayer};
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Compression {
#[serde(default)]
pub enable: bool,
}
impl MiddlewareLayer for Compression {
/// Returns the name of the middleware
fn name(&self) -> &'static str {
"compression"
}
/// Returns whether the middleware is enabled or not
fn is_enabled(&self) -> bool {
self.enable
}
fn config(&self) -> serde_json::Result<serde_json::Value> {
serde_json::to_value(self)
}
/// Applies the Compression middleware layer to the Axum router.
fn apply(&self, app: Router<Arc<AppContext>>) -> RResult<Router<Arc<AppContext>>> {
Ok(app.layer(CompressionLayer::new()))
}
}

View File

@@ -0,0 +1,163 @@
//! Configurable and Flexible CORS Middleware
//!
//! This middleware enables Cross-Origin Resource Sharing (CORS) by allowing
//! configurable origins, methods, and headers in HTTP requests. It can be
//! tailored to fit various application requirements, supporting permissive CORS
//! or specific rules as defined in the middleware configuration.
use std::{sync::Arc, time::Duration};
use axum::Router;
use serde::{Deserialize, Serialize};
use serde_json::json;
use tower_http::cors::{self, Any};
use crate::{app::AppContext, web::middleware::MiddlewareLayer, errors::RResult};
/// CORS middleware configuration
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Cors {
#[serde(default)]
pub enable: bool,
/// Allow origins
#[serde(default = "default_allow_origins")]
pub allow_origins: Vec<String>,
/// Allow headers
#[serde(default = "default_allow_headers")]
pub allow_headers: Vec<String>,
/// Allow methods
#[serde(default = "default_allow_methods")]
pub allow_methods: Vec<String>,
/// Allow credentials
#[serde(default)]
pub allow_credentials: bool,
/// Max age
pub max_age: Option<u64>,
// Vary headers
#[serde(default = "default_vary_headers")]
pub vary: Vec<String>,
}
fn default_allow_origins() -> Vec<String> {
vec!["*".to_string()]
}
fn default_allow_headers() -> Vec<String> {
vec!["*".to_string()]
}
fn default_allow_methods() -> Vec<String> {
vec!["*".to_string()]
}
fn default_vary_headers() -> Vec<String> {
vec![
"origin".to_string(),
"access-control-request-method".to_string(),
"access-control-request-headers".to_string(),
]
}
impl Default for Cors {
fn default() -> Self {
serde_json::from_value(json!({})).unwrap()
}
}
impl Cors {
/// Creates cors layer
///
/// # Errors
///
/// This function returns an error in the following cases:
///
/// - If any of the provided origins in `allow_origins` cannot be parsed as
/// a valid URI, the function will return a parsing error.
/// - If any of the provided headers in `allow_headers` cannot be parsed as
/// valid HTTP headers, the function will return a parsing error.
/// - If any of the provided methods in `allow_methods` cannot be parsed as
/// valid HTTP methods, the function will return a parsing error.
///
/// In all of these cases, the error returned will be the result of the
/// `parse` method of the corresponding type.
pub fn cors(&self) -> RResult<cors::CorsLayer> {
let mut cors: cors::CorsLayer = cors::CorsLayer::new();
// testing CORS, assuming https://example.com in the allow list:
// $ curl -v --request OPTIONS 'localhost:5150/api/_ping' -H 'Origin: https://example.com' -H 'Acces
// look for '< access-control-allow-origin: https://example.com' in response.
// if it doesn't appear (test with a bogus domain), it is not allowed.
if self.allow_origins == default_allow_origins() {
cors = cors.allow_origin(Any);
} else {
let mut list = vec![];
for origin in &self.allow_origins {
list.push(origin.parse()?);
}
if !list.is_empty() {
cors = cors.allow_origin(list);
}
}
if self.allow_headers == default_allow_headers() {
cors = cors.allow_headers(Any);
} else {
let mut list = vec![];
for header in &self.allow_headers {
list.push(header.parse()?);
}
if !list.is_empty() {
cors = cors.allow_headers(list);
}
}
if self.allow_methods == default_allow_methods() {
cors = cors.allow_methods(Any);
} else {
let mut list = vec![];
for method in &self.allow_methods {
list.push(method.parse()?);
}
if !list.is_empty() {
cors = cors.allow_methods(list);
}
}
let mut list = vec![];
for v in &self.vary {
list.push(v.parse()?);
}
if !list.is_empty() {
cors = cors.vary(list);
}
if let Some(max_age) = self.max_age {
cors = cors.max_age(Duration::from_secs(max_age));
}
cors = cors.allow_credentials(self.allow_credentials);
Ok(cors)
}
}
impl MiddlewareLayer for Cors {
/// Returns the name of the middleware
fn name(&self) -> &'static str {
"cors"
}
/// Returns whether the middleware is enabled or not
fn is_enabled(&self) -> bool {
self.enable
}
fn config(&self) -> serde_json::Result<serde_json::Value> {
serde_json::to_value(self)
}
/// Applies the CORS middleware layer to the Axum router.
fn apply(&self, app: Router<Arc<AppContext>>) -> RResult<Router<Arc<AppContext>>> {
Ok(app.layer(self.cors()?))
}
}

View File

@@ -0,0 +1,111 @@
//! `ETag` Middleware for Caching Requests
//!
//! This middleware implements the [ETag](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/ETag)
//! HTTP header for caching responses in Axum. `ETags` are used to validate
//! cache entries by comparing a client's stored `ETag` with the one generated
//! by the server. If the `ETags` match, a `304 Not Modified` response is sent,
//! avoiding the need to resend the full content.
use std::{
sync::Arc,
task::{Context, Poll},
};
use axum::{
Router,
body::Body,
extract::Request,
http::{
StatusCode,
header::{ETAG, IF_NONE_MATCH},
},
response::Response,
};
use futures_util::future::BoxFuture;
use serde::{Deserialize, Serialize};
use tower::{Layer, Service};
use crate::{app::AppContext, errors::RResult, web::middleware::MiddlewareLayer};
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Etag {
#[serde(default)]
pub enable: bool,
}
impl MiddlewareLayer for Etag {
/// Returns the name of the middleware
fn name(&self) -> &'static str {
"etag"
}
/// Returns whether the middleware is enabled or not
fn is_enabled(&self) -> bool {
self.enable
}
fn config(&self) -> serde_json::Result<serde_json::Value> {
serde_json::to_value(self)
}
/// Applies the `ETag` middleware to the application router.
fn apply(&self, app: Router<Arc<AppContext>>) -> RResult<Router<Arc<AppContext>>> {
Ok(app.layer(EtagLayer))
}
}
/// [`EtagLayer`] struct for adding `ETag` functionality as a Tower service
/// layer.
#[derive(Default, Clone)]
struct EtagLayer;
impl<S> Layer<S> for EtagLayer {
type Service = EtagMiddleware<S>;
fn layer(&self, inner: S) -> Self::Service {
EtagMiddleware { inner }
}
}
#[derive(Clone)]
struct EtagMiddleware<S> {
inner: S,
}
impl<S> Service<Request<Body>> for EtagMiddleware<S>
where
S: Service<Request, Response = Response> + Send + 'static,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
// `BoxFuture` is a type alias for `Pin<Box<dyn Future + Send + 'a>>`
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, request: Request) -> Self::Future {
let ifnm = request.headers().get(IF_NONE_MATCH).cloned();
let future = self.inner.call(request);
let res_fut = async move {
let response = future.await?;
let etag_from_response = response.headers().get(ETAG).cloned();
if let Some(etag_in_request) = ifnm {
if let Some(etag_from_response) = etag_from_response {
if etag_in_request == etag_from_response {
return Ok(Response::builder()
.status(StatusCode::NOT_MODIFIED)
.body(Body::empty())
.unwrap());
}
}
}
Ok(response)
};
Box::pin(res_fut)
}
}

View File

@@ -0,0 +1,71 @@
//! Detect a content type and format and responds accordingly
use axum::{
extract::FromRequestParts,
http::{
header::{ACCEPT, CONTENT_TYPE, HeaderMap},
request::Parts,
},
};
use serde::{Deserialize, Serialize};
use crate::errors::RError as Error;
#[derive(Debug, Deserialize, Serialize)]
pub struct Format(pub RespondTo);
#[derive(Debug, Deserialize, Serialize)]
pub enum RespondTo {
None,
Html,
Json,
Xml,
Other(String),
}
fn detect_format(content_type: &str) -> RespondTo {
if content_type.starts_with("application/json") {
RespondTo::Json
} else if content_type.starts_with("text/html") {
RespondTo::Html
} else if content_type.starts_with("text/xml")
|| content_type.starts_with("application/xml")
|| content_type.starts_with("application/xhtml")
{
RespondTo::Xml
} else {
RespondTo::Other(content_type.to_string())
}
}
pub fn get_respond_to(headers: &HeaderMap) -> RespondTo {
#[allow(clippy::option_if_let_else)]
if let Some(content_type) = headers.get(CONTENT_TYPE).and_then(|h| h.to_str().ok()) {
detect_format(content_type)
} else if let Some(content_type) = headers.get(ACCEPT).and_then(|h| h.to_str().ok()) {
detect_format(content_type)
} else {
RespondTo::None
}
}
impl<S> FromRequestParts<S> for Format
where
S: Send + Sync,
{
type Rejection = Error;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Error> {
Ok(Self(get_respond_to(&parts.headers)))
}
}
impl<S> FromRequestParts<S> for RespondTo
where
S: Send + Sync,
{
type Rejection = Error;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Error> {
Ok(get_respond_to(&parts.headers))
}
}

View File

@@ -0,0 +1,102 @@
//! Logger Middleware
//!
//! This middleware provides logging functionality for HTTP requests. It uses
//! `TraceLayer` to log detailed information about each request, such as the
//! HTTP method, URI, version, user agent, and an associated request ID.
//! Additionally, it integrates the application's runtime environment
//! into the log context, allowing environment-specific logging (e.g.,
//! "development", "production").
use std::sync::Arc;
use axum::{Router, http};
use serde::{Deserialize, Serialize};
use tower_http::{add_extension::AddExtensionLayer, trace::TraceLayer};
use crate::{
app::{AppContext, Environment},
errors::RResult,
web::middleware::{MiddlewareLayer, request_id::LocoRequestId},
};
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Config {
#[serde(default)]
pub enable: bool,
}
/// [`Middleware`] struct responsible for logging HTTP requests.
#[derive(Serialize, Debug)]
pub struct Middleware {
config: Config,
environment: Environment,
}
/// Creates a new instance of [`Middleware`] by cloning the [`Config`]
/// configuration.
#[must_use]
pub fn new(config: &Config, context: Arc<AppContext>) -> Middleware {
Middleware {
config: config.clone(),
environment: context.environment.clone(),
}
}
impl MiddlewareLayer for Middleware {
/// Returns the name of the middleware
fn name(&self) -> &'static str {
"logger"
}
/// Returns whether the middleware is enabled or not
fn is_enabled(&self) -> bool {
self.config.enable
}
fn config(&self) -> serde_json::Result<serde_json::Value> {
serde_json::to_value(self)
}
/// Applies the logger middleware to the application router by adding layers
/// for:
///
/// - `TraceLayer`: Logs detailed information about each HTTP request.
/// - `AddExtensionLayer`: Adds the current environment to the request
/// extensions, making it accessible to the `TraceLayer` for logging.
///
/// The `TraceLayer` is customized with `make_span_with` to extract
/// request-specific details like method, URI, version, user agent, and
/// request ID, then create a tracing span for the request.
fn apply(&self, app: Router<Arc<AppContext>>) -> RResult<Router<Arc<AppContext>>> {
Ok(app
.layer(
TraceLayer::new_for_http().make_span_with(|request: &http::Request<_>| {
let ext = request.extensions();
let request_id = ext
.get::<LocoRequestId>()
.map_or_else(|| "req-id-none".to_string(), |r| r.get().to_string());
let user_agent = request
.headers()
.get(axum::http::header::USER_AGENT)
.map_or("", |h| h.to_str().unwrap_or(""));
let env: String = request
.extensions()
.get::<Environment>()
.map(|e| e.full_name().to_string())
.unwrap_or_default();
tracing::error_span!(
"http-request",
"http.method" = tracing::field::display(request.method()),
"http.uri" = tracing::field::display(request.uri()),
"http.version" = tracing::field::debug(request.version()),
"http.user_agent" = tracing::field::display(user_agent),
"environment" = tracing::field::display(env),
request_id = tracing::field::display(request_id),
)
}),
)
.layer(AddExtensionLayer::new(self.environment.clone())))
}
}

View File

@@ -0,0 +1,165 @@
pub mod catch_panic;
pub mod compression;
pub mod cors;
pub mod etag;
pub mod format;
pub mod logger;
pub mod remote_ip;
pub mod request_id;
pub mod secure_headers;
pub mod timeout;
use std::sync::Arc;
use axum::Router;
use serde::{Deserialize, Serialize};
use crate::{app::AppContext, errors::RResult};
/// Trait representing the behavior of middleware components in the application.
/// When implementing a new middleware, make sure to go over this checklist:
/// * The name of the middleware should be an ID that is similar to the field
/// name in configuration (look at how `serde` calls it)
/// * Default value implementation should be paired with `serde` default
/// handlers and default serialization implementation. Which means deriving
/// `Default` will _not_ work. You can use `serde_json` and serialize a new
/// config from an empty value, which will cause `serde` default value
/// handlers to kick in.
/// * If you need completely blank values for configuration (for example for
/// testing), implement an `::empty() -> Self` call ad-hoc.
pub trait MiddlewareLayer {
/// Returns the name of the middleware.
/// This should match the name of the property in the containing
/// `middleware` section in configuration (as named by `serde`)
fn name(&self) -> &'static str;
/// Returns whether the middleware is enabled or not.
/// If the middleware is switchable, take this value from a configuration
/// value
fn is_enabled(&self) -> bool {
true
}
/// Returns middleware config.
///
/// # Errors
/// when could not convert middleware to [`serde_json::Value`]
fn config(&self) -> serde_json::Result<serde_json::Value>;
/// Applies the middleware to the given Axum router and returns the modified
/// router.
///
/// # Errors
///
/// If there is an issue when adding the middleware to the router.
fn apply(&self, app: Router<Arc<AppContext>>) -> RResult<Router<Arc<AppContext>>>;
}
#[allow(clippy::unnecessary_lazy_evaluations)]
#[must_use]
pub fn default_middleware_stack(ctx: Arc<AppContext>) -> Vec<Box<dyn MiddlewareLayer>> {
// Shortened reference to middlewares
let middlewares = &ctx.config.server.middlewares;
vec![
// CORS middleware with a default if none
Box::new(middlewares.cors.clone().unwrap_or_else(|| cors::Cors {
enable: false,
..Default::default()
})),
// Catch Panic middleware with a default if none
Box::new(
middlewares
.catch_panic
.clone()
.unwrap_or_else(|| catch_panic::CatchPanic { enable: true }),
),
// Etag middleware with a default if none
Box::new(
middlewares
.etag
.clone()
.unwrap_or_else(|| etag::Etag { enable: true }),
),
// Remote IP middleware with a default if none
Box::new(
middlewares
.remote_ip
.clone()
.unwrap_or_else(|| remote_ip::RemoteIpMiddleware {
enable: false,
..Default::default()
}),
),
// Compression middleware with a default if none
Box::new(
middlewares
.compression
.clone()
.unwrap_or_else(|| compression::Compression { enable: false }),
),
// Timeout Request middleware with a default if none
Box::new(
middlewares
.timeout_request
.clone()
.unwrap_or_else(|| timeout::TimeOut {
enable: false,
..Default::default()
}),
),
// Secure Headers middleware with a default if none
Box::new(middlewares.secure_headers.clone().unwrap_or_else(|| {
secure_headers::SecureHeader {
enable: false,
..Default::default()
}
})),
// Logger middleware with default logger configuration
Box::new(logger::new(
&middlewares
.logger
.clone()
.unwrap_or_else(|| logger::Config { enable: true }),
ctx.clone(),
)),
// Request ID middleware with a default if none
Box::new(
middlewares
.request_id
.clone()
.unwrap_or_else(|| request_id::RequestId { enable: true }),
),
]
}
/// Server middleware configuration structure.
#[derive(Default, Debug, Clone, Deserialize, Serialize)]
pub struct MiddlewareConfig {
/// Compression for the response.
pub compression: Option<compression::Compression>,
/// Etag cache headers.
pub etag: Option<etag::Etag>,
/// Logger and augmenting trace id with request data
pub logger: Option<logger::Config>,
/// Catch any code panic and log the error.
pub catch_panic: Option<catch_panic::CatchPanic>,
/// Setting a global timeout for requests
pub timeout_request: Option<timeout::TimeOut>,
/// CORS configuration
pub cors: Option<cors::Cors>,
/// Sets a set of secure headers
pub secure_headers: Option<secure_headers::SecureHeader>,
/// Calculates a remote IP based on `X-Forwarded-For` when behind a proxy
pub remote_ip: Option<remote_ip::RemoteIpMiddleware>,
/// Request ID
pub request_id: Option<request_id::RequestId>,
}

View File

@@ -0,0 +1,306 @@
//! Remote IP Middleware for inferring the client's IP address based on the
//! `X-Forwarded-For` header.
//!
//! This middleware is useful when running behind proxies or load balancers that
//! add the `X-Forwarded-For` header, which includes the original client IP
//! address.
//!
//! The middleware provides a mechanism to configure trusted proxies and extract
//! the most likely client IP from the `X-Forwarded-For` header, skipping any
//! trusted proxy IPs.
use std::{
fmt,
iter::Iterator,
net::{IpAddr, SocketAddr},
str::FromStr,
sync::{Arc, OnceLock},
task::{Context, Poll},
};
use axum::{
Router,
body::Body,
extract::{ConnectInfo, FromRequestParts, Request},
http::{header::HeaderMap, request::Parts},
response::Response,
};
use futures_util::future::BoxFuture;
use ipnetwork::IpNetwork;
use serde::{Deserialize, Serialize};
use tower::{Layer, Service};
use tracing::error;
use crate::{
app::AppContext,
errors::{RError, RResult},
web::middleware::MiddlewareLayer,
};
static LOCAL_TRUSTED_PROXIES: OnceLock<Vec<IpNetwork>> = OnceLock::new();
fn get_local_trusted_proxies() -> &'static Vec<IpNetwork> {
LOCAL_TRUSTED_PROXIES.get_or_init(|| {
[
"127.0.0.0/8", // localhost IPv4 range, per RFC-3330
"::1", // localhost IPv6
"fc00::/7", // private IPv6 range fc00::/7
"10.0.0.0/8", // private IPv4 range 10.x.x.x
"172.16.0.0/12", // private IPv4 range 172.16.0.0 .. 172.31.255.255
"192.168.0.0/16",
]
.iter()
.map(|ip| IpNetwork::from_str(ip).unwrap())
.collect()
})
}
const X_FORWARDED_FOR: &str = "X-Forwarded-For";
///
/// Performs a remote ip "calculation", inferring the most likely
/// client IP from the `X-Forwarded-For` header that is used by
/// load balancers and proxies.
///
/// WARNING
/// =======
///
/// LIKE ANY SUCH REMOTE IP MIDDLEWARE, IN THE WRONG ARCHITECTURE IT CAN MAKE
/// YOU VULNERABLE TO IP SPOOFING.
///
/// This middleware assumes that there is at least one proxy sitting around and
/// setting headers with the client's remote IP address. Otherwise any client
/// can claim to have any IP address by setting the `X-Forwarded-For` header.
///
/// DO NOT USE THIS MIDDLEWARE IF YOU DONT KNOW THAT YOU NEED IT
///
/// -- But if you need it, it's crucial to use it (since it's the only way to
/// get the original client IP)
///
/// This middleware is mostly implemented after the Rails `remote_ip`
/// middleware, and looking at other production Rust services with Axum, taking
/// the best of both worlds to balance performance and pragmatism.
///
/// Similarities to the Rails `remote_ip` middleware:
///
/// * Uses `X-Forwarded-For`
/// * Uses the same built-in trusted proxies list
/// * You can provide a list of `trusted_proxies` which will **replace** the
/// built-in trusted proxies
///
/// Differences from the Rails `remote_ip` middleware:
///
/// * You get an indication if the remote IP is actually resolved or is the
/// socket IP (no `X-Forwarded-For` header or could not parse)
/// * We do not not use the `Client-IP` header, or try to detect "spoofing"
/// (spoofing while doing remote IP resolution is virtually non-detectable)
/// * Order of filtering IPs from `X-Forwarded-For` is done according to [the de
/// facto spec](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For#selecting_an_ip_address)
/// "Trusted proxy list"
#[derive(Default, Serialize, Deserialize, Debug, Clone)]
pub struct RemoteIpMiddleware {
#[serde(default)]
pub enable: bool,
/// A list of alternative proxy list IP ranges and/or network range (will
/// replace built-in proxy list)
pub trusted_proxies: Option<Vec<String>>,
}
impl MiddlewareLayer for RemoteIpMiddleware {
/// Returns the name of the middleware
fn name(&self) -> &'static str {
"remote_ip"
}
/// Returns whether the middleware is enabled or not
fn is_enabled(&self) -> bool {
self.enable
&& (self.trusted_proxies.is_none()
|| self.trusted_proxies.as_ref().is_some_and(|t| !t.is_empty()))
}
fn config(&self) -> serde_json::Result<serde_json::Value> {
serde_json::to_value(self)
}
/// Applies the Remote IP middleware to the given Axum router.
fn apply(&self, app: Router<Arc<AppContext>>) -> RResult<Router<Arc<AppContext>>> {
Ok(app.layer(RemoteIPLayer::new(self)?))
}
}
// implementation reference: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For
fn maybe_get_forwarded(
headers: &HeaderMap,
trusted_proxies: Option<&Vec<IpNetwork>>,
) -> Option<IpAddr> {
/*
> There may be multiple X-Forwarded-For headers present in a request. The IP addresses in these headers must be treated as a single list,
> starting with the first IP address of the first header and continuing to the last IP address of the last header.
> There are two ways of making this single list:
> join the X-Forwarded-For full header values with commas and then split by comma into a list, or
> split each X-Forwarded-For header by comma into lists and then join the lists
*/
let xffs = headers
.get_all(X_FORWARDED_FOR)
.iter()
.map(|hdr| hdr.to_str())
.filter_map(Result::ok)
.collect::<Vec<_>>();
if xffs.is_empty() {
return None;
}
let forwarded = xffs.join(",");
forwarded
.split(',')
.map(str::trim)
.map(str::parse)
.filter_map(Result::ok)
/*
> Trusted proxy list: The IPs or IP ranges of the trusted reverse proxies are configured.
> The X-Forwarded-For IP list is searched from the rightmost, skipping all addresses that
> are on the trusted proxy list. The first non-matching address is the target address.
*/
.filter(|ip| {
// trusted proxies provided REPLACES our default local proxies
let proxies = trusted_proxies.unwrap_or_else(|| get_local_trusted_proxies());
!proxies
.iter()
.any(|trusted_proxy| trusted_proxy.contains(*ip))
})
/*
> When choosing the X-Forwarded-For client IP address closest to the client (untrustworthy
> and not for security-related purposes), the first IP from the leftmost that is a valid
> address and not private/internal should be selected.
>
NOTE:
> The first trustworthy X-Forwarded-For IP address may belong to an untrusted intermediate
> proxy rather than the actual client computer, but it is the only IP suitable for security uses.
*/
.next_back()
}
#[derive(Copy, Clone, Debug)]
pub enum RemoteIP {
Forwarded(IpAddr),
Socket(IpAddr),
None,
}
impl<S> FromRequestParts<S> for RemoteIP
where
S: Send + Sync,
{
type Rejection = ();
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let ip = parts.extensions.get::<Self>();
Ok(*ip.unwrap_or(&Self::None))
}
}
impl fmt::Display for RemoteIP {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Forwarded(ip) => write!(f, "remote: {ip}"),
Self::Socket(ip) => write!(f, "socket: {ip}"),
Self::None => write!(f, "--"),
}
}
}
#[derive(Clone, Debug)]
struct RemoteIPLayer {
trusted_proxies: Option<Vec<IpNetwork>>,
}
impl RemoteIPLayer {
/// Returns new secure headers middleware
///
/// # Errors
/// Fails if invalid header values found
pub fn new(config: &RemoteIpMiddleware) -> RResult<Self> {
Ok(Self {
trusted_proxies: config
.trusted_proxies
.as_ref()
.map(|proxies| {
proxies
.iter()
.map(|proxy| {
IpNetwork::from_str(proxy).map_err(|err| {
RError::CustomMessageString(format!(
"remote ip middleare cannot parse trusted proxy \
configuration: `{proxy}`, reason: `{err}`",
))
})
})
.collect::<RResult<Vec<_>>>()
})
.transpose()?,
})
}
}
impl<S> Layer<S> for RemoteIPLayer {
type Service = RemoteIPMiddleware<S>;
fn layer(&self, inner: S) -> Self::Service {
RemoteIPMiddleware {
inner,
layer: self.clone(),
}
}
}
/// Remote IP Detection Middleware
#[derive(Clone, Debug)]
#[must_use]
pub struct RemoteIPMiddleware<S> {
inner: S,
layer: RemoteIPLayer,
}
impl<S> Service<Request<Body>> for RemoteIPMiddleware<S>
where
S: Service<Request<Body>, Response = Response> + Send + 'static,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: Request<Body>) -> Self::Future {
let layer = self.layer.clone();
let xff_ip = maybe_get_forwarded(req.headers(), layer.trusted_proxies.as_ref());
let remote_ip = xff_ip.map_or_else(
|| {
let ip = req
.extensions()
.get::<ConnectInfo<SocketAddr>>()
.map_or_else(
|| {
error!(
"remote ip middleware cannot get socket IP (not set in axum \
extensions): setting IP to `127.0.0.1`"
);
RemoteIP::None
},
|info| RemoteIP::Socket(info.ip()),
);
ip
},
RemoteIP::Forwarded,
);
req.extensions_mut().insert(remote_ip);
Box::pin(self.inner.call(req))
}
}

View File

@@ -0,0 +1,132 @@
//! Middleware to generate or ensure a unique request ID for every request.
//!
//! The request ID is stored in the `x-request-id` header, and it is either
//! generated or sanitized if already present in the request.
//!
//! This can be useful for tracking requests across services, logging, and
//! debugging.
use axum::{Router, extract::Request, http::HeaderValue, middleware::Next, response::Response};
use regex::Regex;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::{web::middleware::MiddlewareLayer, app::AppContext, errors::RResult};
const X_REQUEST_ID: &str = "x-request-id";
const MAX_LEN: usize = 255;
use std::sync::{Arc, OnceLock};
static ID_CLEANUP: OnceLock<Regex> = OnceLock::new();
fn get_id_cleanup() -> &'static Regex {
ID_CLEANUP.get_or_init(|| Regex::new(r"[^\w\-@]").unwrap())
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct RequestId {
#[serde(default)]
pub enable: bool,
}
impl MiddlewareLayer for RequestId {
/// Returns the name of the middleware
fn name(&self) -> &'static str {
"request_id"
}
/// Returns whether the middleware is enabled or not
fn is_enabled(&self) -> bool {
self.enable
}
fn config(&self) -> serde_json::Result<serde_json::Value> {
serde_json::to_value(self)
}
/// Applies the request ID middleware to the Axum router.
///
/// This function sets up the middleware in the router and ensures that
/// every request passing through it will have a unique or sanitized
/// request ID.
///
/// # Errors
/// This function returns an error if the middleware cannot be applied.
fn apply(&self, app: Router<Arc<AppContext>>) -> RResult<Router<Arc<AppContext>>> {
Ok(app.layer(axum::middleware::from_fn(request_id_middleware)))
}
}
/// Wrapper struct for storing the request ID in the request's extensions.
#[derive(Debug, Clone)]
pub struct LocoRequestId(String);
impl LocoRequestId {
/// Retrieves the request ID as a string slice.
#[must_use]
pub fn get(&self) -> &str {
self.0.as_str()
}
}
/// Middleware function to ensure or generate a unique request ID.
///
/// This function intercepts requests, checks for the presence of the
/// `x-request-id` header, and either sanitizes its value or generates a new
/// UUID if absent. The resulting request ID is added to both the request
/// extensions and the response headers.
pub async fn request_id_middleware(mut request: Request, next: Next) -> Response {
let header_request_id = request.headers().get(X_REQUEST_ID).cloned();
let request_id = make_request_id(header_request_id);
request
.extensions_mut()
.insert(LocoRequestId(request_id.clone()));
let mut res = next.run(request).await;
if let Ok(v) = HeaderValue::from_str(request_id.as_str()) {
res.headers_mut().insert(X_REQUEST_ID, v);
} else {
tracing::warn!("could not set request ID into response headers: `{request_id}`",);
}
res
}
/// Generates or sanitizes a request ID.
fn make_request_id(maybe_request_id: Option<HeaderValue>) -> String {
maybe_request_id
.and_then(|hdr| {
// see: https://github.com/rails/rails/blob/main/actionpack/lib/action_dispatch/middleware/request_id.rb#L39
let id: Option<String> = hdr.to_str().ok().map(|s| {
get_id_cleanup()
.replace_all(s, "")
.chars()
.take(MAX_LEN)
.collect()
});
id.filter(|s| !s.is_empty())
})
.unwrap_or_else(|| Uuid::new_v4().to_string())
}
#[cfg(test)]
mod tests {
use axum::http::HeaderValue;
use insta::assert_debug_snapshot;
use super::make_request_id;
#[test]
fn create_or_fetch_request_id() {
let id = make_request_id(Some(HeaderValue::from_static("foo-bar=baz")));
assert_debug_snapshot!(id);
let id = make_request_id(Some(HeaderValue::from_static("")));
assert_debug_snapshot!(id.len());
let id = make_request_id(Some(HeaderValue::from_static("==========")));
assert_debug_snapshot!(id.len());
let long_id = "x".repeat(1000);
let id = make_request_id(Some(HeaderValue::from_str(&long_id).unwrap()));
assert_debug_snapshot!(id.len());
let id = make_request_id(None);
assert_debug_snapshot!(id.len());
}
}

View File

@@ -0,0 +1,26 @@
{
"empty":{},
"github":{
"Content-Security-Policy": "default-src 'self' https:; font-src 'self' https: data:; img-src 'self' https: data:; object-src 'none'; script-src https:; style-src 'self' https: 'unsafe-inline'",
"Strict-Transport-Security": "max-age=631138519",
"X-Content-Type-Options": "nosniff",
"X-Download-Options": "noopen",
"X-Frame-Options": "sameorigin",
"X-Permitted-Cross-Domain-Policies": "none",
"X-Xss-Protection": "0"
},
"owasp":{
"Cache-Control": "no-store, max-age=0",
"Clear-Site-Data": "\"cache\",\"cookies\",\"storage\"",
"Content-Security-Policy": "default-src 'self'; form-action 'self'; object-src 'none'; frame-ancestors 'none'; upgrade-insecure-requests; block-all-mixed-content",
"Cross-Origin-Embedder-Policy": "require-corp",
"Cross-Origin-Opener-Policy": "same-origin",
"Cross-Origin-Resource-Policy": "same-origin",
"Permissions-Policy": "accelerometer=(), autoplay=(), camera=(), cross-origin-isolated=(), display-capture=(), encrypted-media=(), fullscreen=(), geolocation=(), gyroscope=(), keyboard-map=(), magnetometer=(), microphone=(), midi=(), payment=(), picture-in-picture=(), publickey-credentials-get=(), screen-wake-lock=(), sync-xhr=(self), usb=(), web-share=(), xr-spatial-tracking=(), clipboard-read=(), clipboard-write=(), gamepad=(), hid=(), idle-detection=(), interest-cohort=(), serial=(), unload=()",
"Referrer-Policy": "no-referrer",
"Strict-Transport-Security": "max-age=31536000; includeSubDomains",
"X-Content-Type-Options": "nosniff",
"X-Frame-Options": "deny",
"X-Permitted-Cross-Domain-Policies": "none"
}
}

View File

@@ -0,0 +1,311 @@
//! Sets secure headers for your backend to promote security-by-default.
//!
//! This middleware applies secure HTTP headers, providing pre-defined presets
//! (e.g., "github") and the ability to override or define custom headers.
use std::{
collections::{BTreeMap, HashMap},
sync::{Arc, OnceLock},
task::{Context, Poll},
};
use axum::{
Router,
body::Body,
http::{HeaderName, HeaderValue, Request},
response::Response,
};
use futures_util::future::BoxFuture;
use serde::{Deserialize, Serialize};
use serde_json::{self, json};
use tower::{Layer, Service};
use crate::{
app::AppContext,
web::middleware::MiddlewareLayer,
errors::{RError, RResult},
};
static PRESETS: OnceLock<HashMap<String, BTreeMap<String, String>>> = OnceLock::new();
fn get_presets() -> &'static HashMap<String, BTreeMap<String, String>> {
PRESETS.get_or_init(|| {
let json_data = include_str!("secure_headers.json");
serde_json::from_str(json_data).unwrap()
})
}
/// Sets a predefined or custom set of secure headers.
///
/// We recommend our `github` preset. Presets values are derived
/// from the [secure_headers](https://github.com/github/secure_headers) Ruby
/// library which Github (and originally Twitter) use.
///
/// To use a preset, in your `config/development.yaml`:
///
/// ```yaml
/// middlewares:
/// secure_headers:
/// preset: github
/// ```
///
/// You can also override individual headers on a given preset:
///
/// ```yaml
/// middlewares:
/// secure_headers:
/// preset: github
/// overrides:
/// foo: bar
/// ```
///
/// Or start from scratch:
///
///```yaml
/// middlewares:
/// secure_headers:
/// preset: empty
/// overrides:
/// one: two
/// ```
///
/// To support `htmx`, You can add the following override, to allow some inline
/// running of scripts:
///
/// ```yaml
/// secure_headers:
/// preset: github
/// overrides:
/// # this allows you to use HTMX, and has unsafe-inline. Remove or consider in production
/// "Content-Security-Policy": "default-src 'self' https:; font-src 'self' https: data:; img-src 'self' https: data:; object-src 'none'; script-src 'unsafe-inline' 'self' https:; style-src 'self' https: 'unsafe-inline'"
/// ```
///
/// For the list of presets and their content look at [secure_headers.json](https://github.com/loco-rs/loco/blob/master/src/controller/middleware/secure_headers.rs)
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct SecureHeader {
#[serde(default)]
pub enable: bool,
#[serde(default = "default_preset")]
pub preset: String,
#[serde(default)]
pub overrides: Option<BTreeMap<String, String>>,
}
impl Default for SecureHeader {
fn default() -> Self {
serde_json::from_value(json!({})).unwrap()
}
}
fn default_preset() -> String {
"github".to_string()
}
impl MiddlewareLayer for SecureHeader {
/// Returns the name of the middleware
fn name(&self) -> &'static str {
"secure_headers"
}
/// Returns whether the middleware is enabled or not
fn is_enabled(&self) -> bool {
self.enable
}
fn config(&self) -> serde_json::Result<serde_json::Value> {
serde_json::to_value(self)
}
/// Applies the secure headers layer to the application router
fn apply(&self, app: Router<Arc<AppContext>>) -> RResult<Router<Arc<AppContext>>> {
Ok(app.layer(SecureHeaders::new(self)?))
}
}
impl SecureHeader {
/// Converts the configuration into a list of headers.
///
/// Applies the preset headers and any custom overrides.
fn as_headers(&self) -> RResult<Vec<(HeaderName, HeaderValue)>> {
let mut headers = vec![];
let preset = &self.preset;
let p = get_presets().get(preset).ok_or_else(|| {
RError::CustomMessageString(format!(
"secure_headers: a preset named `{preset}` does not exist"
))
})?;
Self::push_headers(&mut headers, p)?;
if let Some(overrides) = &self.overrides {
Self::push_headers(&mut headers, overrides)?;
}
Ok(headers)
}
/// Helper function to push headers into a mutable vector.
///
/// This function takes a map of header names and values, converting them
/// into valid HTTP headers and adding them to the provided `headers`
/// vector.
fn push_headers(
headers: &mut Vec<(HeaderName, HeaderValue)>,
hm: &BTreeMap<String, String>,
) -> RResult<()> {
for (k, v) in hm {
headers.push((
HeaderName::from_bytes(k.clone().as_bytes())?,
HeaderValue::from_str(v.clone().as_str())?,
));
}
Ok(())
}
}
/// The [`SecureHeaders`] layer which wraps around the service and injects
/// security headers
#[derive(Clone, Debug)]
pub struct SecureHeaders {
headers: Vec<(HeaderName, HeaderValue)>,
}
impl SecureHeaders {
/// Creates a new [`SecureHeaders`] instance with the provided
/// configuration.
///
/// # Errors
/// Returns an error if any header values are invalid.
pub fn new(config: &SecureHeader) -> RResult<Self> {
Ok(Self {
headers: config.as_headers()?,
})
}
}
impl<S> Layer<S> for SecureHeaders {
type Service = SecureHeadersMiddleware<S>;
/// Wraps the provided service with the secure headers middleware.
fn layer(&self, inner: S) -> Self::Service {
SecureHeadersMiddleware {
inner,
layer: self.clone(),
}
}
}
/// The secure headers middleware
#[derive(Clone, Debug)]
#[must_use]
pub struct SecureHeadersMiddleware<S> {
inner: S,
layer: SecureHeaders,
}
impl<S> Service<Request<Body>> for SecureHeadersMiddleware<S>
where
S: Service<Request<Body>, Response = Response> + Send + 'static,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, request: Request<Body>) -> Self::Future {
let layer = self.layer.clone();
let future = self.inner.call(request);
Box::pin(async move {
let mut response: Response = future.await?;
let headers = response.headers_mut();
for (k, v) in &layer.headers {
headers.insert(k, v.clone());
}
Ok(response)
})
}
}
#[cfg(test)]
mod tests {
use axum::{
Router,
http::{HeaderMap, Method},
routing::get,
};
use insta::assert_debug_snapshot;
use tower::ServiceExt;
use super::*;
fn normalize_headers(headers: &HeaderMap) -> BTreeMap<String, String> {
headers
.iter()
.map(|(k, v)| {
let key = k.to_string();
let value = v.to_str().unwrap_or("").to_string();
(key, value)
})
.collect()
}
#[tokio::test]
async fn can_set_headers() {
let config = SecureHeader {
enable: true,
preset: "github".to_string(),
overrides: None,
};
let app = Router::new()
.route("/", get(|| async {}))
.layer(SecureHeaders::new(&config).unwrap());
let req = Request::builder()
.uri("/")
.method(Method::GET)
.body(Body::empty())
.unwrap();
let response = app.oneshot(req).await.unwrap();
assert_debug_snapshot!(normalize_headers(response.headers()));
}
#[tokio::test]
async fn can_override_headers() {
let mut overrides = BTreeMap::new();
overrides.insert("X-Download-Options".to_string(), "foobar".to_string());
overrides.insert("New-Header".to_string(), "baz".to_string());
let config = SecureHeader {
enable: true,
preset: "github".to_string(),
overrides: Some(overrides),
};
let app = Router::new()
.route("/", get(|| async {}))
.layer(SecureHeaders::new(&config).unwrap());
let req = Request::builder()
.uri("/")
.method(Method::GET)
.body(Body::empty())
.unwrap();
let response = app.oneshot(req).await.unwrap();
assert_debug_snapshot!(normalize_headers(response.headers()));
}
#[tokio::test]
async fn default_is_github_preset() {
let config = SecureHeader::default();
let app = Router::new()
.route("/", get(|| async {}))
.layer(SecureHeaders::new(&config).unwrap());
let req = Request::builder()
.uri("/")
.method(Method::GET)
.body(Body::empty())
.unwrap();
let response = app.oneshot(req).await.unwrap();
assert_debug_snapshot!(normalize_headers(response.headers()));
}
}

View File

@@ -0,0 +1,64 @@
//! Timeout Request Middleware.
//!
//! This middleware applies a timeout to requests processed by the application.
//! The timeout duration is configurable and defined via the
//! [`TimeoutRequestMiddleware`] configuration. The middleware ensures that
//! requests do not run beyond the specified timeout period, improving the
//! overall performance and responsiveness of the application.
//!
//! If a request exceeds the specified timeout duration, the middleware will
//! return a `408 Request Timeout` status code to the client, indicating that
//! the request took too long to process.
use std::{sync::Arc, time::Duration};
use axum::Router;
use serde::{Deserialize, Serialize};
use serde_json::json;
use tower_http::timeout::TimeoutLayer;
use crate::{app::AppContext, errors::RResult, web::middleware::MiddlewareLayer};
/// Timeout middleware configuration
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct TimeOut {
#[serde(default)]
pub enable: bool,
// Timeout request in milliseconds
#[serde(default = "default_timeout")]
pub timeout: u64,
}
impl Default for TimeOut {
fn default() -> Self {
serde_json::from_value(json!({})).unwrap()
}
}
fn default_timeout() -> u64 {
5_000
}
impl MiddlewareLayer for TimeOut {
/// Returns the name of the middleware.
fn name(&self) -> &'static str {
"timeout_request"
}
/// Checks if the timeout middleware is enabled.
fn is_enabled(&self) -> bool {
self.enable
}
fn config(&self) -> serde_json::Result<serde_json::Value> {
serde_json::to_value(self)
}
/// Applies the timeout middleware to the application router.
///
/// This method wraps the provided [`AXRouter`] in a [`TimeoutLayer`],
/// ensuring that requests exceeding the specified timeout duration will
/// be interrupted.
fn apply(&self, app: Router<Arc<AppContext>>) -> RResult<Router<Arc<AppContext>>> {
Ok(app.layer(TimeoutLayer::new(Duration::from_millis(self.timeout))))
}
}

View File

@@ -0,0 +1,5 @@
pub mod config;
pub mod controller;
pub mod middleware;
pub use config::WebServerConfig;