refactor: remove loco-rs deps

This commit is contained in:
master 2025-03-01 15:21:14 +08:00
parent a68aab1452
commit 2844e1fc32
66 changed files with 2565 additions and 1876 deletions

1426
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -22,8 +22,6 @@ testcontainers = [
] ]
[dependencies] [dependencies]
loco-rs = { version = "0.14" }
zino = { version = "0.33", features = ["axum"] }
serde = { version = "1", features = ["derive"] } serde = { version = "1", features = ["derive"] }
serde_json = "1" serde_json = "1"
tokio = { version = "1.42", features = ["macros", "fs", "rt-multi-thread"] } tokio = { version = "1.42", features = ["macros", "fs", "rt-multi-thread"] }
@ -97,8 +95,17 @@ seaography = { version = "1.1" }
quirks_path = "0.1.1" quirks_path = "0.1.1"
base64 = "0.22.1" base64 = "0.22.1"
tower = "0.5.2" tower = "0.5.2"
axum-extra = "0.10.0" axum-extra = "0.10"
tower-http = "0.6.2" tower-http = { version = "0.6", features = [
"trace",
"catch-panic",
"timeout",
"add-extension",
"cors",
"fs",
"set-header",
"compression-full",
] }
serde_yaml = "0.9.34" serde_yaml = "0.9.34"
tera = "1.20.0" tera = "1.20.0"
openidconnect = { version = "4", features = ["rustls-tls"] } openidconnect = { version = "4", features = ["rustls-tls"] }
@ -119,10 +126,14 @@ secrecy = { version = "0.10.3", features = ["serde"] }
http = "1.2.0" http = "1.2.0"
cookie = "0.18.1" cookie = "0.18.1"
async-stream = "0.3.6" async-stream = "0.3.6"
serde_variant = "0.1.3"
tracing-appender = "0.2.3"
clap = "4.5.31"
futures-util = "0.3.31"
ipnetwork = "0.21.1"
[dev-dependencies] [dev-dependencies]
serial_test = "3" serial_test = "3"
loco-rs = { version = "0.14", features = ["testing"] }
insta = { version = "1", features = ["redactions", "yaml", "filters"] } insta = { version = "1", features = ["redactions", "yaml", "filters"] }
mockito = "1.6.1" mockito = "1.6.1"
rstest = "0.24.0" rstest = "0.24.0"

View File

@ -1,125 +0,0 @@
# Loco configuration file documentation
# Application logging configuration
logger:
# Enable or disable logging.
enable: true
# Enable pretty backtrace (sets RUST_BACKTRACE=1)
pretty_backtrace: true
# Log level, options: trace, debug, info, warn or error.
level: debug
# Define the logging format. options: compact, pretty or Json
format: compact
# By default the logger has filtering only logs that came from your code or logs that came from `loco` framework. to see all third party libraries
# Uncomment the line below to override to see all third party libraries you can enable this config and override the logger filters.
# override_filter: trace
# Web server configuration
server:
# Port on which the server will listen. the server binding is 0.0.0.0:{PORT}
port: 5001
# The UI hostname or IP address that mailers will point to.
host: http://webui.konobangu.com
# Out of the box middleware configuration. to disable middleware you can changed the `enable` field to `false` of comment the middleware block
middlewares:
# Enable Etag cache header middleware
etag:
enable: true
# Allows to limit the payload size request. payload that bigger than this file will blocked the request.
limit_payload:
# Enable/Disable the middleware.
enable: true
# the limit size. can be b,kb,kib,mb,mib,gb,gib
body_limit: 5mb
# Generating a unique request ID and enhancing logging with additional information such as the start and completion of request processing, latency, status code, and other request details.
logger:
# Enable/Disable the middleware.
enable: true
# when your code is panicked, the request still returns 500 status code.
catch_panic:
# Enable/Disable the middleware.
enable: true
# Timeout for incoming requests middleware. requests that take more time from the configuration will cute and 408 status code will returned.
timeout_request:
# Enable/Disable the middleware.
enable: false
# Duration time in milliseconds.
timeout: 5000
cors:
enable: true
# Set the value of the [`Access-Control-Allow-Origin`][mdn] header
# allow_origins:
# - https://loco.rs
# Set the value of the [`Access-Control-Allow-Headers`][mdn] header
# allow_headers:
# - Content-Type
# Set the value of the [`Access-Control-Allow-Methods`][mdn] header
# allow_methods:
# - POST
# Set the value of the [`Access-Control-Max-Age`][mdn] header in seconds
# max_age: 3600
# Worker Configuration
workers:
# specifies the worker mode. Options:
# - BackgroundQueue - Workers operate asynchronously in the background, processing queued.
# - ForegroundBlocking - Workers operate in the foreground and block until tasks are completed.
# - BackgroundAsync - Workers operate asynchronously in the background, processing tasks with async capabilities.
mode: BackgroundQueue
# Mailer Configuration.
mailer:
# SMTP mailer configuration.
smtp:
# Enable/Disable smtp mailer.
enable: true
# SMTP server host. e.x localhost, smtp.gmail.com
host: '{{ get_env(name="MAILER_HOST", default="localhost") }}'
# SMTP server port
port: 1025
# Use secure connection (SSL/TLS).
secure: false
# auth:
# user:
# password:
# Database Configuration
database:
# Database connection URI
uri: '{{ get_env(name="DATABASE_URL", default="postgres://konobangu:konobangu@127.0.0.1:5432/konobangu") }}'
# When enabled, the sql query will be logged.
enable_logging: true
# Set the timeout duration when acquiring a connection.
connect_timeout: 500
# Set the idle duration before closing a connection.
idle_timeout: 500
# Minimum number of connections for a pool.
min_connections: 1
# Maximum number of connections for a pool.
max_connections: 1
# Run migration up when application loaded
auto_migrate: true
# Truncate database when application loaded. This is a dangerous operation, make sure that you using this flag only on dev environments or test mode
dangerously_truncate: false
# Recreating schema when application loaded. This is a dangerous operation, make sure that you using this flag only on dev environments or test mode
dangerously_recreate: false
# Redis Configuration
redis:
# Redis connection URI
uri: '{{ get_env(name="REDIS_URL", default="redis://127.0.0.1:6379") }}'
# Dangerously flush all data in Redis on startup. dangerous operation, make sure that you using this flag only on dev environments or test mode
dangerously_flush: false
settings:
dal:
data_dir: ./temp
mikan:
http_client:
exponential_backoff_max_retries: 3
leaky_bucket_max_tokens: 2
leaky_bucket_initial_tokens: 0
leaky_bucket_refill_tokens: 1
leaky_bucket_refill_interval: 500
user_agent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36 Edg/131.0.0.0"
base_url: "https://mikanani.me/"

View File

@ -26,12 +26,6 @@ server:
# Enable Etag cache header middleware # Enable Etag cache header middleware
etag: etag:
enable: true enable: true
# Allows to limit the payload size request. payload that bigger than this file will blocked the request.
limit_payload:
# Enable/Disable the middleware.
enable: true
# the limit size. can be b,kb,kib,mb,mib,gb,gib
body_limit: 5mb
# Generating a unique request ID and enhancing logging with additional information such as the start and completion of request processing, latency, status code, and other request details. # Generating a unique request ID and enhancing logging with additional information such as the start and completion of request processing, latency, status code, and other request details.
logger: logger:
# Enable/Disable the middleware. # Enable/Disable the middleware.
@ -60,32 +54,6 @@ server:
# - POST # - POST
# Set the value of the [`Access-Control-Max-Age`][mdn] header in seconds # Set the value of the [`Access-Control-Max-Age`][mdn] header in seconds
# max_age: 3600 # max_age: 3600
fallback:
enable: false
# Worker Configuration
workers:
# specifies the worker mode. Options:
# - BackgroundQueue - Workers operate asynchronously in the background, processing queued.
# - ForegroundBlocking - Workers operate in the foreground and block until tasks are completed.
# - BackgroundAsync - Workers operate asynchronously in the background, processing tasks with async capabilities.
mode: BackgroundAsync
# Mailer Configuration.
mailer:
# SMTP mailer configuration.
smtp:
# Enable/Disable smtp mailer.
enable: true
# SMTP server host. e.x localhost, smtp.gmail.com
host: '{{ get_env(name="MAILER_HOST", default="localhost") }}'
# SMTP server port
port: 1025
# Use secure connection (SSL/TLS).
secure: false
# auth:
# user:
# password:
# Database Configuration # Database Configuration
database: database:
@ -103,21 +71,9 @@ database:
max_connections: 1 max_connections: 1
# Run migration up when application loaded # Run migration up when application loaded
auto_migrate: true auto_migrate: true
# Truncate database when application loaded. This is a dangerous operation, make sure that you using this flag only on dev environments or test mode
dangerously_truncate: false
# Recreating schema when application loaded. This is a dangerous operation, make sure that you using this flag only on dev environments or test mode
dangerously_recreate: false
# Redis Configuration storage:
redis: data_dir: '{{ get_env(name="STORAGE_DATA_DIR", default="./data") }}'
# Redis connection URI
uri: '{{ get_env(name="REDIS_URL", default="redis://localhost:6379") }}'
# Dangerously flush all data in Redis on startup. dangerous operation, make sure that you using this flag only on dev environments or test mode
dangerously_flush: false
settings:
dal:
data_dir: '{{ get_env(name="DAL_DATA_DIR", default="./data") }}'
mikan: mikan:
base_url: "https://mikanani.me/" base_url: "https://mikanani.me/"

View File

