refactor: remove loco-rs deps
This commit is contained in:
30
apps/recorder/src/web/config.rs
Normal file
30
apps/recorder/src/web/config.rs
Normal file
@@ -0,0 +1,30 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::middleware::MiddlewareConfig;
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct WebServerConfig {
|
||||
/// The address on which the server should listen on for incoming
|
||||
/// connections.
|
||||
#[serde(default = "default_binding")]
|
||||
pub binding: String,
|
||||
/// The port on which the server should listen for incoming connections.
|
||||
#[serde(default = "default_port")]
|
||||
pub port: u16,
|
||||
/// The webserver host
|
||||
pub host: String,
|
||||
/// Identify via the `Server` header
|
||||
pub ident: Option<String>,
|
||||
/// Middleware configurations for the server, including payload limits,
|
||||
/// logging, and error handling.
|
||||
#[serde(default)]
|
||||
pub middlewares: MiddlewareConfig,
|
||||
}
|
||||
|
||||
pub fn default_binding() -> String {
|
||||
"127.0.0.1".to_string()
|
||||
}
|
||||
|
||||
pub fn default_port() -> u16 {
|
||||
5_001
|
||||
}
|
||||
52
apps/recorder/src/web/controller/__root.tsx
Normal file
52
apps/recorder/src/web/controller/__root.tsx
Normal file
@@ -0,0 +1,52 @@
|
||||
import type { Injector } from '@outposts/injection-js';
|
||||
import {
|
||||
// Link,
|
||||
Outlet,
|
||||
createRootRouteWithContext,
|
||||
} from '@tanstack/react-router';
|
||||
import { TanStackRouterDevtools } from '@tanstack/router-devtools';
|
||||
import type { OidcSecurityService } from 'oidc-client-rx';
|
||||
|
||||
export type RouterContext =
|
||||
| {
|
||||
isAuthenticated: false;
|
||||
injector: Injector;
|
||||
oidcSecurityService: OidcSecurityService;
|
||||
}
|
||||
| {
|
||||
isAuthenticated: true;
|
||||
injector?: Injector;
|
||||
oidcSecurityService?: OidcSecurityService;
|
||||
};
|
||||
|
||||
export const Route = createRootRouteWithContext<RouterContext>()({
|
||||
component: RootComponent,
|
||||
});
|
||||
|
||||
function RootComponent() {
|
||||
return (
|
||||
<>
|
||||
{/* <div className="flex gap-2 p-2 text-lg ">
|
||||
<Link
|
||||
to="/"
|
||||
activeProps={{
|
||||
className: 'font-bold',
|
||||
}}
|
||||
>
|
||||
Home
|
||||
</Link>{' '}
|
||||
<Link
|
||||
to="/graphql"
|
||||
activeProps={{
|
||||
className: 'font-bold',
|
||||
}}
|
||||
>
|
||||
GraphQL
|
||||
</Link>
|
||||
</div> */}
|
||||
{/* <hr /> */}
|
||||
<Outlet />
|
||||
<TanStackRouterDevtools position="bottom-right" />
|
||||
</>
|
||||
);
|
||||
}
|
||||
50
apps/recorder/src/web/controller/core.rs
Normal file
50
apps/recorder/src/web/controller/core.rs
Normal file
@@ -0,0 +1,50 @@
|
||||
use std::{borrow::Cow, sync::Arc};
|
||||
|
||||
use axum::Router;
|
||||
|
||||
use crate::app::AppContext;
|
||||
|
||||
pub trait ControllerTrait: Sized {
|
||||
fn apply_to(self, router: Router<Arc<AppContext>>) -> Router<Arc<AppContext>>;
|
||||
}
|
||||
|
||||
pub struct PrefixController {
|
||||
prefix: Cow<'static, str>,
|
||||
router: Router<Arc<AppContext>>,
|
||||
}
|
||||
|
||||
impl PrefixController {
|
||||
pub fn new(prefix: impl Into<Cow<'static, str>>, router: Router<Arc<AppContext>>) -> Self {
|
||||
Self {
|
||||
prefix: prefix.into(),
|
||||
router,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ControllerTrait for PrefixController {
|
||||
fn apply_to(self, router: Router<Arc<AppContext>>) -> Router<Arc<AppContext>> {
|
||||
router.nest(&self.prefix, self.router)
|
||||
}
|
||||
}
|
||||
|
||||
pub enum Controller {
|
||||
Prefix(PrefixController),
|
||||
}
|
||||
|
||||
impl Controller {
|
||||
pub fn from_prefix(
|
||||
prefix: impl Into<Cow<'static, str>>,
|
||||
router: Router<Arc<AppContext>>,
|
||||
) -> Self {
|
||||
Self::Prefix(PrefixController::new(prefix, router))
|
||||
}
|
||||
}
|
||||
|
||||
impl ControllerTrait for Controller {
|
||||
fn apply_to(self, router: Router<Arc<AppContext>>) -> Router<Arc<AppContext>> {
|
||||
match self {
|
||||
Self::Prefix(p) => p.apply_to(router),
|
||||
}
|
||||
}
|
||||
}
|
||||
36
apps/recorder/src/web/controller/graphql/index.tsx
Normal file
36
apps/recorder/src/web/controller/graphql/index.tsx
Normal file
@@ -0,0 +1,36 @@
|
||||
import { type Fetcher, createGraphiQLFetcher } from '@graphiql/toolkit';
|
||||
import { createFileRoute } from '@tanstack/react-router';
|
||||
import GraphiQL from 'graphiql';
|
||||
import { useMemo } from 'react';
|
||||
import { beforeLoadGuard } from '../../../auth/guard';
|
||||
import 'graphiql/graphiql.css';
|
||||
import { firstValueFrom } from 'rxjs';
|
||||
import { useAuth } from '../../../auth/hooks';
|
||||
|
||||
export const Route = createFileRoute('/graphql/')({
|
||||
component: RouteComponent,
|
||||
beforeLoad: beforeLoadGuard,
|
||||
});
|
||||
|
||||
function RouteComponent() {
|
||||
const { oidcSecurityService } = useAuth();
|
||||
|
||||
const fetcher = useMemo(
|
||||
(): Fetcher => async (props) => {
|
||||
const accessToken = oidcSecurityService
|
||||
? await firstValueFrom(oidcSecurityService.getAccessToken())
|
||||
: undefined;
|
||||
return createGraphiQLFetcher({
|
||||
url: '/api/graphql',
|
||||
headers: accessToken
|
||||
? {
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
}
|
||||
: undefined,
|
||||
})(props);
|
||||
},
|
||||
[oidcSecurityService]
|
||||
);
|
||||
|
||||
return <GraphiQL fetcher={fetcher} className="h-svh" />;
|
||||
}
|
||||
33
apps/recorder/src/web/controller/graphql/mod.rs
Normal file
33
apps/recorder/src/web/controller/graphql/mod.rs
Normal file
@@ -0,0 +1,33 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_graphql_axum::{GraphQLRequest, GraphQLResponse};
|
||||
use axum::{Extension, Router, extract::State, middleware::from_fn_with_state, routing::post};
|
||||
|
||||
use super::core::Controller;
|
||||
use crate::{
|
||||
app::AppContext,
|
||||
auth::{AuthUserInfo, header_www_authenticate_middleware},
|
||||
errors::RResult,
|
||||
};
|
||||
|
||||
pub const CONTROLLER_PREFIX: &str = "/api/graphql";
|
||||
|
||||
async fn graphql_handler(
|
||||
State(ctx): State<Arc<AppContext>>,
|
||||
Extension(auth_user_info): Extension<AuthUserInfo>,
|
||||
req: GraphQLRequest,
|
||||
) -> GraphQLResponse {
|
||||
let graphql_service = &ctx.graphql;
|
||||
|
||||
let mut req = req.into_inner();
|
||||
req = req.data(auth_user_info);
|
||||
|
||||
graphql_service.schema.execute(req).await.into()
|
||||
}
|
||||
|
||||
pub async fn create(ctx: Arc<AppContext>) -> RResult<Controller> {
|
||||
let router = Router::<Arc<AppContext>>::new()
|
||||
.route("/", post(graphql_handler))
|
||||
.layer(from_fn_with_state(ctx, header_www_authenticate_middleware));
|
||||
Ok(Controller::from_prefix(CONTROLLER_PREFIX, router))
|
||||
}
|
||||
9
apps/recorder/src/web/controller/index.tsx
Normal file
9
apps/recorder/src/web/controller/index.tsx
Normal file
@@ -0,0 +1,9 @@
|
||||
import { createFileRoute } from '@tanstack/react-router'
|
||||
|
||||
export const Route = createFileRoute('/')({
|
||||
component: RouteComponent,
|
||||
})
|
||||
|
||||
function RouteComponent() {
|
||||
return <div>Hello to playground!</div>
|
||||
}
|
||||
5
apps/recorder/src/web/controller/mod.rs
Normal file
5
apps/recorder/src/web/controller/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
pub mod core;
|
||||
pub mod graphql;
|
||||
pub mod oidc;
|
||||
|
||||
pub use core::{Controller, ControllerTrait, PrefixController};
|
||||
32
apps/recorder/src/web/controller/oidc/callback.tsx
Normal file
32
apps/recorder/src/web/controller/oidc/callback.tsx
Normal file
@@ -0,0 +1,32 @@
|
||||
import { createFileRoute, redirect } from '@tanstack/react-router';
|
||||
import { EventTypes } from 'oidc-client-rx';
|
||||
import { useAuth } from '../../../auth/hooks';
|
||||
|
||||
export const Route = createFileRoute('/oidc/callback')({
|
||||
component: RouteComponent,
|
||||
beforeLoad: ({ context }) => {
|
||||
if (!context.oidcSecurityService) {
|
||||
throw redirect({
|
||||
to: '/',
|
||||
});
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
function RouteComponent() {
|
||||
const auth = useAuth();
|
||||
|
||||
if (!auth.checkAuthResultEvent) {
|
||||
return <div>Loading...</div>;
|
||||
}
|
||||
|
||||
return (
|
||||
<div>
|
||||
OpenID Connect Auth Callback:{' '}
|
||||
{auth.checkAuthResultEvent?.type ===
|
||||
EventTypes.CheckingAuthFinishedWithError
|
||||
? auth.checkAuthResultEvent.value
|
||||
: 'success'}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
79
apps/recorder/src/web/controller/oidc/mod.rs
Normal file
79
apps/recorder/src/web/controller/oidc/mod.rs
Normal file
@@ -0,0 +1,79 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::{
|
||||
Json, Router,
|
||||
extract::{Query, State},
|
||||
http::request::Parts,
|
||||
routing::get,
|
||||
};
|
||||
|
||||
use super::core::Controller;
|
||||
use crate::{
|
||||
app::AppContext,
|
||||
auth::{
|
||||
AuthError, AuthService, AuthServiceTrait,
|
||||
oidc::{OidcAuthCallbackPayload, OidcAuthCallbackQuery, OidcAuthRequest},
|
||||
},
|
||||
errors::RResult,
|
||||
extract::http::ForwardedRelatedInfo,
|
||||
models::auth::AuthType,
|
||||
};
|
||||
|
||||
pub const CONTROLLER_PREFIX: &str = "/api/oidc";
|
||||
|
||||
async fn oidc_callback(
|
||||
State(ctx): State<Arc<AppContext>>,
|
||||
Query(query): Query<OidcAuthCallbackQuery>,
|
||||
) -> Result<Json<OidcAuthCallbackPayload>, AuthError> {
|
||||
let auth_service = &ctx.auth;
|
||||
if let AuthService::Oidc(oidc_auth_service) = auth_service {
|
||||
let response = oidc_auth_service
|
||||
.extract_authorization_request_callback(query)
|
||||
.await?;
|
||||
Ok(Json(response))
|
||||
} else {
|
||||
Err(AuthError::NotSupportAuthMethod {
|
||||
supported: vec![auth_service.auth_type()],
|
||||
current: AuthType::Oidc,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async fn oidc_auth(
|
||||
State(ctx): State<Arc<AppContext>>,
|
||||
parts: Parts,
|
||||
) -> Result<Json<OidcAuthRequest>, AuthError> {
|
||||
let auth_service = &ctx.auth;
|
||||
if let AuthService::Oidc(oidc_auth_service) = auth_service {
|
||||
let mut redirect_uri = ForwardedRelatedInfo::from_request_parts(&parts)
|
||||
.resolved_origin()
|
||||
.ok_or_else(|| AuthError::OidcRequestRedirectUriError(url::ParseError::EmptyHost))?;
|
||||
|
||||
redirect_uri.set_path(&format!("{CONTROLLER_PREFIX}/callback"));
|
||||
|
||||
let auth_request = oidc_auth_service
|
||||
.build_authorization_request(redirect_uri.as_str())
|
||||
.await?;
|
||||
|
||||
{
|
||||
oidc_auth_service
|
||||
.store_authorization_request(auth_request.clone())
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(Json(auth_request))
|
||||
} else {
|
||||
Err(AuthError::NotSupportAuthMethod {
|
||||
supported: vec![auth_service.auth_type()],
|
||||
current: AuthType::Oidc,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn create(_context: Arc<AppContext>) -> RResult<Controller> {
|
||||
let router = Router::<Arc<AppContext>>::new()
|
||||
.route("/auth", get(oidc_auth))
|
||||
.route("/callback", get(oidc_callback));
|
||||
|
||||
Ok(Controller::from_prefix(CONTROLLER_PREFIX, router))
|
||||
}
|
||||
58
apps/recorder/src/web/middleware/catch_panic.rs
Normal file
58
apps/recorder/src/web/middleware/catch_panic.rs
Normal file
@@ -0,0 +1,58 @@
|
||||
//! Catch Panic Middleware for Axum
|
||||
//!
|
||||
//! This middleware catches panics that occur during request handling in the
|
||||
//! application. When a panic occurs, it logs the error and returns an
|
||||
//! internal server error response. This middleware helps ensure that the
|
||||
//! application can gracefully handle unexpected errors without crashing the
|
||||
//! server.
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::{Router, response::IntoResponse};
|
||||
use http::StatusCode;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tower_http::catch_panic::CatchPanicLayer;
|
||||
|
||||
use crate::{app::AppContext, errors::RResult, web::middleware::MiddlewareLayer};
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct CatchPanic {
|
||||
#[serde(default)]
|
||||
pub enable: bool,
|
||||
}
|
||||
|
||||
/// Handler function for the [`CatchPanicLayer`] middleware.
|
||||
///
|
||||
/// This function processes panics by extracting error messages, logging them,
|
||||
/// and returning an internal server error response.
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
fn handle_panic(err: Box<dyn std::any::Any + Send + 'static>) -> axum::response::Response {
|
||||
let err = err.downcast_ref::<String>().map_or_else(
|
||||
|| err.downcast_ref::<&str>().map_or("no error details", |s| s),
|
||||
|s| s.as_str(),
|
||||
);
|
||||
|
||||
tracing::error!(err.msg = err, "server_panic");
|
||||
|
||||
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||
}
|
||||
|
||||
impl MiddlewareLayer for CatchPanic {
|
||||
/// Returns the name of the middleware
|
||||
fn name(&self) -> &'static str {
|
||||
"catch_panic"
|
||||
}
|
||||
|
||||
/// Returns whether the middleware is enabled or not
|
||||
fn is_enabled(&self) -> bool {
|
||||
self.enable
|
||||
}
|
||||
|
||||
fn config(&self) -> serde_json::Result<serde_json::Value> {
|
||||
serde_json::to_value(self)
|
||||
}
|
||||
|
||||
/// Applies the Catch Panic middleware layer to the Axum router.
|
||||
fn apply(&self, app: Router<Arc<AppContext>>) -> RResult<Router<Arc<AppContext>>> {
|
||||
Ok(app.layer(CatchPanicLayer::custom(handle_panic)))
|
||||
}
|
||||
}
|
||||
41
apps/recorder/src/web/middleware/compression.rs
Normal file
41
apps/recorder/src/web/middleware/compression.rs
Normal file
@@ -0,0 +1,41 @@
|
||||
//! Compression Middleware for Axum
|
||||
//!
|
||||
//! This middleware applies compression to HTTP responses to reduce the size of
|
||||
//! the data being transmitted. This can improve performance by decreasing load
|
||||
//! times and reducing bandwidth usage. The middleware configuration allows for
|
||||
//! enabling or disabling compression based on the application settings.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::Router;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tower_http::compression::CompressionLayer;
|
||||
|
||||
use crate::{app::AppContext, errors::RResult, web::middleware::MiddlewareLayer};
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Compression {
|
||||
#[serde(default)]
|
||||
pub enable: bool,
|
||||
}
|
||||
|
||||
impl MiddlewareLayer for Compression {
|
||||
/// Returns the name of the middleware
|
||||
fn name(&self) -> &'static str {
|
||||
"compression"
|
||||
}
|
||||
|
||||
/// Returns whether the middleware is enabled or not
|
||||
fn is_enabled(&self) -> bool {
|
||||
self.enable
|
||||
}
|
||||
|
||||
fn config(&self) -> serde_json::Result<serde_json::Value> {
|
||||
serde_json::to_value(self)
|
||||
}
|
||||
|
||||
/// Applies the Compression middleware layer to the Axum router.
|
||||
fn apply(&self, app: Router<Arc<AppContext>>) -> RResult<Router<Arc<AppContext>>> {
|
||||
Ok(app.layer(CompressionLayer::new()))
|
||||
}
|
||||
}
|
||||
163
apps/recorder/src/web/middleware/cors.rs
Normal file
163
apps/recorder/src/web/middleware/cors.rs
Normal file
@@ -0,0 +1,163 @@
|
||||
//! Configurable and Flexible CORS Middleware
|
||||
//!
|
||||
//! This middleware enables Cross-Origin Resource Sharing (CORS) by allowing
|
||||
//! configurable origins, methods, and headers in HTTP requests. It can be
|
||||
//! tailored to fit various application requirements, supporting permissive CORS
|
||||
//! or specific rules as defined in the middleware configuration.
|
||||
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
use axum::Router;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use tower_http::cors::{self, Any};
|
||||
|
||||
use crate::{app::AppContext, web::middleware::MiddlewareLayer, errors::RResult};
|
||||
|
||||
/// CORS middleware configuration
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Cors {
|
||||
#[serde(default)]
|
||||
pub enable: bool,
|
||||
/// Allow origins
|
||||
#[serde(default = "default_allow_origins")]
|
||||
pub allow_origins: Vec<String>,
|
||||
/// Allow headers
|
||||
#[serde(default = "default_allow_headers")]
|
||||
pub allow_headers: Vec<String>,
|
||||
/// Allow methods
|
||||
#[serde(default = "default_allow_methods")]
|
||||
pub allow_methods: Vec<String>,
|
||||
/// Allow credentials
|
||||
#[serde(default)]
|
||||
pub allow_credentials: bool,
|
||||
/// Max age
|
||||
pub max_age: Option<u64>,
|
||||
// Vary headers
|
||||
#[serde(default = "default_vary_headers")]
|
||||
pub vary: Vec<String>,
|
||||
}
|
||||
|
||||
fn default_allow_origins() -> Vec<String> {
|
||||
vec!["*".to_string()]
|
||||
}
|
||||
|
||||
fn default_allow_headers() -> Vec<String> {
|
||||
vec!["*".to_string()]
|
||||
}
|
||||
|
||||
fn default_allow_methods() -> Vec<String> {
|
||||
vec!["*".to_string()]
|
||||
}
|
||||
|
||||
fn default_vary_headers() -> Vec<String> {
|
||||
vec![
|
||||
"origin".to_string(),
|
||||
"access-control-request-method".to_string(),
|
||||
"access-control-request-headers".to_string(),
|
||||
]
|
||||
}
|
||||
|
||||
impl Default for Cors {
|
||||
fn default() -> Self {
|
||||
serde_json::from_value(json!({})).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl Cors {
|
||||
/// Creates cors layer
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// This function returns an error in the following cases:
|
||||
///
|
||||
/// - If any of the provided origins in `allow_origins` cannot be parsed as
|
||||
/// a valid URI, the function will return a parsing error.
|
||||
/// - If any of the provided headers in `allow_headers` cannot be parsed as
|
||||
/// valid HTTP headers, the function will return a parsing error.
|
||||
/// - If any of the provided methods in `allow_methods` cannot be parsed as
|
||||
/// valid HTTP methods, the function will return a parsing error.
|
||||
///
|
||||
/// In all of these cases, the error returned will be the result of the
|
||||
/// `parse` method of the corresponding type.
|
||||
pub fn cors(&self) -> RResult<cors::CorsLayer> {
|
||||
let mut cors: cors::CorsLayer = cors::CorsLayer::new();
|
||||
|
||||
// testing CORS, assuming https://example.com in the allow list:
|
||||
// $ curl -v --request OPTIONS 'localhost:5150/api/_ping' -H 'Origin: https://example.com' -H 'Acces
|
||||
// look for '< access-control-allow-origin: https://example.com' in response.
|
||||
// if it doesn't appear (test with a bogus domain), it is not allowed.
|
||||
if self.allow_origins == default_allow_origins() {
|
||||
cors = cors.allow_origin(Any);
|
||||
} else {
|
||||
let mut list = vec![];
|
||||
for origin in &self.allow_origins {
|
||||
list.push(origin.parse()?);
|
||||
}
|
||||
if !list.is_empty() {
|
||||
cors = cors.allow_origin(list);
|
||||
}
|
||||
}
|
||||
|
||||
if self.allow_headers == default_allow_headers() {
|
||||
cors = cors.allow_headers(Any);
|
||||
} else {
|
||||
let mut list = vec![];
|
||||
for header in &self.allow_headers {
|
||||
list.push(header.parse()?);
|
||||
}
|
||||
if !list.is_empty() {
|
||||
cors = cors.allow_headers(list);
|
||||
}
|
||||
}
|
||||
|
||||
if self.allow_methods == default_allow_methods() {
|
||||
cors = cors.allow_methods(Any);
|
||||
} else {
|
||||
let mut list = vec![];
|
||||
for method in &self.allow_methods {
|
||||
list.push(method.parse()?);
|
||||
}
|
||||
if !list.is_empty() {
|
||||
cors = cors.allow_methods(list);
|
||||
}
|
||||
}
|
||||
|
||||
let mut list = vec![];
|
||||
for v in &self.vary {
|
||||
list.push(v.parse()?);
|
||||
}
|
||||
if !list.is_empty() {
|
||||
cors = cors.vary(list);
|
||||
}
|
||||
|
||||
if let Some(max_age) = self.max_age {
|
||||
cors = cors.max_age(Duration::from_secs(max_age));
|
||||
}
|
||||
|
||||
cors = cors.allow_credentials(self.allow_credentials);
|
||||
|
||||
Ok(cors)
|
||||
}
|
||||
}
|
||||
|
||||
impl MiddlewareLayer for Cors {
|
||||
/// Returns the name of the middleware
|
||||
fn name(&self) -> &'static str {
|
||||
"cors"
|
||||
}
|
||||
|
||||
/// Returns whether the middleware is enabled or not
|
||||
fn is_enabled(&self) -> bool {
|
||||
self.enable
|
||||
}
|
||||
|
||||
fn config(&self) -> serde_json::Result<serde_json::Value> {
|
||||
serde_json::to_value(self)
|
||||
}
|
||||
|
||||
/// Applies the CORS middleware layer to the Axum router.
|
||||
fn apply(&self, app: Router<Arc<AppContext>>) -> RResult<Router<Arc<AppContext>>> {
|
||||
Ok(app.layer(self.cors()?))
|
||||
}
|
||||
}
|
||||
111
apps/recorder/src/web/middleware/etag.rs
Normal file
111
apps/recorder/src/web/middleware/etag.rs
Normal file
@@ -0,0 +1,111 @@
|
||||
//! `ETag` Middleware for Caching Requests
|
||||
//!
|
||||
//! This middleware implements the [ETag](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/ETag)
|
||||
//! HTTP header for caching responses in Axum. `ETags` are used to validate
|
||||
//! cache entries by comparing a client's stored `ETag` with the one generated
|
||||
//! by the server. If the `ETags` match, a `304 Not Modified` response is sent,
|
||||
//! avoiding the need to resend the full content.
|
||||
|
||||
use std::{
|
||||
sync::Arc,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use axum::{
|
||||
Router,
|
||||
body::Body,
|
||||
extract::Request,
|
||||
http::{
|
||||
StatusCode,
|
||||
header::{ETAG, IF_NONE_MATCH},
|
||||
},
|
||||
response::Response,
|
||||
};
|
||||
use futures_util::future::BoxFuture;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tower::{Layer, Service};
|
||||
|
||||
use crate::{app::AppContext, errors::RResult, web::middleware::MiddlewareLayer};
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Etag {
|
||||
#[serde(default)]
|
||||
pub enable: bool,
|
||||
}
|
||||
|
||||
impl MiddlewareLayer for Etag {
|
||||
/// Returns the name of the middleware
|
||||
fn name(&self) -> &'static str {
|
||||
"etag"
|
||||
}
|
||||
|
||||
/// Returns whether the middleware is enabled or not
|
||||
fn is_enabled(&self) -> bool {
|
||||
self.enable
|
||||
}
|
||||
|
||||
fn config(&self) -> serde_json::Result<serde_json::Value> {
|
||||
serde_json::to_value(self)
|
||||
}
|
||||
|
||||
/// Applies the `ETag` middleware to the application router.
|
||||
fn apply(&self, app: Router<Arc<AppContext>>) -> RResult<Router<Arc<AppContext>>> {
|
||||
Ok(app.layer(EtagLayer))
|
||||
}
|
||||
}
|
||||
|
||||
/// [`EtagLayer`] struct for adding `ETag` functionality as a Tower service
|
||||
/// layer.
|
||||
#[derive(Default, Clone)]
|
||||
struct EtagLayer;
|
||||
|
||||
impl<S> Layer<S> for EtagLayer {
|
||||
type Service = EtagMiddleware<S>;
|
||||
|
||||
fn layer(&self, inner: S) -> Self::Service {
|
||||
EtagMiddleware { inner }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct EtagMiddleware<S> {
|
||||
inner: S,
|
||||
}
|
||||
|
||||
impl<S> Service<Request<Body>> for EtagMiddleware<S>
|
||||
where
|
||||
S: Service<Request, Response = Response> + Send + 'static,
|
||||
S::Future: Send + 'static,
|
||||
{
|
||||
type Response = S::Response;
|
||||
type Error = S::Error;
|
||||
// `BoxFuture` is a type alias for `Pin<Box<dyn Future + Send + 'a>>`
|
||||
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
|
||||
|
||||
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.inner.poll_ready(cx)
|
||||
}
|
||||
|
||||
fn call(&mut self, request: Request) -> Self::Future {
|
||||
let ifnm = request.headers().get(IF_NONE_MATCH).cloned();
|
||||
|
||||
let future = self.inner.call(request);
|
||||
|
||||
let res_fut = async move {
|
||||
let response = future.await?;
|
||||
let etag_from_response = response.headers().get(ETAG).cloned();
|
||||
if let Some(etag_in_request) = ifnm {
|
||||
if let Some(etag_from_response) = etag_from_response {
|
||||
if etag_in_request == etag_from_response {
|
||||
return Ok(Response::builder()
|
||||
.status(StatusCode::NOT_MODIFIED)
|
||||
.body(Body::empty())
|
||||
.unwrap());
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(response)
|
||||
};
|
||||
Box::pin(res_fut)
|
||||
}
|
||||
}
|
||||
71
apps/recorder/src/web/middleware/format.rs
Normal file
71
apps/recorder/src/web/middleware/format.rs
Normal file
@@ -0,0 +1,71 @@
|
||||
//! Detect a content type and format and responds accordingly
|
||||
use axum::{
|
||||
extract::FromRequestParts,
|
||||
http::{
|
||||
header::{ACCEPT, CONTENT_TYPE, HeaderMap},
|
||||
request::Parts,
|
||||
},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::errors::RError as Error;
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct Format(pub RespondTo);
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub enum RespondTo {
|
||||
None,
|
||||
Html,
|
||||
Json,
|
||||
Xml,
|
||||
Other(String),
|
||||
}
|
||||
|
||||
fn detect_format(content_type: &str) -> RespondTo {
|
||||
if content_type.starts_with("application/json") {
|
||||
RespondTo::Json
|
||||
} else if content_type.starts_with("text/html") {
|
||||
RespondTo::Html
|
||||
} else if content_type.starts_with("text/xml")
|
||||
|| content_type.starts_with("application/xml")
|
||||
|| content_type.starts_with("application/xhtml")
|
||||
{
|
||||
RespondTo::Xml
|
||||
} else {
|
||||
RespondTo::Other(content_type.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_respond_to(headers: &HeaderMap) -> RespondTo {
|
||||
#[allow(clippy::option_if_let_else)]
|
||||
if let Some(content_type) = headers.get(CONTENT_TYPE).and_then(|h| h.to_str().ok()) {
|
||||
detect_format(content_type)
|
||||
} else if let Some(content_type) = headers.get(ACCEPT).and_then(|h| h.to_str().ok()) {
|
||||
detect_format(content_type)
|
||||
} else {
|
||||
RespondTo::None
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> FromRequestParts<S> for Format
|
||||
where
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = Error;
|
||||
|
||||
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Error> {
|
||||
Ok(Self(get_respond_to(&parts.headers)))
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> FromRequestParts<S> for RespondTo
|
||||
where
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = Error;
|
||||
|
||||
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Error> {
|
||||
Ok(get_respond_to(&parts.headers))
|
||||
}
|
||||
}
|
||||
102
apps/recorder/src/web/middleware/logger.rs
Normal file
102
apps/recorder/src/web/middleware/logger.rs
Normal file
@@ -0,0 +1,102 @@
|
||||
//! Logger Middleware
|
||||
//!
|
||||
//! This middleware provides logging functionality for HTTP requests. It uses
|
||||
//! `TraceLayer` to log detailed information about each request, such as the
|
||||
//! HTTP method, URI, version, user agent, and an associated request ID.
|
||||
//! Additionally, it integrates the application's runtime environment
|
||||
//! into the log context, allowing environment-specific logging (e.g.,
|
||||
//! "development", "production").
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::{Router, http};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tower_http::{add_extension::AddExtensionLayer, trace::TraceLayer};
|
||||
|
||||
use crate::{
|
||||
app::{AppContext, Environment},
|
||||
errors::RResult,
|
||||
web::middleware::{MiddlewareLayer, request_id::LocoRequestId},
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Config {
|
||||
#[serde(default)]
|
||||
pub enable: bool,
|
||||
}
|
||||
|
||||
/// [`Middleware`] struct responsible for logging HTTP requests.
|
||||
#[derive(Serialize, Debug)]
|
||||
pub struct Middleware {
|
||||
config: Config,
|
||||
environment: Environment,
|
||||
}
|
||||
|
||||
/// Creates a new instance of [`Middleware`] by cloning the [`Config`]
|
||||
/// configuration.
|
||||
#[must_use]
|
||||
pub fn new(config: &Config, context: Arc<AppContext>) -> Middleware {
|
||||
Middleware {
|
||||
config: config.clone(),
|
||||
environment: context.environment.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
impl MiddlewareLayer for Middleware {
|
||||
/// Returns the name of the middleware
|
||||
fn name(&self) -> &'static str {
|
||||
"logger"
|
||||
}
|
||||
|
||||
/// Returns whether the middleware is enabled or not
|
||||
fn is_enabled(&self) -> bool {
|
||||
self.config.enable
|
||||
}
|
||||
|
||||
fn config(&self) -> serde_json::Result<serde_json::Value> {
|
||||
serde_json::to_value(self)
|
||||
}
|
||||
|
||||
/// Applies the logger middleware to the application router by adding layers
|
||||
/// for:
|
||||
///
|
||||
/// - `TraceLayer`: Logs detailed information about each HTTP request.
|
||||
/// - `AddExtensionLayer`: Adds the current environment to the request
|
||||
/// extensions, making it accessible to the `TraceLayer` for logging.
|
||||
///
|
||||
/// The `TraceLayer` is customized with `make_span_with` to extract
|
||||
/// request-specific details like method, URI, version, user agent, and
|
||||
/// request ID, then create a tracing span for the request.
|
||||
fn apply(&self, app: Router<Arc<AppContext>>) -> RResult<Router<Arc<AppContext>>> {
|
||||
Ok(app
|
||||
.layer(
|
||||
TraceLayer::new_for_http().make_span_with(|request: &http::Request<_>| {
|
||||
let ext = request.extensions();
|
||||
let request_id = ext
|
||||
.get::<LocoRequestId>()
|
||||
.map_or_else(|| "req-id-none".to_string(), |r| r.get().to_string());
|
||||
let user_agent = request
|
||||
.headers()
|
||||
.get(axum::http::header::USER_AGENT)
|
||||
.map_or("", |h| h.to_str().unwrap_or(""));
|
||||
|
||||
let env: String = request
|
||||
.extensions()
|
||||
.get::<Environment>()
|
||||
.map(|e| e.full_name().to_string())
|
||||
.unwrap_or_default();
|
||||
|
||||
tracing::error_span!(
|
||||
"http-request",
|
||||
"http.method" = tracing::field::display(request.method()),
|
||||
"http.uri" = tracing::field::display(request.uri()),
|
||||
"http.version" = tracing::field::debug(request.version()),
|
||||
"http.user_agent" = tracing::field::display(user_agent),
|
||||
"environment" = tracing::field::display(env),
|
||||
request_id = tracing::field::display(request_id),
|
||||
)
|
||||
}),
|
||||
)
|
||||
.layer(AddExtensionLayer::new(self.environment.clone())))
|
||||
}
|
||||
}
|
||||
165
apps/recorder/src/web/middleware/mod.rs
Normal file
165
apps/recorder/src/web/middleware/mod.rs
Normal file
@@ -0,0 +1,165 @@
|
||||
pub mod catch_panic;
|
||||
pub mod compression;
|
||||
pub mod cors;
|
||||
pub mod etag;
|
||||
pub mod format;
|
||||
pub mod logger;
|
||||
pub mod remote_ip;
|
||||
pub mod request_id;
|
||||
pub mod secure_headers;
|
||||
pub mod timeout;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::Router;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{app::AppContext, errors::RResult};
|
||||
|
||||
/// Trait representing the behavior of middleware components in the application.
|
||||
/// When implementing a new middleware, make sure to go over this checklist:
|
||||
/// * The name of the middleware should be an ID that is similar to the field
|
||||
/// name in configuration (look at how `serde` calls it)
|
||||
/// * Default value implementation should be paired with `serde` default
|
||||
/// handlers and default serialization implementation. Which means deriving
|
||||
/// `Default` will _not_ work. You can use `serde_json` and serialize a new
|
||||
/// config from an empty value, which will cause `serde` default value
|
||||
/// handlers to kick in.
|
||||
/// * If you need completely blank values for configuration (for example for
|
||||
/// testing), implement an `::empty() -> Self` call ad-hoc.
|
||||
pub trait MiddlewareLayer {
|
||||
/// Returns the name of the middleware.
|
||||
/// This should match the name of the property in the containing
|
||||
/// `middleware` section in configuration (as named by `serde`)
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
/// Returns whether the middleware is enabled or not.
|
||||
/// If the middleware is switchable, take this value from a configuration
|
||||
/// value
|
||||
fn is_enabled(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
/// Returns middleware config.
|
||||
///
|
||||
/// # Errors
|
||||
/// when could not convert middleware to [`serde_json::Value`]
|
||||
fn config(&self) -> serde_json::Result<serde_json::Value>;
|
||||
|
||||
/// Applies the middleware to the given Axum router and returns the modified
|
||||
/// router.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// If there is an issue when adding the middleware to the router.
|
||||
fn apply(&self, app: Router<Arc<AppContext>>) -> RResult<Router<Arc<AppContext>>>;
|
||||
}
|
||||
|
||||
#[allow(clippy::unnecessary_lazy_evaluations)]
|
||||
#[must_use]
|
||||
pub fn default_middleware_stack(ctx: Arc<AppContext>) -> Vec<Box<dyn MiddlewareLayer>> {
|
||||
// Shortened reference to middlewares
|
||||
let middlewares = &ctx.config.server.middlewares;
|
||||
|
||||
vec![
|
||||
// CORS middleware with a default if none
|
||||
Box::new(middlewares.cors.clone().unwrap_or_else(|| cors::Cors {
|
||||
enable: false,
|
||||
..Default::default()
|
||||
})),
|
||||
// Catch Panic middleware with a default if none
|
||||
Box::new(
|
||||
middlewares
|
||||
.catch_panic
|
||||
.clone()
|
||||
.unwrap_or_else(|| catch_panic::CatchPanic { enable: true }),
|
||||
),
|
||||
// Etag middleware with a default if none
|
||||
Box::new(
|
||||
middlewares
|
||||
.etag
|
||||
.clone()
|
||||
.unwrap_or_else(|| etag::Etag { enable: true }),
|
||||
),
|
||||
// Remote IP middleware with a default if none
|
||||
Box::new(
|
||||
middlewares
|
||||
.remote_ip
|
||||
.clone()
|
||||
.unwrap_or_else(|| remote_ip::RemoteIpMiddleware {
|
||||
enable: false,
|
||||
..Default::default()
|
||||
}),
|
||||
),
|
||||
// Compression middleware with a default if none
|
||||
Box::new(
|
||||
middlewares
|
||||
.compression
|
||||
.clone()
|
||||
.unwrap_or_else(|| compression::Compression { enable: false }),
|
||||
),
|
||||
// Timeout Request middleware with a default if none
|
||||
Box::new(
|
||||
middlewares
|
||||
.timeout_request
|
||||
.clone()
|
||||
.unwrap_or_else(|| timeout::TimeOut {
|
||||
enable: false,
|
||||
..Default::default()
|
||||
}),
|
||||
),
|
||||
// Secure Headers middleware with a default if none
|
||||
Box::new(middlewares.secure_headers.clone().unwrap_or_else(|| {
|
||||
secure_headers::SecureHeader {
|
||||
enable: false,
|
||||
..Default::default()
|
||||
}
|
||||
})),
|
||||
// Logger middleware with default logger configuration
|
||||
Box::new(logger::new(
|
||||
&middlewares
|
||||
.logger
|
||||
.clone()
|
||||
.unwrap_or_else(|| logger::Config { enable: true }),
|
||||
ctx.clone(),
|
||||
)),
|
||||
// Request ID middleware with a default if none
|
||||
Box::new(
|
||||
middlewares
|
||||
.request_id
|
||||
.clone()
|
||||
.unwrap_or_else(|| request_id::RequestId { enable: true }),
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
/// Server middleware configuration structure.
|
||||
#[derive(Default, Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct MiddlewareConfig {
|
||||
/// Compression for the response.
|
||||
pub compression: Option<compression::Compression>,
|
||||
|
||||
/// Etag cache headers.
|
||||
pub etag: Option<etag::Etag>,
|
||||
|
||||
/// Logger and augmenting trace id with request data
|
||||
pub logger: Option<logger::Config>,
|
||||
|
||||
/// Catch any code panic and log the error.
|
||||
pub catch_panic: Option<catch_panic::CatchPanic>,
|
||||
|
||||
/// Setting a global timeout for requests
|
||||
pub timeout_request: Option<timeout::TimeOut>,
|
||||
|
||||
/// CORS configuration
|
||||
pub cors: Option<cors::Cors>,
|
||||
|
||||
/// Sets a set of secure headers
|
||||
pub secure_headers: Option<secure_headers::SecureHeader>,
|
||||
|
||||
/// Calculates a remote IP based on `X-Forwarded-For` when behind a proxy
|
||||
pub remote_ip: Option<remote_ip::RemoteIpMiddleware>,
|
||||
|
||||
/// Request ID
|
||||
pub request_id: Option<request_id::RequestId>,
|
||||
}
|
||||
306
apps/recorder/src/web/middleware/remote_ip.rs
Normal file
306
apps/recorder/src/web/middleware/remote_ip.rs
Normal file
@@ -0,0 +1,306 @@
|
||||
//! Remote IP Middleware for inferring the client's IP address based on the
|
||||
//! `X-Forwarded-For` header.
|
||||
//!
|
||||
//! This middleware is useful when running behind proxies or load balancers that
|
||||
//! add the `X-Forwarded-For` header, which includes the original client IP
|
||||
//! address.
|
||||
//!
|
||||
//! The middleware provides a mechanism to configure trusted proxies and extract
|
||||
//! the most likely client IP from the `X-Forwarded-For` header, skipping any
|
||||
//! trusted proxy IPs.
|
||||
use std::{
|
||||
fmt,
|
||||
iter::Iterator,
|
||||
net::{IpAddr, SocketAddr},
|
||||
str::FromStr,
|
||||
sync::{Arc, OnceLock},
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use axum::{
|
||||
Router,
|
||||
body::Body,
|
||||
extract::{ConnectInfo, FromRequestParts, Request},
|
||||
http::{header::HeaderMap, request::Parts},
|
||||
response::Response,
|
||||
};
|
||||
use futures_util::future::BoxFuture;
|
||||
use ipnetwork::IpNetwork;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tower::{Layer, Service};
|
||||
use tracing::error;
|
||||
|
||||
use crate::{
|
||||
app::AppContext,
|
||||
errors::{RError, RResult},
|
||||
web::middleware::MiddlewareLayer,
|
||||
};
|
||||
|
||||
static LOCAL_TRUSTED_PROXIES: OnceLock<Vec<IpNetwork>> = OnceLock::new();
|
||||
|
||||
fn get_local_trusted_proxies() -> &'static Vec<IpNetwork> {
|
||||
LOCAL_TRUSTED_PROXIES.get_or_init(|| {
|
||||
[
|
||||
"127.0.0.0/8", // localhost IPv4 range, per RFC-3330
|
||||
"::1", // localhost IPv6
|
||||
"fc00::/7", // private IPv6 range fc00::/7
|
||||
"10.0.0.0/8", // private IPv4 range 10.x.x.x
|
||||
"172.16.0.0/12", // private IPv4 range 172.16.0.0 .. 172.31.255.255
|
||||
"192.168.0.0/16",
|
||||
]
|
||||
.iter()
|
||||
.map(|ip| IpNetwork::from_str(ip).unwrap())
|
||||
.collect()
|
||||
})
|
||||
}
|
||||
|
||||
const X_FORWARDED_FOR: &str = "X-Forwarded-For";
|
||||
|
||||
///
|
||||
/// Performs a remote ip "calculation", inferring the most likely
|
||||
/// client IP from the `X-Forwarded-For` header that is used by
|
||||
/// load balancers and proxies.
|
||||
///
|
||||
/// WARNING
|
||||
/// =======
|
||||
///
|
||||
/// LIKE ANY SUCH REMOTE IP MIDDLEWARE, IN THE WRONG ARCHITECTURE IT CAN MAKE
|
||||
/// YOU VULNERABLE TO IP SPOOFING.
|
||||
///
|
||||
/// This middleware assumes that there is at least one proxy sitting around and
|
||||
/// setting headers with the client's remote IP address. Otherwise any client
|
||||
/// can claim to have any IP address by setting the `X-Forwarded-For` header.
|
||||
///
|
||||
/// DO NOT USE THIS MIDDLEWARE IF YOU DONT KNOW THAT YOU NEED IT
|
||||
///
|
||||
/// -- But if you need it, it's crucial to use it (since it's the only way to
|
||||
/// get the original client IP)
|
||||
///
|
||||
/// This middleware is mostly implemented after the Rails `remote_ip`
|
||||
/// middleware, and looking at other production Rust services with Axum, taking
|
||||
/// the best of both worlds to balance performance and pragmatism.
|
||||
///
|
||||
/// Similarities to the Rails `remote_ip` middleware:
|
||||
///
|
||||
/// * Uses `X-Forwarded-For`
|
||||
/// * Uses the same built-in trusted proxies list
|
||||
/// * You can provide a list of `trusted_proxies` which will **replace** the
|
||||
/// built-in trusted proxies
|
||||
///
|
||||
/// Differences from the Rails `remote_ip` middleware:
|
||||
///
|
||||
/// * You get an indication if the remote IP is actually resolved or is the
|
||||
/// socket IP (no `X-Forwarded-For` header or could not parse)
|
||||
/// * We do not not use the `Client-IP` header, or try to detect "spoofing"
|
||||
/// (spoofing while doing remote IP resolution is virtually non-detectable)
|
||||
/// * Order of filtering IPs from `X-Forwarded-For` is done according to [the de
|
||||
/// facto spec](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For#selecting_an_ip_address)
|
||||
/// "Trusted proxy list"
|
||||
#[derive(Default, Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct RemoteIpMiddleware {
|
||||
#[serde(default)]
|
||||
pub enable: bool,
|
||||
/// A list of alternative proxy list IP ranges and/or network range (will
|
||||
/// replace built-in proxy list)
|
||||
pub trusted_proxies: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
impl MiddlewareLayer for RemoteIpMiddleware {
|
||||
/// Returns the name of the middleware
|
||||
fn name(&self) -> &'static str {
|
||||
"remote_ip"
|
||||
}
|
||||
|
||||
/// Returns whether the middleware is enabled or not
|
||||
fn is_enabled(&self) -> bool {
|
||||
self.enable
|
||||
&& (self.trusted_proxies.is_none()
|
||||
|| self.trusted_proxies.as_ref().is_some_and(|t| !t.is_empty()))
|
||||
}
|
||||
|
||||
fn config(&self) -> serde_json::Result<serde_json::Value> {
|
||||
serde_json::to_value(self)
|
||||
}
|
||||
|
||||
/// Applies the Remote IP middleware to the given Axum router.
|
||||
fn apply(&self, app: Router<Arc<AppContext>>) -> RResult<Router<Arc<AppContext>>> {
|
||||
Ok(app.layer(RemoteIPLayer::new(self)?))
|
||||
}
|
||||
}
|
||||
|
||||
// implementation reference: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For
|
||||
fn maybe_get_forwarded(
|
||||
headers: &HeaderMap,
|
||||
trusted_proxies: Option<&Vec<IpNetwork>>,
|
||||
) -> Option<IpAddr> {
|
||||
/*
|
||||
> There may be multiple X-Forwarded-For headers present in a request. The IP addresses in these headers must be treated as a single list,
|
||||
> starting with the first IP address of the first header and continuing to the last IP address of the last header.
|
||||
> There are two ways of making this single list:
|
||||
> join the X-Forwarded-For full header values with commas and then split by comma into a list, or
|
||||
> split each X-Forwarded-For header by comma into lists and then join the lists
|
||||
*/
|
||||
let xffs = headers
|
||||
.get_all(X_FORWARDED_FOR)
|
||||
.iter()
|
||||
.map(|hdr| hdr.to_str())
|
||||
.filter_map(Result::ok)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if xffs.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let forwarded = xffs.join(",");
|
||||
|
||||
forwarded
|
||||
.split(',')
|
||||
.map(str::trim)
|
||||
.map(str::parse)
|
||||
.filter_map(Result::ok)
|
||||
/*
|
||||
> Trusted proxy list: The IPs or IP ranges of the trusted reverse proxies are configured.
|
||||
> The X-Forwarded-For IP list is searched from the rightmost, skipping all addresses that
|
||||
> are on the trusted proxy list. The first non-matching address is the target address.
|
||||
*/
|
||||
.filter(|ip| {
|
||||
// trusted proxies provided REPLACES our default local proxies
|
||||
let proxies = trusted_proxies.unwrap_or_else(|| get_local_trusted_proxies());
|
||||
!proxies
|
||||
.iter()
|
||||
.any(|trusted_proxy| trusted_proxy.contains(*ip))
|
||||
})
|
||||
/*
|
||||
> When choosing the X-Forwarded-For client IP address closest to the client (untrustworthy
|
||||
> and not for security-related purposes), the first IP from the leftmost that is a valid
|
||||
> address and not private/internal should be selected.
|
||||
>
|
||||
NOTE:
|
||||
> The first trustworthy X-Forwarded-For IP address may belong to an untrusted intermediate
|
||||
> proxy rather than the actual client computer, but it is the only IP suitable for security uses.
|
||||
*/
|
||||
.next_back()
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub enum RemoteIP {
|
||||
Forwarded(IpAddr),
|
||||
Socket(IpAddr),
|
||||
None,
|
||||
}
|
||||
|
||||
impl<S> FromRequestParts<S> for RemoteIP
|
||||
where
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = ();
|
||||
|
||||
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
|
||||
let ip = parts.extensions.get::<Self>();
|
||||
Ok(*ip.unwrap_or(&Self::None))
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for RemoteIP {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::Forwarded(ip) => write!(f, "remote: {ip}"),
|
||||
Self::Socket(ip) => write!(f, "socket: {ip}"),
|
||||
Self::None => write!(f, "--"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct RemoteIPLayer {
|
||||
trusted_proxies: Option<Vec<IpNetwork>>,
|
||||
}
|
||||
|
||||
impl RemoteIPLayer {
|
||||
/// Returns new secure headers middleware
|
||||
///
|
||||
/// # Errors
|
||||
/// Fails if invalid header values found
|
||||
pub fn new(config: &RemoteIpMiddleware) -> RResult<Self> {
|
||||
Ok(Self {
|
||||
trusted_proxies: config
|
||||
.trusted_proxies
|
||||
.as_ref()
|
||||
.map(|proxies| {
|
||||
proxies
|
||||
.iter()
|
||||
.map(|proxy| {
|
||||
IpNetwork::from_str(proxy).map_err(|err| {
|
||||
RError::CustomMessageString(format!(
|
||||
"remote ip middleare cannot parse trusted proxy \
|
||||
configuration: `{proxy}`, reason: `{err}`",
|
||||
))
|
||||
})
|
||||
})
|
||||
.collect::<RResult<Vec<_>>>()
|
||||
})
|
||||
.transpose()?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Layer<S> for RemoteIPLayer {
|
||||
type Service = RemoteIPMiddleware<S>;
|
||||
|
||||
fn layer(&self, inner: S) -> Self::Service {
|
||||
RemoteIPMiddleware {
|
||||
inner,
|
||||
layer: self.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Remote IP Detection Middleware
|
||||
#[derive(Clone, Debug)]
|
||||
#[must_use]
|
||||
pub struct RemoteIPMiddleware<S> {
|
||||
inner: S,
|
||||
layer: RemoteIPLayer,
|
||||
}
|
||||
|
||||
impl<S> Service<Request<Body>> for RemoteIPMiddleware<S>
|
||||
where
|
||||
S: Service<Request<Body>, Response = Response> + Send + 'static,
|
||||
S::Future: Send + 'static,
|
||||
{
|
||||
type Response = S::Response;
|
||||
type Error = S::Error;
|
||||
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
|
||||
|
||||
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.inner.poll_ready(cx)
|
||||
}
|
||||
|
||||
fn call(&mut self, mut req: Request<Body>) -> Self::Future {
|
||||
let layer = self.layer.clone();
|
||||
let xff_ip = maybe_get_forwarded(req.headers(), layer.trusted_proxies.as_ref());
|
||||
let remote_ip = xff_ip.map_or_else(
|
||||
|| {
|
||||
let ip = req
|
||||
.extensions()
|
||||
.get::<ConnectInfo<SocketAddr>>()
|
||||
.map_or_else(
|
||||
|| {
|
||||
error!(
|
||||
"remote ip middleware cannot get socket IP (not set in axum \
|
||||
extensions): setting IP to `127.0.0.1`"
|
||||
);
|
||||
RemoteIP::None
|
||||
},
|
||||
|info| RemoteIP::Socket(info.ip()),
|
||||
);
|
||||
ip
|
||||
},
|
||||
RemoteIP::Forwarded,
|
||||
);
|
||||
|
||||
req.extensions_mut().insert(remote_ip);
|
||||
|
||||
Box::pin(self.inner.call(req))
|
||||
}
|
||||
}
|
||||
132
apps/recorder/src/web/middleware/request_id.rs
Normal file
132
apps/recorder/src/web/middleware/request_id.rs
Normal file
@@ -0,0 +1,132 @@
|
||||
//! Middleware to generate or ensure a unique request ID for every request.
|
||||
//!
|
||||
//! The request ID is stored in the `x-request-id` header, and it is either
|
||||
//! generated or sanitized if already present in the request.
|
||||
//!
|
||||
//! This can be useful for tracking requests across services, logging, and
|
||||
//! debugging.
|
||||
|
||||
use axum::{Router, extract::Request, http::HeaderValue, middleware::Next, response::Response};
|
||||
use regex::Regex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{web::middleware::MiddlewareLayer, app::AppContext, errors::RResult};
|
||||
|
||||
const X_REQUEST_ID: &str = "x-request-id";
|
||||
const MAX_LEN: usize = 255;
|
||||
|
||||
use std::sync::{Arc, OnceLock};
|
||||
|
||||
static ID_CLEANUP: OnceLock<Regex> = OnceLock::new();
|
||||
|
||||
fn get_id_cleanup() -> &'static Regex {
|
||||
ID_CLEANUP.get_or_init(|| Regex::new(r"[^\w\-@]").unwrap())
|
||||
}
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct RequestId {
|
||||
#[serde(default)]
|
||||
pub enable: bool,
|
||||
}
|
||||
|
||||
impl MiddlewareLayer for RequestId {
|
||||
/// Returns the name of the middleware
|
||||
fn name(&self) -> &'static str {
|
||||
"request_id"
|
||||
}
|
||||
|
||||
/// Returns whether the middleware is enabled or not
|
||||
fn is_enabled(&self) -> bool {
|
||||
self.enable
|
||||
}
|
||||
|
||||
fn config(&self) -> serde_json::Result<serde_json::Value> {
|
||||
serde_json::to_value(self)
|
||||
}
|
||||
|
||||
/// Applies the request ID middleware to the Axum router.
|
||||
///
|
||||
/// This function sets up the middleware in the router and ensures that
|
||||
/// every request passing through it will have a unique or sanitized
|
||||
/// request ID.
|
||||
///
|
||||
/// # Errors
|
||||
/// This function returns an error if the middleware cannot be applied.
|
||||
fn apply(&self, app: Router<Arc<AppContext>>) -> RResult<Router<Arc<AppContext>>> {
|
||||
Ok(app.layer(axum::middleware::from_fn(request_id_middleware)))
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrapper struct for storing the request ID in the request's extensions.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LocoRequestId(String);
|
||||
|
||||
impl LocoRequestId {
|
||||
/// Retrieves the request ID as a string slice.
|
||||
#[must_use]
|
||||
pub fn get(&self) -> &str {
|
||||
self.0.as_str()
|
||||
}
|
||||
}
|
||||
|
||||
/// Middleware function to ensure or generate a unique request ID.
|
||||
///
|
||||
/// This function intercepts requests, checks for the presence of the
|
||||
/// `x-request-id` header, and either sanitizes its value or generates a new
|
||||
/// UUID if absent. The resulting request ID is added to both the request
|
||||
/// extensions and the response headers.
|
||||
pub async fn request_id_middleware(mut request: Request, next: Next) -> Response {
|
||||
let header_request_id = request.headers().get(X_REQUEST_ID).cloned();
|
||||
let request_id = make_request_id(header_request_id);
|
||||
request
|
||||
.extensions_mut()
|
||||
.insert(LocoRequestId(request_id.clone()));
|
||||
let mut res = next.run(request).await;
|
||||
|
||||
if let Ok(v) = HeaderValue::from_str(request_id.as_str()) {
|
||||
res.headers_mut().insert(X_REQUEST_ID, v);
|
||||
} else {
|
||||
tracing::warn!("could not set request ID into response headers: `{request_id}`",);
|
||||
}
|
||||
res
|
||||
}
|
||||
|
||||
/// Generates or sanitizes a request ID.
|
||||
fn make_request_id(maybe_request_id: Option<HeaderValue>) -> String {
|
||||
maybe_request_id
|
||||
.and_then(|hdr| {
|
||||
// see: https://github.com/rails/rails/blob/main/actionpack/lib/action_dispatch/middleware/request_id.rb#L39
|
||||
let id: Option<String> = hdr.to_str().ok().map(|s| {
|
||||
get_id_cleanup()
|
||||
.replace_all(s, "")
|
||||
.chars()
|
||||
.take(MAX_LEN)
|
||||
.collect()
|
||||
});
|
||||
id.filter(|s| !s.is_empty())
|
||||
})
|
||||
.unwrap_or_else(|| Uuid::new_v4().to_string())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use axum::http::HeaderValue;
|
||||
use insta::assert_debug_snapshot;
|
||||
|
||||
use super::make_request_id;
|
||||
|
||||
#[test]
|
||||
fn create_or_fetch_request_id() {
|
||||
let id = make_request_id(Some(HeaderValue::from_static("foo-bar=baz")));
|
||||
assert_debug_snapshot!(id);
|
||||
let id = make_request_id(Some(HeaderValue::from_static("")));
|
||||
assert_debug_snapshot!(id.len());
|
||||
let id = make_request_id(Some(HeaderValue::from_static("==========")));
|
||||
assert_debug_snapshot!(id.len());
|
||||
let long_id = "x".repeat(1000);
|
||||
let id = make_request_id(Some(HeaderValue::from_str(&long_id).unwrap()));
|
||||
assert_debug_snapshot!(id.len());
|
||||
let id = make_request_id(None);
|
||||
assert_debug_snapshot!(id.len());
|
||||
}
|
||||
}
|
||||
26
apps/recorder/src/web/middleware/secure_headers.json
Normal file
26
apps/recorder/src/web/middleware/secure_headers.json
Normal file
@@ -0,0 +1,26 @@
|
||||
{
|
||||
"empty":{},
|
||||
"github":{
|
||||
"Content-Security-Policy": "default-src 'self' https:; font-src 'self' https: data:; img-src 'self' https: data:; object-src 'none'; script-src https:; style-src 'self' https: 'unsafe-inline'",
|
||||
"Strict-Transport-Security": "max-age=631138519",
|
||||
"X-Content-Type-Options": "nosniff",
|
||||
"X-Download-Options": "noopen",
|
||||
"X-Frame-Options": "sameorigin",
|
||||
"X-Permitted-Cross-Domain-Policies": "none",
|
||||
"X-Xss-Protection": "0"
|
||||
},
|
||||
"owasp":{
|
||||
"Cache-Control": "no-store, max-age=0",
|
||||
"Clear-Site-Data": "\"cache\",\"cookies\",\"storage\"",
|
||||
"Content-Security-Policy": "default-src 'self'; form-action 'self'; object-src 'none'; frame-ancestors 'none'; upgrade-insecure-requests; block-all-mixed-content",
|
||||
"Cross-Origin-Embedder-Policy": "require-corp",
|
||||
"Cross-Origin-Opener-Policy": "same-origin",
|
||||
"Cross-Origin-Resource-Policy": "same-origin",
|
||||
"Permissions-Policy": "accelerometer=(), autoplay=(), camera=(), cross-origin-isolated=(), display-capture=(), encrypted-media=(), fullscreen=(), geolocation=(), gyroscope=(), keyboard-map=(), magnetometer=(), microphone=(), midi=(), payment=(), picture-in-picture=(), publickey-credentials-get=(), screen-wake-lock=(), sync-xhr=(self), usb=(), web-share=(), xr-spatial-tracking=(), clipboard-read=(), clipboard-write=(), gamepad=(), hid=(), idle-detection=(), interest-cohort=(), serial=(), unload=()",
|
||||
"Referrer-Policy": "no-referrer",
|
||||
"Strict-Transport-Security": "max-age=31536000; includeSubDomains",
|
||||
"X-Content-Type-Options": "nosniff",
|
||||
"X-Frame-Options": "deny",
|
||||
"X-Permitted-Cross-Domain-Policies": "none"
|
||||
}
|
||||
}
|
||||
311
apps/recorder/src/web/middleware/secure_headers.rs
Normal file
311
apps/recorder/src/web/middleware/secure_headers.rs
Normal file
@@ -0,0 +1,311 @@
|
||||
//! Sets secure headers for your backend to promote security-by-default.
|
||||
//!
|
||||
//! This middleware applies secure HTTP headers, providing pre-defined presets
|
||||
//! (e.g., "github") and the ability to override or define custom headers.
|
||||
|
||||
use std::{
|
||||
collections::{BTreeMap, HashMap},
|
||||
sync::{Arc, OnceLock},
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use axum::{
|
||||
Router,
|
||||
body::Body,
|
||||
http::{HeaderName, HeaderValue, Request},
|
||||
response::Response,
|
||||
};
|
||||
use futures_util::future::BoxFuture;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{self, json};
|
||||
use tower::{Layer, Service};
|
||||
|
||||
use crate::{
|
||||
app::AppContext,
|
||||
web::middleware::MiddlewareLayer,
|
||||
errors::{RError, RResult},
|
||||
};
|
||||
|
||||
static PRESETS: OnceLock<HashMap<String, BTreeMap<String, String>>> = OnceLock::new();
|
||||
fn get_presets() -> &'static HashMap<String, BTreeMap<String, String>> {
|
||||
PRESETS.get_or_init(|| {
|
||||
let json_data = include_str!("secure_headers.json");
|
||||
serde_json::from_str(json_data).unwrap()
|
||||
})
|
||||
}
|
||||
/// Sets a predefined or custom set of secure headers.
|
||||
///
|
||||
/// We recommend our `github` preset. Presets values are derived
|
||||
/// from the [secure_headers](https://github.com/github/secure_headers) Ruby
|
||||
/// library which Github (and originally Twitter) use.
|
||||
///
|
||||
/// To use a preset, in your `config/development.yaml`:
|
||||
///
|
||||
/// ```yaml
|
||||
/// middlewares:
|
||||
/// secure_headers:
|
||||
/// preset: github
|
||||
/// ```
|
||||
///
|
||||
/// You can also override individual headers on a given preset:
|
||||
///
|
||||
/// ```yaml
|
||||
/// middlewares:
|
||||
/// secure_headers:
|
||||
/// preset: github
|
||||
/// overrides:
|
||||
/// foo: bar
|
||||
/// ```
|
||||
///
|
||||
/// Or start from scratch:
|
||||
///
|
||||
///```yaml
|
||||
/// middlewares:
|
||||
/// secure_headers:
|
||||
/// preset: empty
|
||||
/// overrides:
|
||||
/// one: two
|
||||
/// ```
|
||||
///
|
||||
/// To support `htmx`, You can add the following override, to allow some inline
|
||||
/// running of scripts:
|
||||
///
|
||||
/// ```yaml
|
||||
/// secure_headers:
|
||||
/// preset: github
|
||||
/// overrides:
|
||||
/// # this allows you to use HTMX, and has unsafe-inline. Remove or consider in production
|
||||
/// "Content-Security-Policy": "default-src 'self' https:; font-src 'self' https: data:; img-src 'self' https: data:; object-src 'none'; script-src 'unsafe-inline' 'self' https:; style-src 'self' https: 'unsafe-inline'"
|
||||
/// ```
|
||||
///
|
||||
/// For the list of presets and their content look at [secure_headers.json](https://github.com/loco-rs/loco/blob/master/src/controller/middleware/secure_headers.rs)
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct SecureHeader {
|
||||
#[serde(default)]
|
||||
pub enable: bool,
|
||||
#[serde(default = "default_preset")]
|
||||
pub preset: String,
|
||||
#[serde(default)]
|
||||
pub overrides: Option<BTreeMap<String, String>>,
|
||||
}
|
||||
|
||||
impl Default for SecureHeader {
|
||||
fn default() -> Self {
|
||||
serde_json::from_value(json!({})).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
fn default_preset() -> String {
|
||||
"github".to_string()
|
||||
}
|
||||
|
||||
impl MiddlewareLayer for SecureHeader {
|
||||
/// Returns the name of the middleware
|
||||
fn name(&self) -> &'static str {
|
||||
"secure_headers"
|
||||
}
|
||||
|
||||
/// Returns whether the middleware is enabled or not
|
||||
fn is_enabled(&self) -> bool {
|
||||
self.enable
|
||||
}
|
||||
|
||||
fn config(&self) -> serde_json::Result<serde_json::Value> {
|
||||
serde_json::to_value(self)
|
||||
}
|
||||
|
||||
/// Applies the secure headers layer to the application router
|
||||
fn apply(&self, app: Router<Arc<AppContext>>) -> RResult<Router<Arc<AppContext>>> {
|
||||
Ok(app.layer(SecureHeaders::new(self)?))
|
||||
}
|
||||
}
|
||||
|
||||
impl SecureHeader {
|
||||
/// Converts the configuration into a list of headers.
|
||||
///
|
||||
/// Applies the preset headers and any custom overrides.
|
||||
fn as_headers(&self) -> RResult<Vec<(HeaderName, HeaderValue)>> {
|
||||
let mut headers = vec![];
|
||||
|
||||
let preset = &self.preset;
|
||||
let p = get_presets().get(preset).ok_or_else(|| {
|
||||
RError::CustomMessageString(format!(
|
||||
"secure_headers: a preset named `{preset}` does not exist"
|
||||
))
|
||||
})?;
|
||||
|
||||
Self::push_headers(&mut headers, p)?;
|
||||
if let Some(overrides) = &self.overrides {
|
||||
Self::push_headers(&mut headers, overrides)?;
|
||||
}
|
||||
Ok(headers)
|
||||
}
|
||||
|
||||
/// Helper function to push headers into a mutable vector.
|
||||
///
|
||||
/// This function takes a map of header names and values, converting them
|
||||
/// into valid HTTP headers and adding them to the provided `headers`
|
||||
/// vector.
|
||||
fn push_headers(
|
||||
headers: &mut Vec<(HeaderName, HeaderValue)>,
|
||||
hm: &BTreeMap<String, String>,
|
||||
) -> RResult<()> {
|
||||
for (k, v) in hm {
|
||||
headers.push((
|
||||
HeaderName::from_bytes(k.clone().as_bytes())?,
|
||||
HeaderValue::from_str(v.clone().as_str())?,
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// The [`SecureHeaders`] layer which wraps around the service and injects
|
||||
/// security headers
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct SecureHeaders {
|
||||
headers: Vec<(HeaderName, HeaderValue)>,
|
||||
}
|
||||
|
||||
impl SecureHeaders {
|
||||
/// Creates a new [`SecureHeaders`] instance with the provided
|
||||
/// configuration.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if any header values are invalid.
|
||||
pub fn new(config: &SecureHeader) -> RResult<Self> {
|
||||
Ok(Self {
|
||||
headers: config.as_headers()?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Layer<S> for SecureHeaders {
|
||||
type Service = SecureHeadersMiddleware<S>;
|
||||
|
||||
/// Wraps the provided service with the secure headers middleware.
|
||||
fn layer(&self, inner: S) -> Self::Service {
|
||||
SecureHeadersMiddleware {
|
||||
inner,
|
||||
layer: self.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The secure headers middleware
|
||||
#[derive(Clone, Debug)]
|
||||
#[must_use]
|
||||
pub struct SecureHeadersMiddleware<S> {
|
||||
inner: S,
|
||||
layer: SecureHeaders,
|
||||
}
|
||||
|
||||
impl<S> Service<Request<Body>> for SecureHeadersMiddleware<S>
|
||||
where
|
||||
S: Service<Request<Body>, Response = Response> + Send + 'static,
|
||||
S::Future: Send + 'static,
|
||||
{
|
||||
type Response = S::Response;
|
||||
type Error = S::Error;
|
||||
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
|
||||
|
||||
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.inner.poll_ready(cx)
|
||||
}
|
||||
|
||||
fn call(&mut self, request: Request<Body>) -> Self::Future {
|
||||
let layer = self.layer.clone();
|
||||
let future = self.inner.call(request);
|
||||
Box::pin(async move {
|
||||
let mut response: Response = future.await?;
|
||||
let headers = response.headers_mut();
|
||||
for (k, v) in &layer.headers {
|
||||
headers.insert(k, v.clone());
|
||||
}
|
||||
Ok(response)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use axum::{
|
||||
Router,
|
||||
http::{HeaderMap, Method},
|
||||
routing::get,
|
||||
};
|
||||
use insta::assert_debug_snapshot;
|
||||
use tower::ServiceExt;
|
||||
|
||||
use super::*;
|
||||
fn normalize_headers(headers: &HeaderMap) -> BTreeMap<String, String> {
|
||||
headers
|
||||
.iter()
|
||||
.map(|(k, v)| {
|
||||
let key = k.to_string();
|
||||
let value = v.to_str().unwrap_or("").to_string();
|
||||
(key, value)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
#[tokio::test]
|
||||
async fn can_set_headers() {
|
||||
let config = SecureHeader {
|
||||
enable: true,
|
||||
preset: "github".to_string(),
|
||||
overrides: None,
|
||||
};
|
||||
let app = Router::new()
|
||||
.route("/", get(|| async {}))
|
||||
.layer(SecureHeaders::new(&config).unwrap());
|
||||
|
||||
let req = Request::builder()
|
||||
.uri("/")
|
||||
.method(Method::GET)
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
let response = app.oneshot(req).await.unwrap();
|
||||
assert_debug_snapshot!(normalize_headers(response.headers()));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn can_override_headers() {
|
||||
let mut overrides = BTreeMap::new();
|
||||
overrides.insert("X-Download-Options".to_string(), "foobar".to_string());
|
||||
overrides.insert("New-Header".to_string(), "baz".to_string());
|
||||
|
||||
let config = SecureHeader {
|
||||
enable: true,
|
||||
preset: "github".to_string(),
|
||||
overrides: Some(overrides),
|
||||
};
|
||||
let app = Router::new()
|
||||
.route("/", get(|| async {}))
|
||||
.layer(SecureHeaders::new(&config).unwrap());
|
||||
|
||||
let req = Request::builder()
|
||||
.uri("/")
|
||||
.method(Method::GET)
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
let response = app.oneshot(req).await.unwrap();
|
||||
assert_debug_snapshot!(normalize_headers(response.headers()));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn default_is_github_preset() {
|
||||
let config = SecureHeader::default();
|
||||
let app = Router::new()
|
||||
.route("/", get(|| async {}))
|
||||
.layer(SecureHeaders::new(&config).unwrap());
|
||||
|
||||
let req = Request::builder()
|
||||
.uri("/")
|
||||
.method(Method::GET)
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
let response = app.oneshot(req).await.unwrap();
|
||||
assert_debug_snapshot!(normalize_headers(response.headers()));
|
||||
}
|
||||
}
|
||||
64
apps/recorder/src/web/middleware/timeout.rs
Normal file
64
apps/recorder/src/web/middleware/timeout.rs
Normal file
@@ -0,0 +1,64 @@
|
||||
//! Timeout Request Middleware.
|
||||
//!
|
||||
//! This middleware applies a timeout to requests processed by the application.
|
||||
//! The timeout duration is configurable and defined via the
|
||||
//! [`TimeoutRequestMiddleware`] configuration. The middleware ensures that
|
||||
//! requests do not run beyond the specified timeout period, improving the
|
||||
//! overall performance and responsiveness of the application.
|
||||
//!
|
||||
//! If a request exceeds the specified timeout duration, the middleware will
|
||||
//! return a `408 Request Timeout` status code to the client, indicating that
|
||||
//! the request took too long to process.
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
use axum::Router;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use tower_http::timeout::TimeoutLayer;
|
||||
|
||||
use crate::{app::AppContext, errors::RResult, web::middleware::MiddlewareLayer};
|
||||
|
||||
/// Timeout middleware configuration
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct TimeOut {
|
||||
#[serde(default)]
|
||||
pub enable: bool,
|
||||
// Timeout request in milliseconds
|
||||
#[serde(default = "default_timeout")]
|
||||
pub timeout: u64,
|
||||
}
|
||||
|
||||
impl Default for TimeOut {
|
||||
fn default() -> Self {
|
||||
serde_json::from_value(json!({})).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
fn default_timeout() -> u64 {
|
||||
5_000
|
||||
}
|
||||
|
||||
impl MiddlewareLayer for TimeOut {
|
||||
/// Returns the name of the middleware.
|
||||
fn name(&self) -> &'static str {
|
||||
"timeout_request"
|
||||
}
|
||||
|
||||
/// Checks if the timeout middleware is enabled.
|
||||
fn is_enabled(&self) -> bool {
|
||||
self.enable
|
||||
}
|
||||
|
||||
fn config(&self) -> serde_json::Result<serde_json::Value> {
|
||||
serde_json::to_value(self)
|
||||
}
|
||||
|
||||
/// Applies the timeout middleware to the application router.
|
||||
///
|
||||
/// This method wraps the provided [`AXRouter`] in a [`TimeoutLayer`],
|
||||
/// ensuring that requests exceeding the specified timeout duration will
|
||||
/// be interrupted.
|
||||
fn apply(&self, app: Router<Arc<AppContext>>) -> RResult<Router<Arc<AppContext>>> {
|
||||
Ok(app.layer(TimeoutLayer::new(Duration::from_millis(self.timeout))))
|
||||
}
|
||||
}
|
||||
5
apps/recorder/src/web/mod.rs
Normal file
5
apps/recorder/src/web/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
pub mod config;
|
||||
pub mod controller;
|
||||
pub mod middleware;
|
||||
|
||||
pub use config::WebServerConfig;
|
||||
Reference in New Issue
Block a user