@ -1,153 +1,136 @@
use std::{path::Path, sync::Arc}; use std::sync::Arc;
use figment::Figment; use clap::{Parser, command};
use itertools::Itertools;
use super::{core::App, env::Enviornment}; use super::{AppContext, core::App, env::Environment};
use crate::{ use crate::{app::config::AppConfig, errors::RResult};
app::{config::AppConfig, context::create_context, router::create_router},
errors::RResult, #[derive(Parser, Debug)]
}; #[command(version, about, long_about = None)]
pub struct MainCliArgs {
/// Explicit config file path
#[arg(short, long)]
config_file: Option<String>,
/// Explicit dotenv file path
#[arg(short, long)]
dotenv_file: Option<String>,
/// Explicit working dir
#[arg(short, long)]
working_dir: Option<String>,
/// Explicit environment
#[arg(short, long)]
environment: Option<Environment>,
}
pub struct AppBuilder { pub struct AppBuilder {
dotenv_file: Option<String>, dotenv_file: Option<String>,
config_file: Option<String>, config_file: Option<String>,
working_dir: String, working_dir: String,
enviornment: Enviornment, enviornment: Environment,
} }
impl AppBuilder { impl AppBuilder {
pub async fn load_dotenv(&self) -> RResult<()> { pub async fn from_main_cli(environment: Option<Environment>) -> RResult<Self> {
let try_dotenv_file_or_dirs = if self.dotenv_file.is_some() { let args = MainCliArgs::parse();
vec![self.dotenv_file.as_deref()]
let environment = environment.unwrap_or_else(|| {
args.environment.unwrap_or({
if cfg!(test) {
Environment::Testing
} else if cfg!(debug_assertions) {
Environment::Development
} else { } else {
vec![Some(&self.working_dir as &str)] Environment::Production
};
let priority_suffix = &AppConfig::priority_suffix(&self.enviornment);
let dotenv_prefix = AppConfig::dotenv_prefix();
let try_filenames = priority_suffix
.iter()
.map(|ps| format!("{}{}", &dotenv_prefix, ps))
.collect_vec();
for try_dotenv_file_or_dir in try_dotenv_file_or_dirs.into_iter().flatten() {
let try_dotenv_file_or_dir_path = Path::new(try_dotenv_file_or_dir);
if try_dotenv_file_or_dir_path.exists() {
if try_dotenv_file_or_dir_path.is_dir() {
for f in try_filenames.iter() {
let p = try_dotenv_file_or_dir_path.join(f);
if p.exists() && p.is_file() {
dotenv::from_path(p)?;
break;
} }
}
} else if try_dotenv_file_or_dir_path.is_file() {
dotenv::from_path(try_dotenv_file_or_dir_path)?;
break;
}
}
}
Ok(())
}
pub async fn build_config(&self) -> RResult<AppConfig> {
let try_config_file_or_dirs = if self.config_file.is_some() {
vec![self.config_file.as_deref()]
} else {
vec![Some(&self.working_dir as &str)]
};
let allowed_extensions = &AppConfig::allowed_extension();
let priority_suffix = &AppConfig::priority_suffix(&self.enviornment);
let convention_prefix = &AppConfig::config_prefix();
let try_filenames = priority_suffix
.iter()
.flat_map(|ps| {
allowed_extensions
.iter()
.map(move |ext| (format!("{}{}{}", convention_prefix, ps, ext), ext))
}) })
.collect_vec(); });
let mut fig = Figment::from(AppConfig::default_provider()); let mut builder = Self::default();
for try_config_file_or_dir in try_config_file_or_dirs.into_iter().flatten() { if let Some(working_dir) = args.working_dir {
let try_config_file_or_dir_path = Path::new(try_config_file_or_dir); builder = builder.working_dir(working_dir);
if try_config_file_or_dir_path.exists() {
if try_config_file_or_dir_path.is_dir() {
for (f, ext) in try_filenames.iter() {
let p = try_config_file_or_dir_path.join(f);
if p.exists() && p.is_file() {
fig = AppConfig::merge_provider_from_file(fig, &p, ext)?;
break;
}
}
} else if let Some(ext) = try_config_file_or_dir_path
.extension()
.and_then(|s| s.to_str())
&& try_config_file_or_dir_path.is_file()
{
fig =
AppConfig::merge_provider_from_file(fig, try_config_file_or_dir_path, ext)?;
break;
}
} }
if matches!(
&environment,
Environment::Testing | Environment::Development
) {
builder = builder.working_dir_from_manifest_dir();
} }
let app_config: AppConfig = fig.extract()?; builder = builder
.config_file(args.config_file)
.dotenv_file(args.dotenv_file)
.environment(environment);
Ok(app_config) Ok(builder)
} }
pub async fn build(self) -> RResult<App> { pub async fn build(self) -> RResult<App> {
let _app_name = env!("CARGO_CRATE_NAME"); AppConfig::load_dotenv(
&self.enviornment,
&self.working_dir,
self.dotenv_file.as_deref(),
)
.await?;
let _app_version = format!( let config = AppConfig::load_config(
"{} ({})", &self.enviornment,
env!("CARGO_PKG_VERSION"), &self.working_dir,
option_env!("BUILD_SHA") self.config_file.as_deref(),
.or(option_env!("GITHUB_SHA")) )
.unwrap_or("dev") .await?;
let app_context = Arc::new(
AppContext::new(self.enviornment.clone(), config, self.working_dir.clone()).await?,
); );
self.load_dotenv().await?;
let config = self.build_config().await?;
let app_context = Arc::new(create_context(config).await?);
let router = create_router(app_context.clone()).await?;
Ok(App { Ok(App {
context: app_context, context: app_context,
router,
builder: self, builder: self,
}) })
} }
pub fn set_working_dir(self, working_dir: String) -> Self { pub fn working_dir(self, working_dir: String) -> Self {
let mut ret = self; let mut ret = self;
ret.working_dir = working_dir; ret.working_dir = working_dir;
ret ret
} }
pub fn set_working_dir_to_manifest_dir(self) -> Self { pub fn environment(self, environment: Environment) -> Self {
let manifest_dir = if cfg!(debug_assertions) { let mut ret = self;
ret.enviornment = environment;
ret
}
pub fn config_file(self, config_file: Option<String>) -> Self {
let mut ret = self;
ret.config_file = config_file;
ret
}
pub fn dotenv_file(self, dotenv_file: Option<String>) -> Self {
let mut ret = self;
ret.dotenv_file = dotenv_file;
ret
}
pub fn working_dir_from_manifest_dir(self) -> Self {
let manifest_dir = if cfg!(debug_assertions) || cfg!(test) {
env!("CARGO_MANIFEST_DIR") env!("CARGO_MANIFEST_DIR")
} else { } else {
"./apps/recorder" "./apps/recorder"
}; };
self.set_working_dir(manifest_dir.to_string()) self.working_dir(manifest_dir.to_string())
} }
} }
impl Default for AppBuilder { impl Default for AppBuilder {
fn default() -> Self { fn default() -> Self {
Self { Self {
enviornment: Enviornment::Production, enviornment: Environment::Production,
dotenv_file: None, dotenv_file: None,
config_file: None, config_file: None,
working_dir: String::from("."), working_dir: String::from("."),

View File

@ -14,3 +14,5 @@ leaky_bucket_refill_interval = 500
[graphql] [graphql]
depth_limit = inf depth_limit = inf
complexity_limit = inf complexity_limit = inf
[cache]

View File

@ -7,21 +7,26 @@ use figment::{
use itertools::Itertools; use itertools::Itertools;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use super::env::Enviornment; use super::env::Environment;
use crate::{ use crate::{
auth::AuthConfig, errors::RResult, extract::mikan::AppMikanConfig, auth::AuthConfig, cache::CacheConfig, database::DatabaseConfig, errors::RResult,
graphql::config::GraphQLConfig, storage::StorageConfig, extract::mikan::MikanConfig, graphql::GraphQLConfig, logger::LoggerConfig,
storage::StorageConfig, web::WebServerConfig,
}; };
const DEFAULT_CONFIG_MIXIN: &str = include_str!("./default_mixin.toml"); const DEFAULT_CONFIG_MIXIN: &str = include_str!("./default_mixin.toml");
const CONFIG_ALLOWED_EXTENSIONS: &[&str] = &[".toml", ".json", ".yaml", ".yml"]; const CONFIG_ALLOWED_EXTENSIONS: &[&str] = &[".toml", ".json", ".yaml", ".yml"];
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AppConfig { pub struct AppConfig {
pub server: WebServerConfig,
pub cache: CacheConfig,
pub auth: AuthConfig, pub auth: AuthConfig,
pub dal: StorageConfig, pub storage: StorageConfig,
pub mikan: AppMikanConfig, pub mikan: MikanConfig,
pub graphql: GraphQLConfig, pub graphql: GraphQLConfig,
pub logger: LoggerConfig,
pub database: DatabaseConfig,
} }
impl AppConfig { impl AppConfig {
@ -40,13 +45,13 @@ impl AppConfig {
.collect_vec() .collect_vec()
} }
pub fn priority_suffix(enviornment: &Enviornment) -> Vec<String> { pub fn priority_suffix(environment: &Environment) -> Vec<String> {
vec![ vec![
format!(".{}.local", enviornment.full_name()), format!(".{}.local", environment.full_name()),
format!(".{}.local", enviornment.short_name()), format!(".{}.local", environment.short_name()),
String::from(".local"), String::from(".local"),
enviornment.full_name().to_string(), environment.full_name().to_string(),
enviornment.short_name().to_string(), environment.short_name().to_string(),
String::from(""), String::from(""),
] ]
} }
@ -75,4 +80,97 @@ impl AppConfig {
_ => unreachable!("unsupported config extension"), _ => unreachable!("unsupported config extension"),
}) })
} }
pub async fn load_dotenv(
environment: &Environment,
working_dir: &str,
dotenv_file: Option<&str>,
) -> RResult<()> {
let try_dotenv_file_or_dirs = if dotenv_file.is_some() {
vec![dotenv_file]
} else {
vec![Some(working_dir)]
};
let priority_suffix = &AppConfig::priority_suffix(environment);
let dotenv_prefix = AppConfig::dotenv_prefix();
let try_filenames = priority_suffix
.iter()
.map(|ps| format!("{}{}", &dotenv_prefix, ps))
.collect_vec();
for try_dotenv_file_or_dir in try_dotenv_file_or_dirs.into_iter().flatten() {
let try_dotenv_file_or_dir_path = Path::new(try_dotenv_file_or_dir);
if try_dotenv_file_or_dir_path.exists() {
if try_dotenv_file_or_dir_path.is_dir() {
for f in try_filenames.iter() {
let p = try_dotenv_file_or_dir_path.join(f);
if p.exists() && p.is_file() {
dotenv::from_path(p)?;
break;
}
}
} else if try_dotenv_file_or_dir_path.is_file() {
dotenv::from_path(try_dotenv_file_or_dir_path)?;
break;
}
}
}
Ok(())
}
pub async fn load_config(
environment: &Environment,
working_dir: &str,
config_file: Option<&str>,
) -> RResult<AppConfig> {
let try_config_file_or_dirs = if config_file.is_some() {
vec![config_file]
} else {
vec![Some(working_dir)]
};
let allowed_extensions = &AppConfig::allowed_extension();
let priority_suffix = &AppConfig::priority_suffix(environment);
let convention_prefix = &AppConfig::config_prefix();
let try_filenames = priority_suffix
.iter()
.flat_map(|ps| {
allowed_extensions
.iter()
.map(move |ext| (format!("{}{}{}", convention_prefix, ps, ext), ext))
})
.collect_vec();
let mut fig = Figment::from(AppConfig::default_provider());
for try_config_file_or_dir in try_config_file_or_dirs.into_iter().flatten() {
let try_config_file_or_dir_path = Path::new(try_config_file_or_dir);
if try_config_file_or_dir_path.exists() {
if try_config_file_or_dir_path.is_dir() {
for (f, ext) in try_filenames.iter() {
let p = try_config_file_or_dir_path.join(f);
if p.exists() && p.is_file() {
fig = AppConfig::merge_provider_from_file(fig, &p, ext)?;
break;
}
}
} else if let Some(ext) = try_config_file_or_dir_path
.extension()
.and_then(|s| s.to_str())
&& try_config_file_or_dir_path.is_file()
{
fig =
AppConfig::merge_provider_from_file(fig, try_config_file_or_dir_path, ext)?;
break;
}
}
}
let app_config: AppConfig = fig.extract()?;
Ok(app_config)
}
} }

View File

@ -1,13 +1,13 @@
use sea_orm::DatabaseConnection; use super::{Environment, config::AppConfig};
use super::config::AppConfig;
use crate::{ use crate::{
auth::AuthService, cache::CacheService, errors::RResult, extract::mikan::MikanClient, auth::AuthService, cache::CacheService, database::DatabaseService, errors::RResult,
graphql::GraphQLService, storage::StorageService, extract::mikan::MikanClient, graphql::GraphQLService, logger::LoggerService,
storage::StorageService,
}; };
pub struct AppContext { pub struct AppContext {
pub db: DatabaseConnection, pub logger: LoggerService,
pub db: DatabaseService,
pub config: AppConfig, pub config: AppConfig,
pub cache: CacheService, pub cache: CacheService,
pub mikan: MikanClient, pub mikan: MikanClient,
@ -15,8 +15,36 @@ pub struct AppContext {
pub graphql: GraphQLService, pub graphql: GraphQLService,
pub storage: StorageService, pub storage: StorageService,
pub working_dir: String, pub working_dir: String,
pub environment: Environment,
} }
pub async fn create_context(_config: AppConfig) -> RResult<AppContext> { impl AppContext {
todo!() pub async fn new(
environment: Environment,
config: AppConfig,
working_dir: impl ToString,
) -> RResult<Self> {
let config_cloned = config.clone();
let logger = LoggerService::from_config(config.logger).await?;
let cache = CacheService::from_config(config.cache).await?;
let db = DatabaseService::from_config(config.database).await?;
let storage = StorageService::from_config(config.storage).await?;
let auth = AuthService::from_conf(config.auth).await?;
let mikan = MikanClient::from_config(config.mikan).await?;
let graphql = GraphQLService::from_config_and_database(config.graphql, db.clone()).await?;
Ok(AppContext {
config: config_cloned,
environment,
logger,
auth,
cache,
db,
storage,
mikan,
working_dir: working_dir.to_string(),
graphql,
})
}
} }

View File

@ -1,15 +1,89 @@
use std::sync::Arc; use std::{net::SocketAddr, sync::Arc};
use super::{builder::AppBuilder, context::AppContext, router::AppRouter}; use axum::Router;
use futures::try_join;
use tokio::signal;
use super::{builder::AppBuilder, context::AppContext};
use crate::{
errors::RResult,
web::{
controller::{self, core::ControllerTrait},
middleware::default_middleware_stack,
},
};
pub struct App { pub struct App {
pub context: Arc<AppContext>, pub context: Arc<AppContext>,
pub builder: AppBuilder, pub builder: AppBuilder,
pub router: AppRouter,
} }
impl App { impl App {
pub fn builder() -> AppBuilder { pub fn builder() -> AppBuilder {
AppBuilder::default() AppBuilder::default()
} }
pub async fn serve(&self) -> RResult<()> {
let context = &self.context;
let config = &context.config;
let listener = tokio::net::TcpListener::bind(&format!(
"{}:{}",
config.server.binding, config.server.port
))
.await?;
let mut router = Router::<Arc<AppContext>>::new();
let (graphqlc, oidcc) = try_join!(
controller::graphql::create(context.clone()),
controller::oidc::create(context.clone()),
)?;
for c in [graphqlc, oidcc] {
router = c.apply_to(router);
}
let middlewares = default_middleware_stack(context.clone());
for mid in middlewares {
router = mid.apply(router)?;
tracing::info!(name = mid.name(), "+middleware");
}
let router = router
.with_state(context.clone())
.into_make_service_with_connect_info::<SocketAddr>();
axum::serve(listener, router)
.with_graceful_shutdown(async move {
Self::shutdown_signal().await;
tracing::info!("shutting down...");
})
.await?;
Ok(())
}
async fn shutdown_signal() {
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("failed to install signal handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
() = ctrl_c => {},
() = terminate => {},
}
}
} }

View File

@ -1,10 +1,22 @@
pub enum Enviornment { use clap::ValueEnum;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize, ValueEnum)]
#[serde(rename_all = "snake_case")]
#[value(rename_all = "snake_case")]
pub enum Environment {
#[serde(alias = "dev")]
#[value(alias = "dev")]
Development, Development,
#[serde(alias = "prod")]
#[value(alias = "prod")]
Production, Production,
#[serde(alias = "test")]
#[value(alias = "test")]
Testing, Testing,
} }
impl Enviornment { impl Environment {
pub fn full_name(&self) -> &'static str { pub fn full_name(&self) -> &'static str {
match &self { match &self {
Self::Development => "development", Self::Development => "development",

View File

@ -3,75 +3,10 @@ pub mod config;
pub mod context; pub mod context;
pub mod core; pub mod core;
pub mod env; pub mod env;
pub mod router;
pub use core::App; pub use core::App;
use std::path::Path;
use async_trait::async_trait; pub use builder::AppBuilder;
pub use config::AppConfig;
pub use context::AppContext; pub use context::AppContext;
use loco_rs::{ pub use env::Environment;
Result,
app::{AppContext as LocoAppContext, Hooks},
boot::{BootResult, StartMode, create_app},
config::Config,
controller::AppRoutes,
db::truncate_table,
environment::Environment,
prelude::*,
task::Tasks,
};
use crate::{migrations::Migrator, models::subscribers};
pub struct App1;
#[async_trait]
impl Hooks for App1 {
fn app_version() -> String {
format!(
"{} ({})",
env!("CARGO_PKG_VERSION"),
option_env!("BUILD_SHA")
.or(option_env!("GITHUB_SHA"))
.unwrap_or("dev")
)
}
fn app_name() -> &'static str {
env!("CARGO_CRATE_NAME")
}
async fn boot(
mode: StartMode,
environment: &Environment,
config: Config,
) -> Result<BootResult> {
create_app::<Self, Migrator>(mode, environment, config).await
}
async fn initializers(_ctx: &LocoAppContext) -> Result<Vec<Box<dyn Initializer>>> {
let initializers: Vec<Box<dyn Initializer>> = vec![];
Ok(initializers)
}
fn routes(_ctx: &LocoAppContext) -> AppRoutes {
AppRoutes::with_default_routes()
}
fn register_tasks(_tasks: &mut Tasks) {}
async fn truncate(ctx: &LocoAppContext) -> Result<()> {
truncate_table(&ctx.db, subscribers::Entity).await?;
Ok(())
}
async fn seed(_ctx: &LocoAppContext, _base: &Path) -> Result<()> {
Ok(())
}
async fn connect_workers(_ctx: &LocoAppContext, _queue: &Queue) -> Result<()> {
Ok(())
}
}

View File

@ -1,31 +0,0 @@
use std::sync::Arc;
use axum::Router;
use futures::try_join;
use crate::{
app::AppContext,
controllers::{self, core::ControllerTrait},
errors::RResult,
};
pub struct AppRouter {
pub root: Router<Arc<AppContext>>,
}
pub async fn create_router(context: Arc<AppContext>) -> RResult<AppRouter> {
let mut root_router = Router::<Arc<AppContext>>::new();
let (graphqlc, oidcc) = try_join!(
controllers::graphql::create(context.clone()),
controllers::oidc::create(context.clone()),
)?;
for c in [graphqlc, oidcc] {
root_router = c.apply_to(root_router);
}
root_router = root_router.with_state(context);
Ok(AppRouter { root: root_router })
}

View File

@ -77,7 +77,7 @@ impl AuthServiceTrait for BasicAuthService {
{ {
let subscriber_auth = crate::models::auth::Model::find_by_pid(ctx, SEED_SUBSCRIBER) let subscriber_auth = crate::models::auth::Model::find_by_pid(ctx, SEED_SUBSCRIBER)
.await .await
.map_err(AuthError::FindAuthRecordError)?; .map_err(|_| AuthError::FindAuthRecordError)?;
return Ok(AuthUserInfo { return Ok(AuthUserInfo {
subscriber_auth, subscriber_auth,
auth_type: AuthType::Basic, auth_type: AuthType::Basic,

View File

@ -13,7 +13,7 @@ use openidconnect::{
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use thiserror::Error; use thiserror::Error;
use crate::{errors::RError, fetch::HttpClientError, models::auth::AuthType}; use crate::{fetch::HttpClientError, models::auth::AuthType};
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum AuthError { pub enum AuthError {
@ -23,7 +23,7 @@ pub enum AuthError {
current: AuthType, current: AuthType,
}, },
#[error("Failed to find auth record")] #[error("Failed to find auth record")]
FindAuthRecordError(RError), FindAuthRecordError,
#[error("Invalid credentials")] #[error("Invalid credentials")]
BasicInvalidCredentials, BasicInvalidCredentials,
#[error(transparent)] #[error(transparent)]

View File

@ -311,7 +311,7 @@ impl AuthServiceTrait for OidcAuthService {
} }
r => r, r => r,
} }
.map_err(AuthError::FindAuthRecordError)?; .map_err(|_| AuthError::FindAuthRecordError)?;
Ok(AuthUserInfo { Ok(AuthUserInfo {
subscriber_auth, subscriber_auth,

View File

@ -1,9 +1,15 @@
use loco_rs::cli; use color_eyre::{self, eyre};
use recorder::{app::App1, migrations::Migrator}; use recorder::app::AppBuilder;
#[tokio::main] #[tokio::main]
async fn main() -> color_eyre::eyre::Result<()> { async fn main() -> eyre::Result<()> {
color_eyre::install()?; color_eyre::install()?;
cli::main::<App1, Migrator>().await?;
let builder = AppBuilder::from_main_cli(None).await?;
let app = builder.build().await?;
app.serve().await?;
Ok(()) Ok(())
} }

View File

@ -1 +1,4 @@
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct CacheConfig {} pub struct CacheConfig {}

View File

@ -1 +1,10 @@
use super::CacheConfig;
use crate::errors::RResult;
pub struct CacheService {} pub struct CacheService {}
impl CacheService {
pub async fn from_config(_config: CacheConfig) -> RResult<Self> {
Ok(Self {})
}
}

View File

@ -1,3 +0,0 @@
pub mod core;
pub mod graphql;
pub mod oidc;

View File

@ -0,0 +1,14 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct DatabaseConfig {
pub uri: String,
pub enable_logging: bool,
pub min_connections: u32,
pub max_connections: u32,
pub connect_timeout: u64,
pub idle_timeout: u64,
pub acquire_timeout: Option<u64>,
#[serde(default)]
pub auto_migrate: bool,
}

View File

@ -0,0 +1,5 @@
pub mod config;
pub mod service;
pub use config::DatabaseConfig;
pub use service::DatabaseService;

View File

@ -0,0 +1,97 @@
use std::{ops::Deref, time::Duration};
use sea_orm::{
ConnectOptions, ConnectionTrait, Database, DatabaseBackend, DatabaseConnection, DbBackend,
DbErr, ExecResult, QueryResult, Statement,
};
use sea_orm_migration::MigratorTrait;
use super::DatabaseConfig;
use crate::{errors::RResult, migrations::Migrator};
pub struct DatabaseService {
connection: DatabaseConnection,
}
impl DatabaseService {
pub async fn from_config(config: DatabaseConfig) -> RResult<Self> {
let mut opt = ConnectOptions::new(&config.uri);
opt.max_connections(config.max_connections)
.min_connections(config.min_connections)
.connect_timeout(Duration::from_millis(config.connect_timeout))
.idle_timeout(Duration::from_millis(config.idle_timeout))
.sqlx_logging(config.enable_logging);
if let Some(acquire_timeout) = config.acquire_timeout {
opt.acquire_timeout(Duration::from_millis(acquire_timeout));
}
let db = Database::connect(opt).await?;
if db.get_database_backend() == DatabaseBackend::Sqlite {
db.execute(Statement::from_string(
DatabaseBackend::Sqlite,
"
PRAGMA foreign_keys = ON;
PRAGMA journal_mode = WAL;
PRAGMA synchronous = NORMAL;
PRAGMA mmap_size = 134217728;
PRAGMA journal_size_limit = 67108864;
PRAGMA cache_size = 2000;
",
))
.await?;
}
if config.auto_migrate {
Migrator::up(&db, None).await?;
}
Ok(Self { connection: db })
}
}
impl Deref for DatabaseService {
type Target = DatabaseConnection;
fn deref(&self) -> &Self::Target {
&self.connection
}
}
impl AsRef<DatabaseConnection> for DatabaseService {
fn as_ref(&self) -> &DatabaseConnection {
&self.connection
}
}
#[async_trait::async_trait]
impl ConnectionTrait for DatabaseService {
fn get_database_backend(&self) -> DbBackend {
self.deref().get_database_backend()
}
async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
self.deref().execute(stmt).await
}
async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
self.deref().execute_unprepared(sql).await
}
async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
self.deref().query_one(stmt).await
}
async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
self.deref().query_all(stmt).await
}
fn support_returning(&self) -> bool {
self.deref().support_returning()
}
fn is_mock_connection(&self) -> bool {
self.deref().is_mock_connection()
}
}

View File

@ -1,11 +1,23 @@
use std::{borrow::Cow, error::Error as StdError}; use std::{borrow::Cow, error::Error as StdError};
use axum::response::{IntoResponse, Response};
use http::StatusCode;
use thiserror::Error as ThisError; use thiserror::Error as ThisError;
use crate::fetch::HttpClientError; use crate::{auth::AuthError, fetch::HttpClientError};
#[derive(ThisError, Debug)] #[derive(ThisError, Debug)]
pub enum RError { pub enum RError {
#[error(transparent)]
InvalidMethodError(#[from] http::method::InvalidMethod),
#[error(transparent)]
InvalidHeaderNameError(#[from] http::header::InvalidHeaderName),
#[error(transparent)]
TracingAppenderInitError(#[from] tracing_appender::rolling::InitError),
#[error(transparent)]
GraphQLSchemaError(#[from] async_graphql::dynamic::SchemaError),
#[error(transparent)]
AuthError(#[from] AuthError),
#[error(transparent)] #[error(transparent)]
RSSError(#[from] rss::Error), RSSError(#[from] rss::Error),
#[error(transparent)] #[error(transparent)]
@ -56,6 +68,10 @@ pub enum RError {
}, },
#[error("Model Entity {entity} not found")] #[error("Model Entity {entity} not found")]
ModelEntityNotFound { entity: Cow<'static, str> }, ModelEntityNotFound { entity: Cow<'static, str> },
#[error("{0}")]
CustomMessageStr(&'static str),
#[error("{0}")]
CustomMessageString(String),
} }
impl RError { impl RError {
@ -88,4 +104,13 @@ impl RError {
} }
} }
impl IntoResponse for RError {
fn into_response(self) -> Response {
match self {
Self::AuthError(auth_error) => auth_error.into_response(),
err => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response(),
}
}
}
pub type RResult<T> = Result<T, RError>; pub type RResult<T> = Result<T, RError>;

View File

@ -4,7 +4,7 @@ use reqwest_middleware::ClientWithMiddleware;
use secrecy::{ExposeSecret, SecretString}; use secrecy::{ExposeSecret, SecretString};
use url::Url; use url::Url;
use super::AppMikanConfig; use super::MikanConfig;
use crate::{ use crate::{
errors::RError, errors::RError,
fetch::{HttpClient, HttpClientTrait, client::HttpClientCookiesAuth}, fetch::{HttpClient, HttpClientTrait, client::HttpClientCookiesAuth},
@ -29,7 +29,7 @@ pub struct MikanClient {
} }
impl MikanClient { impl MikanClient {
pub fn new(config: AppMikanConfig) -> Result<Self, RError> { pub async fn from_config(config: MikanConfig) -> Result<Self, RError> {
let http_client = HttpClient::from_config(config.http_client)?; let http_client = HttpClient::from_config(config.http_client)?;
let base_url = config.base_url; let base_url = config.base_url;
Ok(Self { Ok(Self {

View File

@ -4,7 +4,7 @@ use url::Url;
use crate::fetch::HttpClientConfig; use crate::fetch::HttpClientConfig;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct AppMikanConfig { pub struct MikanConfig {
pub http_client: HttpClientConfig, pub http_client: HttpClientConfig,
pub base_url: Url, pub base_url: Url,
} }

View File

@ -5,7 +5,7 @@ pub mod rss_extract;
pub mod web_extract; pub mod web_extract;
pub use client::{MikanAuthSecrecy, MikanClient}; pub use client::{MikanAuthSecrecy, MikanClient};
pub use config::AppMikanConfig; pub use config::MikanConfig;
pub use constants::MIKAN_BUCKET_KEY; pub use constants::MIKAN_BUCKET_KEY;
pub use rss_extract::{ pub use rss_extract::{
MikanBangumiAggregationRssChannel, MikanBangumiRssChannel, MikanBangumiRssLink, MikanBangumiAggregationRssChannel, MikanBangumiRssChannel, MikanBangumiRssLink,

View File

@ -354,7 +354,7 @@ mod tests {
let mikan_base_url = Url::parse(&mikan_server.url())?; let mikan_base_url = Url::parse(&mikan_server.url())?;
let mikan_client = build_testing_mikan_client(mikan_base_url.clone())?; let mikan_client = build_testing_mikan_client(mikan_base_url.clone()).await?;
{ {
let bangumi_rss_url = let bangumi_rss_url =

View File

@ -509,7 +509,7 @@ mod test {
async fn test_extract_mikan_poster_from_src(before_each: ()) -> eyre::Result<()> { async fn test_extract_mikan_poster_from_src(before_each: ()) -> eyre::Result<()> {
let mut mikan_server = mockito::Server::new_async().await; let mut mikan_server = mockito::Server::new_async().await;
let mikan_base_url = Url::parse(&mikan_server.url())?; let mikan_base_url = Url::parse(&mikan_server.url())?;
let mikan_client = build_testing_mikan_client(mikan_base_url.clone())?; let mikan_client = build_testing_mikan_client(mikan_base_url.clone()).await?;
let bangumi_poster_url = mikan_base_url.join("/images/Bangumi/202309/5ce9fed1.jpg")?; let bangumi_poster_url = mikan_base_url.join("/images/Bangumi/202309/5ce9fed1.jpg")?;
@ -540,7 +540,7 @@ mod test {
async fn test_extract_mikan_episode(before_each: ()) -> eyre::Result<()> { async fn test_extract_mikan_episode(before_each: ()) -> eyre::Result<()> {
let mut mikan_server = mockito::Server::new_async().await; let mut mikan_server = mockito::Server::new_async().await;
let mikan_base_url = Url::parse(&mikan_server.url())?; let mikan_base_url = Url::parse(&mikan_server.url())?;
let mikan_client = build_testing_mikan_client(mikan_base_url.clone())?; let mikan_client = build_testing_mikan_client(mikan_base_url.clone()).await?;
let episode_homepage_url = let episode_homepage_url =
mikan_base_url.join("/Home/Episode/475184dce83ea2b82902592a5ac3343f6d54b36a")?; mikan_base_url.join("/Home/Episode/475184dce83ea2b82902592a5ac3343f6d54b36a")?;
@ -582,7 +582,7 @@ mod test {
) -> eyre::Result<()> { ) -> eyre::Result<()> {
let mut mikan_server = mockito::Server::new_async().await; let mut mikan_server = mockito::Server::new_async().await;
let mikan_base_url = Url::parse(&mikan_server.url())?; let mikan_base_url = Url::parse(&mikan_server.url())?;
let mikan_client = build_testing_mikan_client(mikan_base_url.clone())?; let mikan_client = build_testing_mikan_client(mikan_base_url.clone()).await?;
let bangumi_homepage_url = mikan_base_url.join("/Home/Bangumi/3416#370")?; let bangumi_homepage_url = mikan_base_url.join("/Home/Bangumi/3416#370")?;
@ -625,7 +625,7 @@ mod test {
let my_bangumi_page_url = mikan_base_url.join("/Home/MyBangumi")?; let my_bangumi_page_url = mikan_base_url.join("/Home/MyBangumi")?;
let mikan_client = build_testing_mikan_client(mikan_base_url.clone())?; let mikan_client = build_testing_mikan_client(mikan_base_url.clone()).await?;
{ {
let my_bangumi_without_cookie_mock = mikan_server let my_bangumi_without_cookie_mock = mikan_server

View File

@ -5,5 +5,6 @@ pub mod schema_root;
pub mod service; pub mod service;
pub mod util; pub mod util;
pub use config::GraphQLConfig;
pub use schema_root::schema; pub use schema_root::schema;
pub use service::GraphQLService; pub use service::GraphQLService;

View File

@ -1,7 +1,8 @@
use async_graphql::dynamic::{Schema, SchemaError}; use async_graphql::dynamic::Schema;
use sea_orm::DatabaseConnection; use sea_orm::DatabaseConnection;
use super::{config::GraphQLConfig, schema_root}; use super::{config::GraphQLConfig, schema_root};
use crate::errors::RResult;
#[derive(Debug)] #[derive(Debug)]
pub struct GraphQLService { pub struct GraphQLService {
@ -9,7 +10,10 @@ pub struct GraphQLService {
} }
impl GraphQLService { impl GraphQLService {
pub fn new(config: GraphQLConfig, db: DatabaseConnection) -> Result<Self, SchemaError> { pub async fn from_config_and_database(
config: GraphQLConfig,
db: DatabaseConnection,
) -> RResult<Self> {
let schema = schema_root::schema(db, config.depth_limit, config.complexity_limit)?; let schema = schema_root::schema(db, config.depth_limit, config.complexity_limit)?;
Ok(Self { schema }) Ok(Self { schema })
} }

View File

@ -11,11 +11,12 @@
pub mod app; pub mod app;
pub mod auth; pub mod auth;
pub mod cache; pub mod cache;
pub mod controllers; pub mod database;
pub mod errors; pub mod errors;
pub mod extract; pub mod extract;
pub mod fetch; pub mod fetch;
pub mod graphql; pub mod graphql;
pub mod logger;
pub mod migrations; pub mod migrations;
pub mod models; pub mod models;
pub mod storage; pub mod storage;
@ -24,3 +25,4 @@ pub mod tasks;
#[cfg(test)] #[cfg(test)]
pub mod test_utils; pub mod test_utils;
pub mod views; pub mod views;
pub mod web;

View File

@ -0,0 +1,38 @@
use serde::{Deserialize, Serialize};
use super::{
LogRotation,
core::{LogFormat, LogLevel},
};
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub struct LoggerConfig {
pub enable: bool,
#[serde(default)]
pub pretty_backtrace: bool,
pub level: LogLevel,
pub format: LogFormat,
pub filter: Option<String>,
pub override_filter: Option<String>,
pub file_appender: Option<LoggerFileAppender>,
}
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub struct LoggerFileAppender {
pub enable: bool,
#[serde(default)]
pub non_blocking: bool,
pub level: LogLevel,
pub format: LogFormat,
pub rotation: LogRotation,
pub dir: Option<String>,
pub filename_prefix: Option<String>,
pub filename_suffix: Option<String>,
pub max_log_files: usize,
}

View File

@ -0,0 +1,49 @@
use serde::{Deserialize, Serialize};
use serde_variant::to_variant_name;
#[derive(Debug, Default, Clone, Deserialize, Serialize)]
pub enum LogLevel {
#[serde(rename = "off")]
Off,
#[serde(rename = "trace")]
Trace,
#[serde(rename = "debug")]
Debug,
#[serde(rename = "info")]
#[default]
Info,
#[serde(rename = "warn")]
Warn,
#[serde(rename = "error")]
Error,
}
impl std::fmt::Display for LogLevel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
to_variant_name(self).expect("only enum supported").fmt(f)
}
}
#[derive(Debug, Default, Clone, Deserialize, Serialize)]
pub enum LogFormat {
#[serde(rename = "compact")]
#[default]
Compact,
#[serde(rename = "pretty")]
Pretty,
#[serde(rename = "json")]
Json,
}
#[derive(Debug, Default, Clone, Deserialize, Serialize)]
pub enum LogRotation {
#[serde(rename = "minutely")]
Minutely,
#[serde(rename = "hourly")]
#[default]
Hourly,
#[serde(rename = "daily")]
Daily,
#[serde(rename = "never")]
Never,
}

View File

@ -0,0 +1,8 @@
pub mod config;
pub mod core;
pub mod service;
pub use core::{LogFormat, LogLevel, LogRotation};
pub use config::{LoggerConfig, LoggerFileAppender};
pub use service::LoggerService;

View File

@ -0,0 +1,162 @@
use std::sync::OnceLock;
use tracing_appender::non_blocking::WorkerGuard;
use tracing_subscriber::{
EnvFilter, Layer, Registry,
fmt::{self, MakeWriter},
layer::SubscriberExt,
util::SubscriberInitExt,
};
use super::{LogFormat, LogLevel, LogRotation, LoggerConfig};
use crate::errors::{RError, RResult};
// Function to initialize the logger based on the provided configuration
const MODULE_WHITELIST: &[&str] = &["sea_orm_migration", "tower_http", "sqlx::query", "sidekiq"];
// Keep nonblocking file appender work guard
static NONBLOCKING_WORK_GUARD_KEEP: OnceLock<WorkerGuard> = OnceLock::new();
pub struct LoggerService {}
impl LoggerService {
pub fn init_layer<W2>(
make_writer: W2,
format: &LogFormat,
ansi: bool,
) -> Box<dyn Layer<Registry> + Sync + Send>
where
W2: for<'writer> MakeWriter<'writer> + Sync + Send + 'static,
{
match format {
LogFormat::Compact => fmt::Layer::default()
.with_ansi(ansi)
.with_writer(make_writer)
.compact()
.boxed(),
LogFormat::Pretty => fmt::Layer::default()
.with_ansi(ansi)
.with_writer(make_writer)
.pretty()
.boxed(),
LogFormat::Json => fmt::Layer::default()
.with_ansi(ansi)
.with_writer(make_writer)
.json()
.boxed(),
}
}
fn init_env_filter(override_filter: Option<&String>, level: &LogLevel) -> EnvFilter {
EnvFilter::try_from_default_env()
.or_else(|_| {
// user wanted a specific filter, don't care about our internal whitelist
// or, if no override give them the default whitelisted filter (most common)
override_filter.map_or_else(
|| {
EnvFilter::try_new(
MODULE_WHITELIST
.iter()
.map(|m| format!("{m}={level}"))
.chain(std::iter::once(format!(
"{}={}",
env!("CARGO_CRATE_NAME"),
level
)))
.collect::<Vec<_>>()
.join(","),
)
},
EnvFilter::try_new,
)
})
.expect("logger initialization failed")
}
pub async fn from_config(config: LoggerConfig) -> RResult<Self> {
let mut layers: Vec<Box<dyn Layer<Registry> + Sync + Send>> = Vec::new();
if let Some(file_appender_config) = config.file_appender.as_ref() {
if file_appender_config.enable {
let dir = file_appender_config
.dir
.as_ref()
.map_or_else(|| "./logs".to_string(), ToString::to_string);
let mut rolling_builder = tracing_appender::rolling::Builder::default()
.max_log_files(file_appender_config.max_log_files);
rolling_builder = match file_appender_config.rotation {
LogRotation::Minutely => {
rolling_builder.rotation(tracing_appender::rolling::Rotation::MINUTELY)
}
LogRotation::Hourly => {
rolling_builder.rotation(tracing_appender::rolling::Rotation::HOURLY)
}
LogRotation::Daily => {
rolling_builder.rotation(tracing_appender::rolling::Rotation::DAILY)
}
LogRotation::Never => {
rolling_builder.rotation(tracing_appender::rolling::Rotation::NEVER)
}
};
let file_appender = rolling_builder
.filename_prefix(
file_appender_config
.filename_prefix
.as_ref()
.map_or_else(String::new, ToString::to_string),
)
.filename_suffix(
file_appender_config
.filename_suffix
.as_ref()
.map_or_else(String::new, ToString::to_string),
)
.build(dir)?;
let file_appender_layer = if file_appender_config.non_blocking {
let (non_blocking_file_appender, work_guard) =
tracing_appender::non_blocking(file_appender);
NONBLOCKING_WORK_GUARD_KEEP
.set(work_guard)
.map_err(|_| RError::CustomMessageStr("cannot lock for appender"))?;
Self::init_layer(
non_blocking_file_appender,
&file_appender_config.format,
false,
)
} else {
Self::init_layer(file_appender, &file_appender_config.format, false)
};
layers.push(file_appender_layer);
}
}
if config.enable {
let stdout_layer = Self::init_layer(std::io::stdout, &config.format, true);
layers.push(stdout_layer);
}
if !layers.is_empty() {
let env_filter = Self::init_env_filter(config.override_filter.as_ref(), &config.level);
tracing_subscriber::registry()
.with(layers)
.with(env_filter)
.init();
}
if config.pretty_backtrace {
unsafe {
std::env::set_var("RUST_BACKTRACE", "1");
}
tracing::warn!(
"pretty backtraces are enabled (this is great for development but has a runtime \
cost for production. disable with `logger.pretty_backtrace` in your config yaml)"
);
}
Ok(Self {})
}
}

View File

@ -1,5 +1,5 @@
use async_trait::async_trait; use async_trait::async_trait;
use sea_orm::{Set, TransactionTrait, entity::prelude::*}; use sea_orm::{EntityTrait, Set, TransactionTrait, prelude::*};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use super::subscribers::{self, SEED_SUBSCRIBER}; use super::subscribers::{self, SEED_SUBSCRIBER};

View File

@ -10,10 +10,10 @@
// Import Routes // Import Routes
import { Route as rootRoute } from './controllers/__root' import { Route as rootRoute } from './web/controller/__root'
import { Route as IndexImport } from './controllers/index' import { Route as IndexImport } from './web/controller/index'
import { Route as GraphqlIndexImport } from './controllers/graphql/index' import { Route as GraphqlIndexImport } from './web/controller/graphql/index'
import { Route as OidcCallbackImport } from './controllers/oidc/callback' import { Route as OidcCallbackImport } from './web/controller/oidc/callback'
// Create/Update Routes // Create/Update Routes

View File

@ -8,7 +8,7 @@ use url::Url;
use uuid::Uuid; use uuid::Uuid;
use super::StorageConfig; use super::StorageConfig;
use crate::errors::RError; use crate::errors::{RError, RResult};
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
@ -50,10 +50,10 @@ pub struct StorageService {
} }
impl StorageService { impl StorageService {
pub fn from_config(config: StorageConfig) -> Self { pub async fn from_config(config: StorageConfig) -> RResult<Self> {
Self { Ok(Self {
data_dir: config.data_dir, data_dir: config.data_dir.to_string(),
} })
} }
pub fn get_fs(&self) -> Fs { pub fn get_fs(&self) -> Fs {

View File

@ -1,17 +1,18 @@
use color_eyre::eyre;
use reqwest::IntoUrl; use reqwest::IntoUrl;
use crate::{ use crate::{
extract::mikan::{AppMikanConfig, MikanClient}, errors::RResult,
extract::mikan::{MikanClient, MikanConfig},
fetch::HttpClientConfig, fetch::HttpClientConfig,
}; };
pub fn build_testing_mikan_client(base_mikan_url: impl IntoUrl) -> eyre::Result<MikanClient> { pub async fn build_testing_mikan_client(base_mikan_url: impl IntoUrl) -> RResult<MikanClient> {
let mikan_client = MikanClient::new(AppMikanConfig { let mikan_client = MikanClient::from_config(MikanConfig {
http_client: HttpClientConfig { http_client: HttpClientConfig {
..Default::default() ..Default::default()
}, },
base_url: base_mikan_url.into_url()?, base_url: base_mikan_url.into_url()?,
})?; })
.await?;
Ok(mikan_client) Ok(mikan_client)
} }

View File

@ -0,0 +1,30 @@
use serde::{Deserialize, Serialize};
use super::middleware::MiddlewareConfig;
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct WebServerConfig {
/// The address on which the server should listen on for incoming
/// connections.
#[serde(default = "default_binding")]
pub binding: String,
/// The port on which the server should listen for incoming connections.
#[serde(default = "default_port")]
pub port: u16,
/// The webserver host
pub host: String,
/// Identify via the `Server` header
pub ident: Option<String>,
/// Middleware configurations for the server, including payload limits,
/// logging, and error handling.
#[serde(default)]
pub middlewares: MiddlewareConfig,
}
pub fn default_binding() -> String {
"127.0.0.1".to_string()
}
pub fn default_port() -> u16 {
5_001
}

View File

@ -2,10 +2,10 @@ import { type Fetcher, createGraphiQLFetcher } from '@graphiql/toolkit';
import { createFileRoute } from '@tanstack/react-router'; import { createFileRoute } from '@tanstack/react-router';
import GraphiQL from 'graphiql'; import GraphiQL from 'graphiql';
import { useMemo } from 'react'; import { useMemo } from 'react';
import { beforeLoadGuard } from '../../auth/guard'; import { beforeLoadGuard } from '../../../auth/guard';
import 'graphiql/graphiql.css'; import 'graphiql/graphiql.css';
import { firstValueFrom } from 'rxjs'; import { firstValueFrom } from 'rxjs';
import { useAuth } from '../../auth/hooks'; import { useAuth } from '../../../auth/hooks';
export const Route = createFileRoute('/graphql/')({ export const Route = createFileRoute('/graphql/')({
component: RouteComponent, component: RouteComponent,

View File

@ -0,0 +1,5 @@
pub mod core;
pub mod graphql;
pub mod oidc;
pub use core::{Controller, ControllerTrait, PrefixController};

View File

@ -1,6 +1,6 @@
import { createFileRoute, redirect } from '@tanstack/react-router'; import { createFileRoute, redirect } from '@tanstack/react-router';
import { EventTypes } from 'oidc-client-rx'; import { EventTypes } from 'oidc-client-rx';
import { useAuth } from '../../auth/hooks'; import { useAuth } from '../../../auth/hooks';
export const Route = createFileRoute('/oidc/callback')({ export const Route = createFileRoute('/oidc/callback')({
component: RouteComponent, component: RouteComponent,

View File

@ -0,0 +1,58 @@
//! Catch Panic Middleware for Axum
//!
//! This middleware catches panics that occur during request handling in the
//! application. When a panic occurs, it logs the error and returns an
//! internal server error response. This middleware helps ensure that the
//! application can gracefully handle unexpected errors without crashing the
//! server.
use std::sync::Arc;
use axum::{Router, response::IntoResponse};
use http::StatusCode;
use serde::{Deserialize, Serialize};
use tower_http::catch_panic::CatchPanicLayer;
use crate::{app::AppContext, errors::RResult, web::middleware::MiddlewareLayer};
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct CatchPanic {
#[serde(default)]
pub enable: bool,
}
/// Handler function for the [`CatchPanicLayer`] middleware.
///
/// This function processes panics by extracting error messages, logging them,
/// and returning an internal server error response.
#[allow(clippy::needless_pass_by_value)]
fn handle_panic(err: Box<dyn std::any::Any + Send + 'static>) -> axum::response::Response {
let err = err.downcast_ref::<String>().map_or_else(
|| err.downcast_ref::<&str>().map_or("no error details", |s| s),
|s| s.as_str(),
);
tracing::error!(err.msg = err, "server_panic");
StatusCode::INTERNAL_SERVER_ERROR.into_response()
}
impl MiddlewareLayer for CatchPanic {
/// Returns the name of the middleware
fn name(&self) -> &'static str {
"catch_panic"
}
/// Returns whether the middleware is enabled or not
fn is_enabled(&self) -> bool {
self.enable
}
fn config(&self) -> serde_json::Result<serde_json::Value> {
serde_json::to_value(self)
}
/// Applies the Catch Panic middleware layer to the Axum router.
fn apply(&self, app: Router<Arc<AppContext>>) -> RResult<Router<Arc<AppContext>>> {
Ok(app.layer(CatchPanicLayer::custom(handle_panic)))
}
}

View File

@ -0,0 +1,41 @@
//! Compression Middleware for Axum
//!
//! This middleware applies compression to HTTP responses to reduce the size of
//! the data being transmitted. This can improve performance by decreasing load
//! times and reducing bandwidth usage. The middleware configuration allows for
//! enabling or disabling compression based on the application settings.
use std::sync::Arc;
use axum::Router;
use serde::{Deserialize, Serialize};
use tower_http::compression::CompressionLayer;
use crate::{app::AppContext, errors::RResult, web::middleware::MiddlewareLayer};
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Compression {
#[serde(default)]
pub enable: bool,
}
impl MiddlewareLayer for Compression {
/// Returns the name of the middleware
fn name(&self) -> &'static str {
"compression"
}
/// Returns whether the middleware is enabled or not
fn is_enabled(&self) -> bool {
self.enable
}
fn config(&self) -> serde_json::Result<serde_json::Value> {
serde_json::to_value(self)
}
/// Applies the Compression middleware layer to the Axum router.
fn apply(&self, app: Router<Arc<AppContext>>) -> RResult<Router<Arc<AppContext>>> {
Ok(app.layer(CompressionLayer::new()))
}
}

View File

@ -0,0 +1,163 @@
//! Configurable and Flexible CORS Middleware
//!
//! This middleware enables Cross-Origin Resource Sharing (CORS) by allowing
//! configurable origins, methods, and headers in HTTP requests. It can be
//! tailored to fit various application requirements, supporting permissive CORS
//! or specific rules as defined in the middleware configuration.
use std::{sync::Arc, time::Duration};
use axum::Router;
use serde::{Deserialize, Serialize};
use serde_json::json;
use tower_http::cors::{self, Any};
use crate::{app::AppContext, web::middleware::MiddlewareLayer, errors::RResult};
/// CORS middleware configuration
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Cors {
#[serde(default)]
pub enable: bool,
/// Allow origins
#[serde(default = "default_allow_origins")]
pub allow_origins: Vec<String>,
/// Allow headers
#[serde(default = "default_allow_headers")]
pub allow_headers: Vec<String>,
/// Allow methods
#[serde(default = "default_allow_methods")]
pub allow_methods: Vec<String>,
/// Allow credentials
#[serde(default)]
pub allow_credentials: bool,
/// Max age
pub max_age: Option<u64>,
// Vary headers
#[serde(default = "default_vary_headers")]
pub vary: Vec<String>,
}
fn default_allow_origins() -> Vec<String> {
vec!["*".to_string()]
}
fn default_allow_headers() -> Vec<String> {
vec!["*".to_string()]
}
fn default_allow_methods() -> Vec<String> {
vec!["*".to_string()]
}
fn default_vary_headers() -> Vec<String> {
vec![
"origin".to_string(),
"access-control-request-method".to_string(),
"access-control-request-headers".to_string(),
]
}
impl Default for Cors {
fn default() -> Self {
serde_json::from_value(json!({})).unwrap()
}
}
impl Cors {
/// Creates cors layer
///
/// # Errors
///
/// This function returns an error in the following cases:
///
/// - If any of the provided origins in `allow_origins` cannot be parsed as
/// a valid URI, the function will return a parsing error.
/// - If any of the provided headers in `allow_headers` cannot be parsed as
/// valid HTTP headers, the function will return a parsing error.
/// - If any of the provided methods in `allow_methods` cannot be parsed as
/// valid HTTP methods, the function will return a parsing error.
///
/// In all of these cases, the error returned will be the result of the
/// `parse` method of the corresponding type.
pub fn cors(&self) -> RResult<cors::CorsLayer> {
let mut cors: cors::CorsLayer = cors::CorsLayer::new();
// testing CORS, assuming https://example.com in the allow list:
// $ curl -v --request OPTIONS 'localhost:5150/api/_ping' -H 'Origin: https://example.com' -H 'Acces
// look for '< access-control-allow-origin: https://example.com' in response.
// if it doesn't appear (test with a bogus domain), it is not allowed.
if self.allow_origins == default_allow_origins() {
cors = cors.allow_origin(Any);
} else {
let mut list = vec![];
for origin in &self.allow_origins {
list.push(origin.parse()?);
}
if !list.is_empty() {
cors = cors.allow_origin(list);
}
}
if self.allow_headers == default_allow_headers() {
cors = cors.allow_headers(Any);
} else {
let mut list = vec![];
for header in &self.allow_headers {
list.push(header.parse()?);
}
if !list.is_empty() {
cors = cors.allow_headers(list);
}
}
if self.allow_methods == default_allow_methods() {
cors = cors.allow_methods(Any);
} else {
let mut list = vec![];
for method in &self.allow_methods {
list.push(method.parse()?);
}
if !list.is_empty() {
cors = cors.allow_methods(list);
}
}
let mut list = vec![];
for v in &self.vary {
list.push(v.parse()?);
}
if !list.is_empty() {
cors = cors.vary(list);
}
if let Some(max_age) = self.max_age {
cors = cors.max_age(Duration::from_secs(max_age));
}
cors = cors.allow_credentials(self.allow_credentials);
Ok(cors)
}
}
impl MiddlewareLayer for Cors {
/// Returns the name of the middleware
fn name(&self) -> &'static str {
"cors"
}
/// Returns whether the middleware is enabled or not
fn is_enabled(&self) -> bool {
self.enable
}
fn config(&self) -> serde_json::Result<serde_json::Value> {
serde_json::to_value(self)
}
/// Applies the CORS middleware layer to the Axum router.
fn apply(&self, app: Router<Arc<AppContext>>) -> RResult<Router<Arc<AppContext>>> {
Ok(app.layer(self.cors()?))
}
}

View File

@ -0,0 +1,111 @@
//! `ETag` Middleware for Caching Requests
//!
//! This middleware implements the [ETag](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/ETag)
//! HTTP header for caching responses in Axum. `ETags` are used to validate
//! cache entries by comparing a client's stored `ETag` with the one generated
//! by the server. If the `ETags` match, a `304 Not Modified` response is sent,
//! avoiding the need to resend the full content.
use std::{
sync::Arc,
task::{Context, Poll},
};
use axum::{
Router,
body::Body,
extract::Request,
http::{
StatusCode,
header::{ETAG, IF_NONE_MATCH},
},
response::Response,
};
use futures_util::future::BoxFuture;
use serde::{Deserialize, Serialize};
use tower::{Layer, Service};
use crate::{app::AppContext, errors::RResult, web::middleware::MiddlewareLayer};
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Etag {
#[serde(default)]
pub enable: bool,
}
impl MiddlewareLayer for Etag {
/// Returns the name of the middleware
fn name(&self) -> &'static str {
"etag"
}
/// Returns whether the middleware is enabled or not
fn is_enabled(&self) -> bool {
self.enable
}
fn config(&self) -> serde_json::Result<serde_json::Value> {
serde_json::to_value(self)
}
/// Applies the `ETag` middleware to the application router.
fn apply(&self, app: Router<Arc<AppContext>>) -> RResult<Router<Arc<AppContext>>> {
Ok(app.layer(EtagLayer))
}
}
/// [`EtagLayer`] struct for adding `ETag` functionality as a Tower service
/// layer.
#[derive(Default, Clone)]
struct EtagLayer;
impl<S> Layer<S> for EtagLayer {
type Service = EtagMiddleware<S>;
fn layer(&self, inner: S) -> Self::Service {
EtagMiddleware { inner }
}
}
#[derive(Clone)]
struct EtagMiddleware<S> {
inner: S,
}
impl<S> Service<Request<Body>> for EtagMiddleware<S>
where
S: Service<Request, Response = Response> + Send + 'static,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
// `BoxFuture` is a type alias for `Pin<Box<dyn Future + Send + 'a>>`
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, request: Request) -> Self::Future {
let ifnm = request.headers().get(IF_NONE_MATCH).cloned();
let future = self.inner.call(request);
let res_fut = async move {
let response = future.await?;
let etag_from_response = response.headers().get(ETAG).cloned();
if let Some(etag_in_request) = ifnm {
if let Some(etag_from_response) = etag_from_response {
if etag_in_request == etag_from_response {
return Ok(Response::builder()
.status(StatusCode::NOT_MODIFIED)
.body(Body::empty())
.unwrap());
}
}
}
Ok(response)
};
Box::pin(res_fut)
}
}

View File

@ -0,0 +1,71 @@
//! Detect a content type and format and responds accordingly
use axum::{
extract::FromRequestParts,
http::{
header::{ACCEPT, CONTENT_TYPE, HeaderMap},
request::Parts,
},
};
use serde::{Deserialize, Serialize};
use crate::errors::RError as Error;
#[derive(Debug, Deserialize, Serialize)]
pub struct Format(pub RespondTo);
#[derive(Debug, Deserialize, Serialize)]
pub enum RespondTo {
None,
Html,
Json,
Xml,
Other(String),
}
fn detect_format(content_type: &str) -> RespondTo {
if content_type.starts_with("application/json") {
RespondTo::Json
} else if content_type.starts_with("text/html") {
RespondTo::Html
} else if content_type.starts_with("text/xml")
|| content_type.starts_with("application/xml")
|| content_type.starts_with("application/xhtml")
{
RespondTo::Xml
} else {
RespondTo::Other(content_type.to_string())
}
}
pub fn get_respond_to(headers: &HeaderMap) -> RespondTo {
#[allow(clippy::option_if_let_else)]
if let Some(content_type) = headers.get(CONTENT_TYPE).and_then(|h| h.to_str().ok()) {
detect_format(content_type)
} else if let Some(content_type) = headers.get(ACCEPT).and_then(|h| h.to_str().ok()) {
detect_format(content_type)
} else {
RespondTo::None
}
}
impl<S> FromRequestParts<S> for Format
where
S: Send + Sync,
{
type Rejection = Error;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Error> {
Ok(Self(get_respond_to(&parts.headers)))
}
}
impl<S> FromRequestParts<S> for RespondTo
where
S: Send + Sync,
{
type Rejection = Error;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Error> {
Ok(get_respond_to(&parts.headers))
}
}

View File

@ -0,0 +1,102 @@
//! Logger Middleware
//!
//! This middleware provides logging functionality for HTTP requests. It uses
//! `TraceLayer` to log detailed information about each request, such as the
//! HTTP method, URI, version, user agent, and an associated request ID.
//! Additionally, it integrates the application's runtime environment
//! into the log context, allowing environment-specific logging (e.g.,
//! "development", "production").
use std::sync::Arc;
use axum::{Router, http};
use serde::{Deserialize, Serialize};
use tower_http::{add_extension::AddExtensionLayer, trace::TraceLayer};
use crate::{
app::{AppContext, Environment},
errors::RResult,
web::middleware::{MiddlewareLayer, request_id::LocoRequestId},
};
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Config {
#[serde(default)]
pub enable: bool,
}
/// [`Middleware`] struct responsible for logging HTTP requests.
#[derive(Serialize, Debug)]
pub struct Middleware {
config: Config,
environment: Environment,
}
/// Creates a new instance of [`Middleware`] by cloning the [`Config`]
/// configuration.
#[must_use]
pub fn new(config: &Config, context: Arc<AppContext>) -> Middleware {
Middleware {
config: config.clone(),
environment: context.environment.clone(),
}
}
impl MiddlewareLayer for Middleware {
/// Returns the name of the middleware
fn name(&self) -> &'static str {
"logger"
}
/// Returns whether the middleware is enabled or not
fn is_enabled(&self) -> bool {
self.config.enable
}
fn config(&self) -> serde_json::Result<serde_json::Value> {
serde_json::to_value(self)
}
/// Applies the logger middleware to the application router by adding layers
/// for:
///
/// - `TraceLayer`: Logs detailed information about each HTTP request.
/// - `AddExtensionLayer`: Adds the current environment to the request
/// extensions, making it accessible to the `TraceLayer` for logging.
///
/// The `TraceLayer` is customized with `make_span_with` to extract
/// request-specific details like method, URI, version, user agent, and
/// request ID, then create a tracing span for the request.
fn apply(&self, app: Router<Arc<AppContext>>) -> RResult<Router<Arc<AppContext>>> {
Ok(app
.layer(
TraceLayer::new_for_http().make_span_with(|request: &http::Request<_>| {
let ext = request.extensions();
let request_id = ext
.get::<LocoRequestId>()
.map_or_else(|| "req-id-none".to_string(), |r| r.get().to_string());
let user_agent = request
.headers()
.get(axum::http::header::USER_AGENT)
.map_or("", |h| h.to_str().unwrap_or(""));
let env: String = request
.extensions()
.get::<Environment>()
.map(|e| e.full_name().to_string())
.unwrap_or_default();
tracing::error_span!(
"http-request",
"http.method" = tracing::field::display(request.method()),
"http.uri" = tracing::field::display(request.uri()),
"http.version" = tracing::field::debug(request.version()),
"http.user_agent" = tracing::field::display(user_agent),
"environment" = tracing::field::display(env),
request_id = tracing::field::display(request_id),
)
}),
)
.layer(AddExtensionLayer::new(self.environment.clone())))
}
}

View File

@ -0,0 +1,165 @@
pub mod catch_panic;
pub mod compression;
pub mod cors;
pub mod etag;
pub mod format;
pub mod logger;
pub mod remote_ip;
pub mod request_id;
pub mod secure_headers;
pub mod timeout;
use std::sync::Arc;
use axum::Router;
use serde::{Deserialize, Serialize};
use crate::{app::AppContext, errors::RResult};
/// Trait representing the behavior of middleware components in the application.
/// When implementing a new middleware, make sure to go over this checklist:
/// * The name of the middleware should be an ID that is similar to the field
/// name in configuration (look at how `serde` calls it)
/// * Default value implementation should be paired with `serde` default
/// handlers and default serialization implementation. Which means deriving
/// `Default` will _not_ work. You can use `serde_json` and serialize a new
/// config from an empty value, which will cause `serde` default value
/// handlers to kick in.
/// * If you need completely blank values for configuration (for example for
/// testing), implement an `::empty() -> Self` call ad-hoc.
pub trait MiddlewareLayer {
/// Returns the name of the middleware.
/// This should match the name of the property in the containing
/// `middleware` section in configuration (as named by `serde`)
fn name(&self) -> &'static str;
/// Returns whether the middleware is enabled or not.
/// If the middleware is switchable, take this value from a configuration
/// value
fn is_enabled(&self) -> bool {
true
}
/// Returns middleware config.
///
/// # Errors
/// when could not convert middleware to [`serde_json::Value`]
fn config(&self) -> serde_json::Result<serde_json::Value>;
/// Applies the middleware to the given Axum router and returns the modified
/// router.
///
/// # Errors
///
/// If there is an issue when adding the middleware to the router.
fn apply(&self, app: Router<Arc<AppContext>>) -> RResult<Router<Arc<AppContext>>>;
}
#[allow(clippy::unnecessary_lazy_evaluations)]
#[must_use]
pub fn default_middleware_stack(ctx: Arc<AppContext>) -> Vec<Box<dyn MiddlewareLayer>> {
// Shortened reference to middlewares
let middlewares = &ctx.config.server.middlewares;
vec![
// CORS middleware with a default if none
Box::new(middlewares.cors.clone().unwrap_or_else(|| cors::Cors {
enable: false,
..Default::default()
})),
// Catch Panic middleware with a default if none
Box::new(
middlewares
.catch_panic
.clone()
.unwrap_or_else(|| catch_panic::CatchPanic { enable: true }),
),
// Etag middleware with a default if none
Box::new(
middlewares
.etag
.clone()
.unwrap_or_else(|| etag::Etag { enable: true }),
),
// Remote IP middleware with a default if none
Box::new(
middlewares
.remote_ip
.clone()
.unwrap_or_else(|| remote_ip::RemoteIpMiddleware {
enable: false,
..Default::default()
}),
),
// Compression middleware with a default if none
Box::new(
middlewares
.compression
.clone()
.unwrap_or_else(|| compression::Compression { enable: false }),
),
// Timeout Request middleware with a default if none
Box::new(
middlewares
.timeout_request
.clone()
.unwrap_or_else(|| timeout::TimeOut {
enable: false,
..Default::default()
}),
),
// Secure Headers middleware with a default if none
Box::new(middlewares.secure_headers.clone().unwrap_or_else(|| {
secure_headers::SecureHeader {
enable: false,
..Default::default()
}
})),
// Logger middleware with default logger configuration
Box::new(logger::new(
&middlewares
.logger
.clone()
.unwrap_or_else(|| logger::Config { enable: true }),
ctx.clone(),
)),
// Request ID middleware with a default if none
Box::new(
middlewares
.request_id
.clone()
.unwrap_or_else(|| request_id::RequestId { enable: true }),
),
]
}
/// Server middleware configuration structure.
#[derive(Default, Debug, Clone, Deserialize, Serialize)]
pub struct MiddlewareConfig {
/// Compression for the response.
pub compression: Option<compression::Compression>,
/// Etag cache headers.
pub etag: Option<etag::Etag>,
/// Logger and augmenting trace id with request data
pub logger: Option<logger::Config>,
/// Catch any code panic and log the error.
pub catch_panic: Option<catch_panic::CatchPanic>,
/// Setting a global timeout for requests
pub timeout_request: Option<timeout::TimeOut>,
/// CORS configuration
pub cors: Option<cors::Cors>,
/// Sets a set of secure headers
pub secure_headers: Option<secure_headers::SecureHeader>,
/// Calculates a remote IP based on `X-Forwarded-For` when behind a proxy
pub remote_ip: Option<remote_ip::RemoteIpMiddleware>,
/// Request ID
pub request_id: Option<request_id::RequestId>,
}

View File

@ -0,0 +1,306 @@
//! Remote IP Middleware for inferring the client's IP address based on the
//! `X-Forwarded-For` header.
//!
//! This middleware is useful when running behind proxies or load balancers that
//! add the `X-Forwarded-For` header, which includes the original client IP
//! address.
//!
//! The middleware provides a mechanism to configure trusted proxies and extract
//! the most likely client IP from the `X-Forwarded-For` header, skipping any
//! trusted proxy IPs.
use std::{
fmt,
iter::Iterator,
net::{IpAddr, SocketAddr},
str::FromStr,
sync::{Arc, OnceLock},
task::{Context, Poll},
};
use axum::{
Router,
body::Body,
extract::{ConnectInfo, FromRequestParts, Request},
http::{header::HeaderMap, request::Parts},
response::Response,
};
use futures_util::future::BoxFuture;
use ipnetwork::IpNetwork;
use serde::{Deserialize, Serialize};
use tower::{Layer, Service};
use tracing::error;
use crate::{
app::AppContext,
errors::{RError, RResult},
web::middleware::MiddlewareLayer,
};
static LOCAL_TRUSTED_PROXIES: OnceLock<Vec<IpNetwork>> = OnceLock::new();
fn get_local_trusted_proxies() -> &'static Vec<IpNetwork> {
LOCAL_TRUSTED_PROXIES.get_or_init(|| {
[
"127.0.0.0/8", // localhost IPv4 range, per RFC-3330
"::1", // localhost IPv6
"fc00::/7", // private IPv6 range fc00::/7
"10.0.0.0/8", // private IPv4 range 10.x.x.x
"172.16.0.0/12", // private IPv4 range 172.16.0.0 .. 172.31.255.255
"192.168.0.0/16",
]
.iter()
.map(|ip| IpNetwork::from_str(ip).unwrap())
.collect()
})
}
const X_FORWARDED_FOR: &str = "X-Forwarded-For";
///
/// Performs a remote ip "calculation", inferring the most likely
/// client IP from the `X-Forwarded-For` header that is used by
/// load balancers and proxies.
///
/// WARNING
/// =======
///
/// LIKE ANY SUCH REMOTE IP MIDDLEWARE, IN THE WRONG ARCHITECTURE IT CAN MAKE
/// YOU VULNERABLE TO IP SPOOFING.
///
/// This middleware assumes that there is at least one proxy sitting around and
/// setting headers with the client's remote IP address. Otherwise any client
/// can claim to have any IP address by setting the `X-Forwarded-For` header.
///
/// DO NOT USE THIS MIDDLEWARE IF YOU DONT KNOW THAT YOU NEED IT
///
/// -- But if you need it, it's crucial to use it (since it's the only way to
/// get the original client IP)
///
/// This middleware is mostly implemented after the Rails `remote_ip`
/// middleware, and looking at other production Rust services with Axum, taking
/// the best of both worlds to balance performance and pragmatism.
///
/// Similarities to the Rails `remote_ip` middleware:
///
/// * Uses `X-Forwarded-For`
/// * Uses the same built-in trusted proxies list
/// * You can provide a list of `trusted_proxies` which will **replace** the
/// built-in trusted proxies
///
/// Differences from the Rails `remote_ip` middleware:
///
/// * You get an indication if the remote IP is actually resolved or is the
/// socket IP (no `X-Forwarded-For` header or could not parse)
/// * We do not not use the `Client-IP` header, or try to detect "spoofing"
/// (spoofing while doing remote IP resolution is virtually non-detectable)
/// * Order of filtering IPs from `X-Forwarded-For` is done according to [the de
/// facto spec](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For#selecting_an_ip_address)
/// "Trusted proxy list"
#[derive(Default, Serialize, Deserialize, Debug, Clone)]
pub struct RemoteIpMiddleware {
#[serde(default)]
pub enable: bool,
/// A list of alternative proxy list IP ranges and/or network range (will
/// replace built-in proxy list)
pub trusted_proxies: Option<Vec<String>>,
}
impl MiddlewareLayer for RemoteIpMiddleware {
/// Returns the name of the middleware
fn name(&self) -> &'static str {
"remote_ip"
}
/// Returns whether the middleware is enabled or not
fn is_enabled(&self) -> bool {
self.enable
&& (self.trusted_proxies.is_none()
|| self.trusted_proxies.as_ref().is_some_and(|t| !t.is_empty()))
}
fn config(&self) -> serde_json::Result<serde_json::Value> {
serde_json::to_value(self)
}
/// Applies the Remote IP middleware to the given Axum router.
fn apply(&self, app: Router<Arc<AppContext>>) -> RResult<Router<Arc<AppContext>>> {
Ok(app.layer(RemoteIPLayer::new(self)?))
}
}
// implementation reference: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For
fn maybe_get_forwarded(
headers: &HeaderMap,
trusted_proxies: Option<&Vec<IpNetwork>>,
) -> Option<IpAddr> {
/*
> There may be multiple X-Forwarded-For headers present in a request. The IP addresses in these headers must be treated as a single list,
> starting with the first IP address of the first header and continuing to the last IP address of the last header.
> There are two ways of making this single list:
> join the X-Forwarded-For full header values with commas and then split by comma into a list, or
> split each X-Forwarded-For header by comma into lists and then join the lists
*/
let xffs = headers
.get_all(X_FORWARDED_FOR)
.iter()
.map(|hdr| hdr.to_str())
.filter_map(Result::ok)
.collect::<Vec<_>>();
if xffs.is_empty() {
return None;
}
let forwarded = xffs.join(",");
forwarded
.split(',')
.map(str::trim)
.map(str::parse)
.filter_map(Result::ok)
/*
> Trusted proxy list: The IPs or IP ranges of the trusted reverse proxies are configured.
> The X-Forwarded-For IP list is searched from the rightmost, skipping all addresses that
> are on the trusted proxy list. The first non-matching address is the target address.
*/
.filter(|ip| {
// trusted proxies provided REPLACES our default local proxies
let proxies = trusted_proxies.unwrap_or_else(|| get_local_trusted_proxies());
!proxies
.iter()
.any(|trusted_proxy| trusted_proxy.contains(*ip))
})
/*
> When choosing the X-Forwarded-For client IP address closest to the client (untrustworthy
> and not for security-related purposes), the first IP from the leftmost that is a valid
> address and not private/internal should be selected.
>
NOTE:
> The first trustworthy X-Forwarded-For IP address may belong to an untrusted intermediate
> proxy rather than the actual client computer, but it is the only IP suitable for security uses.
*/
.next_back()
}
#[derive(Copy, Clone, Debug)]
pub enum RemoteIP {
Forwarded(IpAddr),
Socket(IpAddr),
None,
}
impl<S> FromRequestParts<S> for RemoteIP
where
S: Send + Sync,
{
type Rejection = ();
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let ip = parts.extensions.get::<Self>();
Ok(*ip.unwrap_or(&Self::None))
}
}
impl fmt::Display for RemoteIP {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Forwarded(ip) => write!(f, "remote: {ip}"),
Self::Socket(ip) => write!(f, "socket: {ip}"),
Self::None => write!(f, "--"),
}
}
}
#[derive(Clone, Debug)]
struct RemoteIPLayer {
trusted_proxies: Option<Vec<IpNetwork>>,
}
impl RemoteIPLayer {
/// Returns new secure headers middleware
///
/// # Errors
/// Fails if invalid header values found
pub fn new(config: &RemoteIpMiddleware) -> RResult<Self> {
Ok(Self {
trusted_proxies: config
.trusted_proxies
.as_ref()
.map(|proxies| {
proxies
.iter()
.map(|proxy| {
IpNetwork::from_str(proxy).map_err(|err| {
RError::CustomMessageString(format!(
"remote ip middleare cannot parse trusted proxy \
configuration: `{proxy}`, reason: `{err}`",
))
})
})
.collect::<RResult<Vec<_>>>()
})
.transpose()?,
})
}
}
impl<S> Layer<S> for RemoteIPLayer {
type Service = RemoteIPMiddleware<S>;
fn layer(&self, inner: S) -> Self::Service {
RemoteIPMiddleware {
inner,
layer: self.clone(),
}
}
}
/// Remote IP Detection Middleware
#[derive(Clone, Debug)]
#[must_use]
pub struct RemoteIPMiddleware<S> {
inner: S,
layer: RemoteIPLayer,
}
impl<S> Service<Request<Body>> for RemoteIPMiddleware<S>
where
S: Service<Request<Body>, Response = Response> + Send + 'static,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: Request<Body>) -> Self::Future {
let layer = self.layer.clone();
let xff_ip = maybe_get_forwarded(req.headers(), layer.trusted_proxies.as_ref());
let remote_ip = xff_ip.map_or_else(
|| {
let ip = req
.extensions()
.get::<ConnectInfo<SocketAddr>>()
.map_or_else(
|| {
error!(
"remote ip middleware cannot get socket IP (not set in axum \
extensions): setting IP to `127.0.0.1`"
);
RemoteIP::None
},
|info| RemoteIP::Socket(info.ip()),
);
ip
},
RemoteIP::Forwarded,
);
req.extensions_mut().insert(remote_ip);
Box::pin(self.inner.call(req))
}
}

View File

@ -0,0 +1,132 @@
//! Middleware to generate or ensure a unique request ID for every request.
//!
//! The request ID is stored in the `x-request-id` header, and it is either
//! generated or sanitized if already present in the request.
//!
//! This can be useful for tracking requests across services, logging, and
//! debugging.
use axum::{Router, extract::Request, http::HeaderValue, middleware::Next, response::Response};
use regex::Regex;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::{web::middleware::MiddlewareLayer, app::AppContext, errors::RResult};
const X_REQUEST_ID: &str = "x-request-id";
const MAX_LEN: usize = 255;
use std::sync::{Arc, OnceLock};
static ID_CLEANUP: OnceLock<Regex> = OnceLock::new();
fn get_id_cleanup() -> &'static Regex {
ID_CLEANUP.get_or_init(|| Regex::new(r"[^\w\-@]").unwrap())
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct RequestId {
#[serde(default)]
pub enable: bool,
}
impl MiddlewareLayer for RequestId {
/// Returns the name of the middleware
fn name(&self) -> &'static str {
"request_id"
}
/// Returns whether the middleware is enabled or not
fn is_enabled(&self) -> bool {
self.enable
}
fn config(&self) -> serde_json::Result<serde_json::Value> {
serde_json::to_value(self)
}
/// Applies the request ID middleware to the Axum router.
///
/// This function sets up the middleware in the router and ensures that
/// every request passing through it will have a unique or sanitized
/// request ID.
///
/// # Errors
/// This function returns an error if the middleware cannot be applied.
fn apply(&self, app: Router<Arc<AppContext>>) -> RResult<Router<Arc<AppContext>>> {
Ok(app.layer(axum::middleware::from_fn(request_id_middleware)))
}
}
/// Wrapper struct for storing the request ID in the request's extensions.
#[derive(Debug, Clone)]
pub struct LocoRequestId(String);
impl LocoRequestId {
/// Retrieves the request ID as a string slice.
#[must_use]
pub fn get(&self) -> &str {
self.0.as_str()
}
}
/// Middleware function to ensure or generate a unique request ID.
///
/// This function intercepts requests, checks for the presence of the
/// `x-request-id` header, and either sanitizes its value or generates a new
/// UUID if absent. The resulting request ID is added to both the request
/// extensions and the response headers.
pub async fn request_id_middleware(mut request: Request, next: Next) -> Response {
let header_request_id = request.headers().get(X_REQUEST_ID).cloned();
let request_id = make_request_id(header_request_id);
request
.extensions_mut()
.insert(LocoRequestId(request_id.clone()));
let mut res = next.run(request).await;
if let Ok(v) = HeaderValue::from_str(request_id.as_str()) {
res.headers_mut().insert(X_REQUEST_ID, v);
} else {
tracing::warn!("could not set request ID into response headers: `{request_id}`",);
}
res
}
/// Generates or sanitizes a request ID.
fn make_request_id(maybe_request_id: Option<HeaderValue>) -> String {
maybe_request_id
.and_then(|hdr| {
// see: https://github.com/rails/rails/blob/main/actionpack/lib/action_dispatch/middleware/request_id.rb#L39
let id: Option<String> = hdr.to_str().ok().map(|s| {
get_id_cleanup()
.replace_all(s, "")
.chars()
.take(MAX_LEN)
.collect()
});
id.filter(|s| !s.is_empty())
})
.unwrap_or_else(|| Uuid::new_v4().to_string())
}
#[cfg(test)]
mod tests {
use axum::http::HeaderValue;
use insta::assert_debug_snapshot;
use super::make_request_id;
#[test]
fn create_or_fetch_request_id() {
let id = make_request_id(Some(HeaderValue::from_static("foo-bar=baz")));
assert_debug_snapshot!(id);
let id = make_request_id(Some(HeaderValue::from_static("")));
assert_debug_snapshot!(id.len());
let id = make_request_id(Some(HeaderValue::from_static("==========")));
assert_debug_snapshot!(id.len());
let long_id = "x".repeat(1000);
let id = make_request_id(Some(HeaderValue::from_str(&long_id).unwrap()));
assert_debug_snapshot!(id.len());
let id = make_request_id(None);
assert_debug_snapshot!(id.len());
}
}

View File

@ -0,0 +1,26 @@
{
"empty":{},
"github":{
"Content-Security-Policy": "default-src 'self' https:; font-src 'self' https: data:; img-src 'self' https: data:; object-src 'none'; script-src https:; style-src 'self' https: 'unsafe-inline'",
"Strict-Transport-Security": "max-age=631138519",
"X-Content-Type-Options": "nosniff",
"X-Download-Options": "noopen",
"X-Frame-Options": "sameorigin",
"X-Permitted-Cross-Domain-Policies": "none",
"X-Xss-Protection": "0"
},
"owasp":{
"Cache-Control": "no-store, max-age=0",
"Clear-Site-Data": "\"cache\",\"cookies\",\"storage\"",
"Content-Security-Policy": "default-src 'self'; form-action 'self'; object-src 'none'; frame-ancestors 'none'; upgrade-insecure-requests; block-all-mixed-content",
"Cross-Origin-Embedder-Policy": "require-corp",
"Cross-Origin-Opener-Policy": "same-origin",
"Cross-Origin-Resource-Policy": "same-origin",
"Permissions-Policy": "accelerometer=(), autoplay=(), camera=(), cross-origin-isolated=(), display-capture=(), encrypted-media=(), fullscreen=(), geolocation=(), gyroscope=(), keyboard-map=(), magnetometer=(), microphone=(), midi=(), payment=(), picture-in-picture=(), publickey-credentials-get=(), screen-wake-lock=(), sync-xhr=(self), usb=(), web-share=(), xr-spatial-tracking=(), clipboard-read=(), clipboard-write=(), gamepad=(), hid=(), idle-detection=(), interest-cohort=(), serial=(), unload=()",
"Referrer-Policy": "no-referrer",
"Strict-Transport-Security": "max-age=31536000; includeSubDomains",
"X-Content-Type-Options": "nosniff",
"X-Frame-Options": "deny",
"X-Permitted-Cross-Domain-Policies": "none"
}
}

View File

@ -0,0 +1,311 @@
//! Sets secure headers for your backend to promote security-by-default.
//!
//! This middleware applies secure HTTP headers, providing pre-defined presets
//! (e.g., "github") and the ability to override or define custom headers.
use std::{
collections::{BTreeMap, HashMap},
sync::{Arc, OnceLock},
task::{Context, Poll},
};
use axum::{
Router,
body::Body,
http::{HeaderName, HeaderValue, Request},
response::Response,
};
use futures_util::future::BoxFuture;
use serde::{Deserialize, Serialize};
use serde_json::{self, json};
use tower::{Layer, Service};
use crate::{
app::AppContext,
web::middleware::MiddlewareLayer,
errors::{RError, RResult},
};
static PRESETS: OnceLock<HashMap<String, BTreeMap<String, String>>> = OnceLock::new();
fn get_presets() -> &'static HashMap<String, BTreeMap<String, String>> {
PRESETS.get_or_init(|| {
let json_data = include_str!("secure_headers.json");
serde_json::from_str(json_data).unwrap()
})
}
/// Sets a predefined or custom set of secure headers.
///
/// We recommend our `github` preset. Presets values are derived
/// from the [secure_headers](https://github.com/github/secure_headers) Ruby
/// library which Github (and originally Twitter) use.
///
/// To use a preset, in your `config/development.yaml`:
///
/// ```yaml
/// middlewares:
/// secure_headers:
/// preset: github
/// ```
///
/// You can also override individual headers on a given preset:
///
/// ```yaml
/// middlewares:
/// secure_headers:
/// preset: github
/// overrides:
/// foo: bar
/// ```
///
/// Or start from scratch:
///
///```yaml
/// middlewares:
/// secure_headers:
/// preset: empty
/// overrides:
/// one: two
/// ```
///
/// To support `htmx`, You can add the following override, to allow some inline
/// running of scripts:
///
/// ```yaml
/// secure_headers:
/// preset: github
/// overrides:
/// # this allows you to use HTMX, and has unsafe-inline. Remove or consider in production
/// "Content-Security-Policy": "default-src 'self' https:; font-src 'self' https: data:; img-src 'self' https: data:; object-src 'none'; script-src 'unsafe-inline' 'self' https:; style-src 'self' https: 'unsafe-inline'"
/// ```
///
/// For the list of presets and their content look at [secure_headers.json](https://github.com/loco-rs/loco/blob/master/src/controller/middleware/secure_headers.rs)
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct SecureHeader {
#[serde(default)]
pub enable: bool,
#[serde(default = "default_preset")]
pub preset: String,
#[serde(default)]
pub overrides: Option<BTreeMap<String, String>>,
}
impl Default for SecureHeader {
fn default() -> Self {
serde_json::from_value(json!({})).unwrap()
}
}
fn default_preset() -> String {
"github".to_string()
}
impl MiddlewareLayer for SecureHeader {
/// Returns the name of the middleware
fn name(&self) -> &'static str {
"secure_headers"
}
/// Returns whether the middleware is enabled or not
fn is_enabled(&self) -> bool {
self.enable
}
fn config(&self) -> serde_json::Result<serde_json::Value> {
serde_json::to_value(self)
}
/// Applies the secure headers layer to the application router
fn apply(&self, app: Router<Arc<AppContext>>) -> RResult<Router<Arc<AppContext>>> {
Ok(app.layer(SecureHeaders::new(self)?))
}
}
impl SecureHeader {
/// Converts the configuration into a list of headers.
///
/// Applies the preset headers and any custom overrides.
fn as_headers(&self) -> RResult<Vec<(HeaderName, HeaderValue)>> {
let mut headers = vec![];
let preset = &self.preset;
let p = get_presets().get(preset).ok_or_else(|| {
RError::CustomMessageString(format!(
"secure_headers: a preset named `{preset}` does not exist"
))
})?;
Self::push_headers(&mut headers, p)?;
if let Some(overrides) = &self.overrides {
Self::push_headers(&mut headers, overrides)?;
}
Ok(headers)
}
/// Helper function to push headers into a mutable vector.
///
/// This function takes a map of header names and values, converting them
/// into valid HTTP headers and adding them to the provided `headers`
/// vector.
fn push_headers(
headers: &mut Vec<(HeaderName, HeaderValue)>,
hm: &BTreeMap<String, String>,
) -> RResult<()> {
for (k, v) in hm {
headers.push((
HeaderName::from_bytes(k.clone().as_bytes())?,
HeaderValue::from_str(v.clone().as_str())?,
));
}
Ok(())
}
}
/// The [`SecureHeaders`] layer which wraps around the service and injects
/// security headers
#[derive(Clone, Debug)]
pub struct SecureHeaders {
headers: Vec<(HeaderName, HeaderValue)>,
}
impl SecureHeaders {
/// Creates a new [`SecureHeaders`] instance with the provided
/// configuration.
///
/// # Errors
/// Returns an error if any header values are invalid.
pub fn new(config: &SecureHeader) -> RResult<Self> {
Ok(Self {
headers: config.as_headers()?,
})
}
}
impl<S> Layer<S> for SecureHeaders {
type Service = SecureHeadersMiddleware<S>;
/// Wraps the provided service with the secure headers middleware.
fn layer(&self, inner: S) -> Self::Service {
SecureHeadersMiddleware {
inner,
layer: self.clone(),
}
}
}
/// The secure headers middleware
#[derive(Clone, Debug)]
#[must_use]
pub struct SecureHeadersMiddleware<S> {
inner: S,
layer: SecureHeaders,
}
impl<S> Service<Request<Body>> for SecureHeadersMiddleware<S>
where
S: Service<Request<Body>, Response = Response> + Send + 'static,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, request: Request<Body>) -> Self::Future {
let layer = self.layer.clone();
let future = self.inner.call(request);
Box::pin(async move {
let mut response: Response = future.await?;
let headers = response.headers_mut();
for (k, v) in &layer.headers {
headers.insert(k, v.clone());
}
Ok(response)
})
}
}
#[cfg(test)]
mod tests {
use axum::{
Router,
http::{HeaderMap, Method},
routing::get,
};
use insta::assert_debug_snapshot;
use tower::ServiceExt;
use super::*;
fn normalize_headers(headers: &HeaderMap) -> BTreeMap<String, String> {
headers
.iter()
.map(|(k, v)| {
let key = k.to_string();
let value = v.to_str().unwrap_or("").to_string();
(key, value)
})
.collect()
}
#[tokio::test]
async fn can_set_headers() {
let config = SecureHeader {
enable: true,
preset: "github".to_string(),
overrides: None,
};
let app = Router::new()
.route("/", get(|| async {}))
.layer(SecureHeaders::new(&config).unwrap());
let req = Request::builder()
.uri("/")
.method(Method::GET)
.body(Body::empty())
.unwrap();
let response = app.oneshot(req).await.unwrap();
assert_debug_snapshot!(normalize_headers(response.headers()));
}
#[tokio::test]
async fn can_override_headers() {
let mut overrides = BTreeMap::new();
overrides.insert("X-Download-Options".to_string(), "foobar".to_string());
overrides.insert("New-Header".to_string(), "baz".to_string());
let config = SecureHeader {
enable: true,
preset: "github".to_string(),
overrides: Some(overrides),
};
let app = Router::new()
.route("/", get(|| async {}))
.layer(SecureHeaders::new(&config).unwrap());
let req = Request::builder()
.uri("/")
.method(Method::GET)
.body(Body::empty())
.unwrap();
let response = app.oneshot(req).await.unwrap();
assert_debug_snapshot!(normalize_headers(response.headers()));
}
#[tokio::test]
async fn default_is_github_preset() {
let config = SecureHeader::default();
let app = Router::new()
.route("/", get(|| async {}))
.layer(SecureHeaders::new(&config).unwrap());
let req = Request::builder()
.uri("/")
.method(Method::GET)
.body(Body::empty())
.unwrap();
let response = app.oneshot(req).await.unwrap();
assert_debug_snapshot!(normalize_headers(response.headers()));
}
}

View File

@ -0,0 +1,64 @@
//! Timeout Request Middleware.
//!
//! This middleware applies a timeout to requests processed by the application.
//! The timeout duration is configurable and defined via the
//! [`TimeoutRequestMiddleware`] configuration. The middleware ensures that
//! requests do not run beyond the specified timeout period, improving the
//! overall performance and responsiveness of the application.
//!
//! If a request exceeds the specified timeout duration, the middleware will
//! return a `408 Request Timeout` status code to the client, indicating that
//! the request took too long to process.
use std::{sync::Arc, time::Duration};
use axum::Router;
use serde::{Deserialize, Serialize};
use serde_json::json;
use tower_http::timeout::TimeoutLayer;
use crate::{app::AppContext, errors::RResult, web::middleware::MiddlewareLayer};
/// Timeout middleware configuration
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct TimeOut {
#[serde(default)]
pub enable: bool,
// Timeout request in milliseconds
#[serde(default = "default_timeout")]
pub timeout: u64,
}
impl Default for TimeOut {
fn default() -> Self {
serde_json::from_value(json!({})).unwrap()
}
}
fn default_timeout() -> u64 {
5_000
}
impl MiddlewareLayer for TimeOut {
/// Returns the name of the middleware.
fn name(&self) -> &'static str {
"timeout_request"
}
/// Checks if the timeout middleware is enabled.
fn is_enabled(&self) -> bool {
self.enable
}
fn config(&self) -> serde_json::Result<serde_json::Value> {
serde_json::to_value(self)
}
/// Applies the timeout middleware to the application router.
///
/// This method wraps the provided [`AXRouter`] in a [`TimeoutLayer`],
/// ensuring that requests exceeding the specified timeout duration will
/// be interrupted.
fn apply(&self, app: Router<Arc<AppContext>>) -> RResult<Router<Arc<AppContext>>> {
Ok(app.layer(TimeoutLayer::new(Duration::from_millis(self.timeout))))
}
}

View File

@ -0,0 +1,5 @@
pub mod config;
pub mod controller;
pub mod middleware;
pub use config::WebServerConfig;

View File

@ -1,6 +1,5 @@
#![allow(unused_imports)] #![allow(unused_imports)]
use insta::{assert_debug_snapshot, with_settings}; use insta::{assert_debug_snapshot, with_settings};
use recorder::app::App1;
use serial_test::serial; use serial_test::serial;
macro_rules! configure_insta { macro_rules! configure_insta {

View File

@ -1,4 +1,4 @@
{ {
"routesDirectory": "./src/controllers", "routesDirectory": "./src/web/controller",
"generatedRouteTree": "./src/routeTree.gen.ts" "generatedRouteTree": "./src/routeTree.gen.ts"
} }

116
bacon.toml Normal file
View File

@ -0,0 +1,116 @@
# This is a configuration file for the bacon tool
#
# Complete help on configuration: https://dystroy.org/bacon/config/
#
# You may check the current default at
# https://github.com/Canop/bacon/blob/main/defaults/default-bacon.toml
default_job = "check"
env.CARGO_TERM_COLOR = "always"
[jobs.recorder]
command = ["cargo", "recorder"]
watch = ["apps/recorder"]
need_stdout = true
[jobs.check]
command = ["cargo", "check"]
need_stdout = false
[jobs.check-all]
command = ["cargo", "check", "--all-targets"]
need_stdout = false
# Run clippy on the default target
[jobs.clippy]
command = ["cargo", "clippy"]
need_stdout = false
# Run clippy on all targets
# To disable some lints, you may change the job this way:
# [jobs.clippy-all]
# command = [
# "cargo", "clippy",
# "--all-targets",
# "--",
# "-A", "clippy::bool_to_int_with_if",
# "-A", "clippy::collapsible_if",
# "-A", "clippy::derive_partial_eq_without_eq",
# ]
# need_stdout = false
[jobs.clippy-all]
command = ["cargo", "clippy", "--all-targets"]
need_stdout = false
# This job lets you run
# - all tests: bacon test
# - a specific test: bacon test -- config::test_default_files
# - the tests of a package: bacon test -- -- -p config
[jobs.test]
command = ["cargo", "test"]
need_stdout = true
[jobs.nextest]
command = [
"cargo", "nextest", "run",
"--hide-progress-bar", "--failure-output", "final"
]
need_stdout = true
analyzer = "nextest"
[jobs.doc]
command = ["cargo", "doc", "--no-deps"]
need_stdout = false
# If the doc compiles, then it opens in your browser and bacon switches
# to the previous job
[jobs.doc-open]
command = ["cargo", "doc", "--no-deps", "--open"]
need_stdout = false
on_success = "back" # so that we don't open the browser at each change
# You can run your application and have the result displayed in bacon,
# if it makes sense for this crate.
[jobs.run]
command = [
"cargo", "run",
# put launch parameters for your program behind a `--` separator
]
need_stdout = true
allow_warnings = true
background = true
# Run your long-running application (eg server) and have the result displayed in bacon.
# For programs that never stop (eg a server), `background` is set to false
# to have the cargo run output immediately displayed instead of waiting for
# program's end.
# 'on_change_strategy' is set to `kill_then_restart` to have your program restart
# on every change (an alternative would be to use the 'F5' key manually in bacon).
# If you often use this job, it makes sense to override the 'r' key by adding
# a binding `r = job:run-long` at the end of this file .
[jobs.run-long]
command = [
"cargo", "run",
# put launch parameters for your program behind a `--` separator
]
need_stdout = true
allow_warnings = true
background = false
on_change_strategy = "kill_then_restart"
# This parameterized job runs the example of your choice, as soon
# as the code compiles.
# Call it as
# bacon ex -- my-example
[jobs.ex]
command = ["cargo", "run", "--example"]
need_stdout = true
allow_warnings = true
# You may define here keybindings that would be specific to
# a project, for example a shortcut to launch a specific job.
# Shortcuts to internal functions (scrolling, toggling, etc.)
# should go in your personal global prefs.toml file instead.
[keybindings]
# alt-m = "job:my-job"
c = "job:clippy-all" # comment this to have 'c' run clippy on only the default target

View File

@ -2,7 +2,6 @@ set windows-shell := ["pwsh.exe", "-c"]
set dotenv-load set dotenv-load
prepare-dev-recorder: prepare-dev-recorder:
cargo install loco-cli
cargo install sea-orm-cli cargo install sea-orm-cli
cargo install cargo-watch cargo install cargo-watch
@ -13,13 +12,10 @@ dev-proxy:
pnpm run --filter=proxy dev pnpm run --filter=proxy dev
dev-recorder: dev-recorder:
cargo watch -w apps/recorder -i '**/*.{js,css,scss,tsx,ts,jsx,html}' -x 'recorder start' bacon recorder
dev-playground: dev-playground:
pnpm run --filter=recorder dev pnpm run --filter=recorder dev
down-recorder:
cargo run -p recorder --bin recorder_cli -- db down 999 --environment development
play-recorder: play-recorder:
cargo recorder-playground cargo recorder-playground