feat: add permission control

This commit is contained in:
master 2025-02-22 20:26:14 +08:00
parent ae40a3a7f8
commit c2f74dc369
33 changed files with 707 additions and 226 deletions

30
Cargo.lock generated
View File

@ -225,8 +225,9 @@ dependencies = [
[[package]] [[package]]
name = "async-graphql" name = "async-graphql"
version = "7.0.13" version = "7.0.15"
source = "git+https://github.com/aumetra/async-graphql.git?rev=690ece7#690ece7cd408e28bfaf0c434fdd4c46ef1a78ef2" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfff2b17d272a5e3e201feda444e2c24b011fa722951268d1bd8b9b5bc6dc449"
dependencies = [ dependencies = [
"async-graphql-derive", "async-graphql-derive",
"async-graphql-parser", "async-graphql-parser",
@ -261,8 +262,9 @@ dependencies = [
[[package]] [[package]]
name = "async-graphql-axum" name = "async-graphql-axum"
version = "7.0.13" version = "7.0.15"
source = "git+https://github.com/aumetra/async-graphql.git?rev=690ece7#690ece7cd408e28bfaf0c434fdd4c46ef1a78ef2" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6bf2882c816094fef6e39d381b8e9b710e5943e7bdef5198496441d5083164fa"
dependencies = [ dependencies = [
"async-graphql", "async-graphql",
"axum", "axum",
@ -277,8 +279,9 @@ dependencies = [
[[package]] [[package]]
name = "async-graphql-derive" name = "async-graphql-derive"
version = "7.0.13" version = "7.0.15"
source = "git+https://github.com/aumetra/async-graphql.git?rev=690ece7#690ece7cd408e28bfaf0c434fdd4c46ef1a78ef2" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d8e5d0c6697def2f79ccbd972fb106b633173a6066e430b480e1ff9376a7561a"
dependencies = [ dependencies = [
"Inflector", "Inflector",
"async-graphql-parser", "async-graphql-parser",
@ -293,8 +296,9 @@ dependencies = [
[[package]] [[package]]
name = "async-graphql-parser" name = "async-graphql-parser"
version = "7.0.13" version = "7.0.15"
source = "git+https://github.com/aumetra/async-graphql.git?rev=690ece7#690ece7cd408e28bfaf0c434fdd4c46ef1a78ef2" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8531ee6d292c26df31c18c565ff22371e7bdfffe7f5e62b69537db0b8fd554dc"
dependencies = [ dependencies = [
"async-graphql-value", "async-graphql-value",
"pest", "pest",
@ -304,8 +308,9 @@ dependencies = [
[[package]] [[package]]
name = "async-graphql-value" name = "async-graphql-value"
version = "7.0.13" version = "7.0.15"
source = "git+https://github.com/aumetra/async-graphql.git?rev=690ece7#690ece7cd408e28bfaf0c434fdd4c46ef1a78ef2" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "741110dda927420a28fbc1c310543d3416f789a6ba96859c2c265843a0a96887"
dependencies = [ dependencies = [
"bytes", "bytes",
"indexmap 2.7.1", "indexmap 2.7.1",
@ -6801,8 +6806,9 @@ dependencies = [
[[package]] [[package]]
name = "testcontainers" name = "testcontainers"
version = "0.23.1" version = "0.23.3"
source = "git+https://github.com/testcontainers/testcontainers-rs.git?rev=af21727#af2172714bbb79c6ce648b699135922f85cafc0c" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59a4f01f39bb10fc2a5ab23eb0d888b1e2bb168c157f61a1b98e6c501c639c74"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"bollard", "bollard",

View File

@ -3,12 +3,9 @@ members = ["apps/recorder"]
resolver = "2" resolver = "2"
[patch.crates-io] [patch.crates-io]
testcontainers = { git = "https://github.com/testcontainers/testcontainers-rs.git", rev = "af21727" }
# loco-rs = { git = "https://github.com/lonelyhentxi/loco.git", rev = "beb890e" } # loco-rs = { git = "https://github.com/lonelyhentxi/loco.git", rev = "beb890e" }
# loco-rs = { git = "https://github.com/loco-rs/loco.git" } # loco-rs = { git = "https://github.com/loco-rs/loco.git" }
# loco-rs = { path = "./patches/loco" } # loco-rs = { path = "./patches/loco" }
async-graphql = { git = "https://github.com/aumetra/async-graphql.git", rev = "690ece7" }
async-graphql-axum = { git = "https://github.com/aumetra/async-graphql.git", rev = "690ece7" }
jwt-authorizer = { git = "https://github.com/blablacio/jwt-authorizer.git", rev = "e956774" } jwt-authorizer = { git = "https://github.com/blablacio/jwt-authorizer.git", rev = "e956774" }
# [patch."https://github.com/lonelyhentxi/qbit.git"] # [patch."https://github.com/lonelyhentxi/qbit.git"]

View File

@ -29,7 +29,7 @@ tokio = { version = "1.42", features = ["macros", "fs", "rt-multi-thread"] }
async-trait = "0.1.83" async-trait = "0.1.83"
tracing = "0.1" tracing = "0.1"
chrono = "0.4" chrono = "0.4"
sea-orm = { version = "1", features = [ sea-orm = { version = "1.1", features = [
"sqlx-sqlite", "sqlx-sqlite",
"sqlx-postgres", "sqlx-postgres",
"runtime-tokio-rustls", "runtime-tokio-rustls",
@ -41,7 +41,7 @@ figment = { version = "0.10", features = ["toml", "json", "env", "yaml"] }
axum = "0.8" axum = "0.8"
uuid = { version = "1.6.0", features = ["v4"] } uuid = { version = "1.6.0", features = ["v4"] }
tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] } tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] }
sea-orm-migration = { version = "1", features = ["runtime-tokio-rustls"] } sea-orm-migration = { version = "1.1", features = ["runtime-tokio-rustls"] }
reqwest = { version = "0.12", features = [ reqwest = { version = "0.12", features = [
"charset", "charset",
"http2", "http2",
@ -76,7 +76,7 @@ qbit-rs = { git = "https://github.com/lonelyhentxi/qbit.git", rev = "72d53138ebe
"default", "default",
"builder", "builder",
] } ] }
testcontainers = { version = "0.23.1", features = [ testcontainers = { version = "0.23.3", features = [
"default", "default",
"properties-config", "properties-config",
"watchdog", "watchdog",
@ -88,10 +88,10 @@ color-eyre = "0.6"
log = "0.4.22" log = "0.4.22"
anyhow = "1.0.95" anyhow = "1.0.95"
bollard = { version = "0.18", optional = true } bollard = { version = "0.18", optional = true }
async-graphql = { version = "7.0.13", features = [] } async-graphql = { version = "7.0.15", features = [] }
async-graphql-axum = "7.0.13" async-graphql-axum = "7.0.15"
fastrand = "2.3.0" fastrand = "2.3.0"
seaography = "1.1.2" seaography = "1.1"
quirks_path = "0.1.1" quirks_path = "0.1.1"
base64 = "0.22.1" base64 = "0.22.1"
tower = "0.5.2" tower = "0.5.2"
@ -99,7 +99,7 @@ axum-extra = "0.10.0"
tower-http = "0.6.2" tower-http = "0.6.2"
serde_yaml = "0.9.34" serde_yaml = "0.9.34"
tera = "1.20.0" tera = "1.20.0"
openidconnect = "4.0.0-rc.1" openidconnect = "4"
http-cache-reqwest = { version = "0.15", features = [ http-cache-reqwest = { version = "0.15", features = [
"manager-cacache", "manager-cacache",
"manager-moka", "manager-moka",

View File

@ -1,6 +1,7 @@
use async_trait::async_trait; use async_trait::async_trait;
use axum::http::{request::Parts, HeaderValue}; use axum::http::{HeaderValue, request::Parts};
use base64::{self, Engine}; use base64::{self, Engine};
use loco_rs::app::AppContext;
use reqwest::header::AUTHORIZATION; use reqwest::header::AUTHORIZATION;
use super::{ use super::{
@ -59,7 +60,11 @@ pub struct BasicAuthService {
#[async_trait] #[async_trait]
impl AuthService for BasicAuthService { impl AuthService for BasicAuthService {
async fn extract_user_info(&self, request: &mut Parts) -> Result<AuthUserInfo, AuthError> { async fn extract_user_info(
&self,
ctx: &AppContext,
request: &mut Parts,
) -> Result<AuthUserInfo, AuthError> {
if let Ok(AuthBasic { if let Ok(AuthBasic {
user: found_user, user: found_user,
password: found_password, password: found_password,
@ -68,8 +73,11 @@ impl AuthService for BasicAuthService {
if self.config.user == found_user if self.config.user == found_user
&& self.config.password == found_password.unwrap_or_default() && self.config.password == found_password.unwrap_or_default()
{ {
let subscriber_auth = crate::models::auth::Model::find_by_pid(ctx, SEED_SUBSCRIBER)
.await
.map_err(AuthError::FindAuthRecordError)?;
return Ok(AuthUserInfo { return Ok(AuthUserInfo {
user_pid: SEED_SUBSCRIBER.to_string(), subscriber_auth,
auth_type: AuthType::Basic, auth_type: AuthType::Basic,
}); });
} }

View File

@ -1,11 +1,14 @@
use std::fmt;
use axum::{ use axum::{
Json,
http::StatusCode, http::StatusCode,
response::{IntoResponse, Response}, response::{IntoResponse, Response},
Json,
}; };
use loco_rs::model::ModelError;
use openidconnect::{ use openidconnect::{
core::CoreErrorResponseType, ConfigurationError, RequestTokenError, SignatureVerificationError, ConfigurationError, RequestTokenError, SignatureVerificationError, SigningError,
SigningError, StandardErrorResponse, StandardErrorResponse, core::CoreErrorResponseType,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use thiserror::Error; use thiserror::Error;
@ -19,6 +22,8 @@ pub enum AuthError {
supported: Vec<AuthType>, supported: Vec<AuthType>,
current: AuthType, current: AuthType,
}, },
#[error("Failed to find auth record")]
FindAuthRecordError(ModelError),
#[error("Invalid credentials")] #[error("Invalid credentials")]
BasicInvalidCredentials, BasicInvalidCredentials,
#[error(transparent)] #[error(transparent)]
@ -69,6 +74,15 @@ pub enum AuthError {
OidcAudMissingError(String), OidcAudMissingError(String),
#[error("Subject missing")] #[error("Subject missing")]
OidcSubMissingError, OidcSubMissingError,
#[error(fmt = display_graphql_permission_error)]
GraphQLPermissionError(async_graphql::Error),
}
fn display_graphql_permission_error(
error: &async_graphql::Error,
formatter: &mut fmt::Formatter<'_>,
) -> fmt::Result {
write!(formatter, "GraphQL permission denied: {}", error.message)
} }
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]

View File

@ -19,7 +19,7 @@ pub async fn api_auth_middleware(
let (mut parts, body) = request.into_parts(); let (mut parts, body) = request.into_parts();
let mut response = match auth_service.extract_user_info(&mut parts).await { let mut response = match auth_service.extract_user_info(&ctx, &mut parts).await {
Ok(auth_user_info) => { Ok(auth_user_info) => {
let mut request = Request::from_parts(parts, body); let mut request = Request::from_parts(parts, body);
request.extensions_mut().insert(auth_user_info); request.extensions_mut().insert(auth_user_info);

View File

@ -4,14 +4,15 @@ use std::{
}; };
use async_trait::async_trait; use async_trait::async_trait;
use axum::http::{request::Parts, HeaderValue}; use axum::http::{HeaderValue, request::Parts};
use itertools::Itertools; use itertools::Itertools;
use jwt_authorizer::{authorizer::Authorizer, NumericDate, OneOrArray}; use jwt_authorizer::{NumericDate, OneOrArray, authorizer::Authorizer};
use loco_rs::{app::AppContext, model::ModelError};
use moka::future::Cache; use moka::future::Cache;
use openidconnect::{ use openidconnect::{
core::{CoreAuthenticationFlow, CoreClient, CoreProviderMetadata},
AccessTokenHash, AuthorizationCode, ClientId, ClientSecret, CsrfToken, IssuerUrl, Nonce, AccessTokenHash, AuthorizationCode, ClientId, ClientSecret, CsrfToken, IssuerUrl, Nonce,
OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, TokenResponse, OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, TokenResponse,
core::{CoreAuthenticationFlow, CoreClient, CoreProviderMetadata},
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value; use serde_json::Value;
@ -258,7 +259,11 @@ impl OidcAuthService {
#[async_trait] #[async_trait]
impl AuthService for OidcAuthService { impl AuthService for OidcAuthService {
async fn extract_user_info(&self, request: &mut Parts) -> Result<AuthUserInfo, AuthError> { async fn extract_user_info(
&self,
ctx: &AppContext,
request: &mut Parts,
) -> Result<AuthUserInfo, AuthError> {
let config = &self.config; let config = &self.config;
let token = self.api_authorizer.extract_token(&request.headers).ok_or( let token = self.api_authorizer.extract_token(&request.headers).ok_or(
AuthError::OidcJwtAuthError(jwt_authorizer::AuthError::MissingToken()), AuthError::OidcJwtAuthError(jwt_authorizer::AuthError::MissingToken()),
@ -266,9 +271,11 @@ impl AuthService for OidcAuthService {
let token_data = self.api_authorizer.check_auth(&token).await?; let token_data = self.api_authorizer.check_auth(&token).await?;
let claims = token_data.claims; let claims = token_data.claims;
if claims.sub.as_deref().is_none_or(|s| s.trim().is_empty()) { let sub = if let Some(sub) = claims.sub.as_deref() {
sub
} else {
return Err(AuthError::OidcSubMissingError); return Err(AuthError::OidcSubMissingError);
} };
if !claims.contains_audience(&config.audience) { if !claims.contains_audience(&config.audience) {
return Err(AuthError::OidcAudMissingError(config.audience.clone())); return Err(AuthError::OidcAudMissingError(config.audience.clone()));
} }
@ -298,12 +305,16 @@ impl AuthService for OidcAuthService {
} }
} }
} }
let subscriber_auth = match crate::models::auth::Model::find_by_pid(ctx, sub).await {
Err(ModelError::EntityNotFound) => {
crate::models::auth::Model::create_from_oidc(ctx, sub.to_string()).await
}
r => r,
}
.map_err(AuthError::FindAuthRecordError)?;
Ok(AuthUserInfo { Ok(AuthUserInfo {
user_pid: claims subscriber_auth,
.sub
.as_deref()
.map(|s| s.trim().to_string())
.unwrap_or_else(|| unreachable!("sub should be present and validated")),
auth_type: AuthType::Oidc, auth_type: AuthType::Oidc,
}) })
} }

View File

@ -13,24 +13,24 @@ use once_cell::sync::OnceCell;
use reqwest::header::HeaderValue; use reqwest::header::HeaderValue;
use super::{ use super::{
AppAuthConfig,
basic::BasicAuthService, basic::BasicAuthService,
errors::AuthError, errors::AuthError,
oidc::{OidcAuthClaims, OidcAuthService}, oidc::{OidcAuthClaims, OidcAuthService},
AppAuthConfig,
}; };
use crate::{ use crate::{
app::AppContextExt as _, app::AppContextExt as _,
config::AppConfigExt, config::AppConfigExt,
fetch::{ fetch::{
client::{HttpClientCacheBackendConfig, HttpClientCachePresetConfig},
HttpClient, HttpClientConfig, HttpClient, HttpClientConfig,
client::{HttpClientCacheBackendConfig, HttpClientCachePresetConfig},
}, },
models::auth::AuthType, models::auth::AuthType,
}; };
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct AuthUserInfo { pub struct AuthUserInfo {
pub user_pid: String, pub subscriber_auth: crate::models::auth::Model,
pub auth_type: AuthType, pub auth_type: AuthType,
} }
@ -44,7 +44,7 @@ impl FromRequestParts<AppContext> for AuthUserInfo {
let auth_service = state.get_auth_service(); let auth_service = state.get_auth_service();
auth_service auth_service
.extract_user_info(parts) .extract_user_info(state, parts)
.await .await
.map_err(|err| err.into_response()) .map_err(|err| err.into_response())
} }
@ -52,7 +52,11 @@ impl FromRequestParts<AppContext> for AuthUserInfo {
#[async_trait] #[async_trait]
pub trait AuthService { pub trait AuthService {
async fn extract_user_info(&self, request: &mut Parts) -> Result<AuthUserInfo, AuthError>; async fn extract_user_info(
&self,
ctx: &AppContext,
request: &mut Parts,
) -> Result<AuthUserInfo, AuthError>;
fn www_authenticate_header_value(&self) -> Option<HeaderValue>; fn www_authenticate_header_value(&self) -> Option<HeaderValue>;
fn auth_type(&self) -> AuthType; fn auth_type(&self) -> AuthType;
} }
@ -79,21 +83,23 @@ impl AppAuthService {
.iss(&[&config.issuer]) .iss(&[&config.issuer])
.aud(&[&config.audience]); .aud(&[&config.audience]);
let jwt_auth = JwtAuthorizer::<OidcAuthClaims>::from_oidc(&config.issuer) let oidc_provider_client = HttpClient::from_config(HttpClientConfig {
exponential_backoff_max_retries: Some(3),
cache_backend: Some(HttpClientCacheBackendConfig::Moka { cache_size: 1 }),
cache_preset: Some(HttpClientCachePresetConfig::RFC7234),
..Default::default()
})
.map_err(AuthError::OidcProviderHttpClientError)?;
let api_authorizer = JwtAuthorizer::<OidcAuthClaims>::from_oidc(&config.issuer)
.validation(validation) .validation(validation)
.build() .build()
.await?; .await?;
AppAuthService::Oidc(OidcAuthService { AppAuthService::Oidc(OidcAuthService {
config, config,
api_authorizer: jwt_auth, api_authorizer,
oidc_provider_client: HttpClient::from_config(HttpClientConfig { oidc_provider_client,
exponential_backoff_max_retries: Some(3),
cache_backend: Some(HttpClientCacheBackendConfig::Moka { cache_size: 1 }),
cache_preset: Some(HttpClientCachePresetConfig::RFC7234),
..Default::default()
})
.map_err(AuthError::OidcProviderHttpClientError)?,
oidc_request_cache: Cache::builder() oidc_request_cache: Cache::builder()
.time_to_live(Duration::from_mins(5)) .time_to_live(Duration::from_mins(5))
.name("oidc_request_cache") .name("oidc_request_cache")
@ -107,10 +113,14 @@ impl AppAuthService {
#[async_trait] #[async_trait]
impl AuthService for AppAuthService { impl AuthService for AppAuthService {
async fn extract_user_info(&self, request: &mut Parts) -> Result<AuthUserInfo, AuthError> { async fn extract_user_info(
&self,
ctx: &AppContext,
request: &mut Parts,
) -> Result<AuthUserInfo, AuthError> {
match self { match self {
AppAuthService::Basic(service) => service.extract_user_info(request).await, AppAuthService::Basic(service) => service.extract_user_info(ctx, request).await,
AppAuthService::Oidc(service) => service.extract_user_info(request).await, AppAuthService::Oidc(service) => service.extract_user_info(ctx, request).await,
} }
} }

View File

@ -4,7 +4,7 @@ use async_trait::async_trait;
use bytes::Bytes; use bytes::Bytes;
use loco_rs::app::{AppContext, Initializer}; use loco_rs::app::{AppContext, Initializer};
use once_cell::sync::OnceCell; use once_cell::sync::OnceCell;
use opendal::{layers::LoggingLayer, services::Fs, Buffer, Operator}; use opendal::{Buffer, Operator, layers::LoggingLayer, services::Fs};
use quirks_path::{Path, PathBuf}; use quirks_path::{Path, PathBuf};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use url::Url; use url::Url;
@ -81,7 +81,7 @@ impl AppDalClient {
pub async fn store_object( pub async fn store_object(
&self, &self,
content_category: DalContentCategory, content_category: DalContentCategory,
subscriber_pid: &str, subscriber_id: i32,
bucket: Option<&str>, bucket: Option<&str>,
filename: &str, filename: &str,
data: Bytes, data: Bytes,
@ -89,7 +89,7 @@ impl AppDalClient {
match content_category { match content_category {
DalContentCategory::Image => { DalContentCategory::Image => {
let fullname = [ let fullname = [
subscriber_pid, &subscriber_id.to_string(),
content_category.as_ref(), content_category.as_ref(),
bucket.unwrap_or_default(), bucket.unwrap_or_default(),
filename, filename,
@ -119,14 +119,14 @@ impl AppDalClient {
pub async fn exists_object( pub async fn exists_object(
&self, &self,
content_category: DalContentCategory, content_category: DalContentCategory,
subscriber_pid: &str, subscriber_id: i32,
bucket: Option<&str>, bucket: Option<&str>,
filename: &str, filename: &str,
) -> color_eyre::eyre::Result<Option<DalStoredUrl>> { ) -> color_eyre::eyre::Result<Option<DalStoredUrl>> {
match content_category { match content_category {
DalContentCategory::Image => { DalContentCategory::Image => {
let fullname = [ let fullname = [
subscriber_pid, &subscriber_id.to_string(),
content_category.as_ref(), content_category.as_ref(),
bucket.unwrap_or_default(), bucket.unwrap_or_default(),
filename, filename,

View File

@ -12,14 +12,13 @@ use scraper::Html;
use url::Url; use url::Url;
use super::{ use super::{
parse_mikan_bangumi_id_from_rss_link, AppMikanClient, MikanBangumiRssLink, MIKAN_BUCKET_KEY, AppMikanClient, MIKAN_BUCKET_KEY, MikanBangumiRssLink, parse_mikan_bangumi_id_from_rss_link,
}; };
use crate::{ use crate::{
app::AppContextExt, app::AppContextExt,
dal::DalContentCategory, dal::DalContentCategory,
extract::html::parse_style_attr, extract::html::parse_style_attr,
fetch::{html::fetch_html, image::fetch_image}, fetch::{html::fetch_html, image::fetch_image},
models::subscribers,
}; };
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq)]
@ -110,11 +109,10 @@ pub async fn parse_mikan_bangumi_poster_from_origin_poster_src_with_cache(
) -> color_eyre::eyre::Result<MikanBangumiPosterMeta> { ) -> color_eyre::eyre::Result<MikanBangumiPosterMeta> {
let dal_client = ctx.get_dal_client(); let dal_client = ctx.get_dal_client();
let mikan_client = ctx.get_mikan_client(); let mikan_client = ctx.get_mikan_client();
let subscriber_pid = &subscribers::Model::find_pid_by_id_with_cache(ctx, subscriber_id).await?;
if let Some(poster_src) = dal_client if let Some(poster_src) = dal_client
.exists_object( .exists_object(
DalContentCategory::Image, DalContentCategory::Image,
subscriber_pid, subscriber_id,
Some(MIKAN_BUCKET_KEY), Some(MIKAN_BUCKET_KEY),
&origin_poster_src.path().replace("/images/Bangumi/", ""), &origin_poster_src.path().replace("/images/Bangumi/", ""),
) )
@ -132,7 +130,7 @@ pub async fn parse_mikan_bangumi_poster_from_origin_poster_src_with_cache(
let poster_str = dal_client let poster_str = dal_client
.store_object( .store_object(
DalContentCategory::Image, DalContentCategory::Image,
subscriber_pid, subscriber_id,
Some(MIKAN_BUCKET_KEY), Some(MIKAN_BUCKET_KEY),
&origin_poster_src.path().replace("/images/Bangumi/", ""), &origin_poster_src.path().replace("/images/Bangumi/", ""),
poster_data.clone(), poster_data.clone(),

View File

@ -0,0 +1,27 @@
use std::sync::Arc;
use async_graphql::{
ServerResult, Value,
extensions::{Extension, ExtensionContext, ExtensionFactory, NextResolve, ResolveInfo},
};
pub struct GraphqlAuthExtension;
#[async_trait::async_trait]
impl Extension for GraphqlAuthExtension {
async fn resolve(
&self,
ctx: &ExtensionContext<'_>,
info: ResolveInfo<'_>,
next: NextResolve<'_>,
) -> ServerResult<Option<Value>> {
dbg!(info.field);
next.run(ctx, info).await
}
}
impl ExtensionFactory for GraphqlAuthExtension {
fn create(&self) -> Arc<dyn Extension> {
Arc::new(GraphqlAuthExtension)
}
}

View File

@ -0,0 +1,199 @@
use std::sync::Arc;
use async_graphql::dynamic::{ResolverContext, ValueAccessor};
use sea_orm::EntityTrait;
use seaography::{BuilderContext, FnGuard, GuardAction};
use super::util::get_entity_key;
use crate::{
auth::{AuthError, AuthUserInfo},
graphql::util::get_column_key,
};
fn guard_data_object_accessor_with_subscriber_id(
value: ValueAccessor<'_>,
column_name: &str,
subscriber_id: i32,
) -> async_graphql::Result<()> {
let obj = value.object()?;
let subscriber_id_value = obj.try_get(column_name)?;
let id = subscriber_id_value.i64()?;
if id == subscriber_id as i64 {
Ok(())
} else {
Err(async_graphql::Error::new("subscriber permission denied"))
}
}
fn guard_data_object_accessor_with_optional_subscriber_id(
value: ValueAccessor<'_>,
column_name: &str,
subscriber_id: i32,
) -> async_graphql::Result<()> {
if value.is_null() {
return Ok(());
}
let obj = value.object()?;
if let Some(subscriber_id_value) = obj.get(column_name) {
let id = subscriber_id_value.i64()?;
if id == subscriber_id as i64 {
Ok(())
} else {
Err(async_graphql::Error::new("subscriber permission denied"))
}
} else {
Ok(())
}
}
fn guard_filter_object_accessor_with_subscriber_id(
value: ValueAccessor<'_>,
column_name: &str,
subscriber_id: i32,
) -> async_graphql::Result<()> {
let obj = value.object()?;
let subscriber_id_filter_input_value = obj.try_get(column_name)?;
let subscriber_id_filter_input_obj = subscriber_id_filter_input_value.object()?;
let subscriber_id_value = subscriber_id_filter_input_obj.try_get("eq")?;
let id = subscriber_id_value.i64()?;
if id == subscriber_id as i64 {
Ok(())
} else {
Err(async_graphql::Error::new("subscriber permission denied"))
}
}
pub fn guard_entity_with_subscriber_id<T>(context: &BuilderContext, column: &T::Column) -> FnGuard
where
T: EntityTrait,
<T as EntityTrait>::Model: Sync,
{
let entity_key = get_entity_key::<T>(context);
let entity_name = context.entity_query_field.type_name.as_ref()(&entity_key);
let column_key = get_column_key::<T>(context, column);
let column_name = Arc::new(context.entity_object.column_name.as_ref()(
&entity_key,
&column_key,
));
let entity_create_one_mutation_field_name = Arc::new(format!(
"{}{}",
entity_name, context.entity_create_one_mutation.mutation_suffix
));
let entity_create_one_mutation_data_field_name =
Arc::new(context.entity_create_one_mutation.data_field.clone());
let entity_create_batch_mutation_field_name = Arc::new(format!(
"{}{}",
entity_name,
context.entity_create_batch_mutation.mutation_suffix.clone()
));
let entity_create_batch_mutation_data_field_name =
Arc::new(context.entity_create_batch_mutation.data_field.clone());
let entity_delete_mutation_field_name = Arc::new(format!(
"{}{}",
entity_name,
context.entity_delete_mutation.mutation_suffix.clone()
));
let entity_delete_filter_field_name =
Arc::new(context.entity_delete_mutation.filter_field.clone());
let entity_update_mutation_field_name = Arc::new(format!(
"{}{}",
entity_name, context.entity_update_mutation.mutation_suffix
));
let entity_update_mutation_filter_field_name =
Arc::new(context.entity_update_mutation.filter_field.clone());
let entity_update_mutation_data_field_name =
Arc::new(context.entity_update_mutation.data_field.clone());
let entity_query_field_name = Arc::new(entity_name);
let entity_query_filter_field_name = Arc::new(context.entity_query_field.filters.clone());
Box::new(move |context: &ResolverContext| -> GuardAction {
match context.ctx.data::<AuthUserInfo>() {
Ok(user_info) => {
let subscriber_id = user_info.subscriber_auth.subscriber_id;
let validation_result = match context.field().name() {
field if field == entity_create_one_mutation_field_name.as_str() => context
.args
.try_get(&entity_create_one_mutation_data_field_name)
.and_then(|data_value| {
guard_data_object_accessor_with_subscriber_id(
data_value,
&column_name,
subscriber_id,
)
}),
field if field == entity_create_batch_mutation_field_name.as_str() => context
.args
.try_get(&entity_create_batch_mutation_data_field_name)
.and_then(|data_value| {
data_value.list().and_then(|data_list| {
data_list.iter().try_for_each(|data_item_value| {
guard_data_object_accessor_with_subscriber_id(
data_item_value,
&column_name,
subscriber_id,
)
})
})
}),
field if field == entity_delete_mutation_field_name.as_str() => context
.args
.try_get(&entity_delete_filter_field_name)
.and_then(|filter_value| {
guard_filter_object_accessor_with_subscriber_id(
filter_value,
&column_name,
subscriber_id,
)
}),
field if field == entity_update_mutation_field_name.as_str() => context
.args
.try_get(&entity_update_mutation_filter_field_name)
.and_then(|filter_value| {
guard_filter_object_accessor_with_subscriber_id(
filter_value,
&column_name,
subscriber_id,
)
})
.and_then(|_| {
match context.args.get(&entity_update_mutation_data_field_name) {
Some(data_value) => {
guard_data_object_accessor_with_optional_subscriber_id(
data_value,
&column_name,
subscriber_id,
)
}
None => Ok(()),
}
}),
field if field == entity_query_field_name.as_str() => context
.args
.try_get(&entity_query_filter_field_name)
.and_then(|filter_value| {
guard_filter_object_accessor_with_subscriber_id(
filter_value,
&column_name,
subscriber_id,
)
}),
field => Err(async_graphql::Error::new(format!(
"unsupport graphql field {}",
field
))),
};
match validation_result.map_err(AuthError::GraphQLPermissionError) {
Ok(_) => GuardAction::Allow,
Err(err) => GuardAction::Block(Some(err.to_string())),
}
}
Err(err) => GuardAction::Block(Some(err.message)),
}
})
}

View File

@ -1,5 +1,8 @@
pub mod config; pub mod config;
pub mod query_root; pub mod extention;
pub mod guard;
pub mod schema_root;
pub mod service; pub mod service;
pub mod util;
pub use query_root::schema; pub use schema_root::schema;

View File

@ -1,56 +0,0 @@
use async_graphql::dynamic::*;
use sea_orm::DatabaseConnection;
use seaography::{Builder, BuilderContext};
lazy_static::lazy_static! { static ref CONTEXT : BuilderContext = {
BuilderContext {
..Default::default()
}
}; }
pub fn schema(
database: DatabaseConnection,
depth: Option<usize>,
complexity: Option<usize>,
) -> Result<Schema, SchemaError> {
use crate::models::*;
let mut builder = Builder::new(&CONTEXT, database.clone());
seaography::register_entities!(
builder,
[
bangumi,
downloaders,
downloads,
episodes,
subscribers,
subscription_bangumi,
subscription_episode,
subscriptions
]
);
{
builder.register_enumeration::<downloads::DownloadStatus>();
builder.register_enumeration::<subscriptions::SubscriptionCategory>();
builder.register_enumeration::<downloaders::DownloaderCategory>();
builder.register_enumeration::<downloads::DownloadMime>();
}
let schema = builder.schema_builder();
let schema = if let Some(depth) = depth {
schema.limit_depth(depth)
} else {
schema
};
let schema = if let Some(complexity) = complexity {
schema.limit_complexity(complexity)
} else {
schema
};
schema
.data(database)
.finish()
.inspect_err(|e| tracing::error!(e = ?e))
}

View File

@ -0,0 +1,146 @@
use async_graphql::dynamic::*;
use once_cell::sync::OnceCell;
use sea_orm::{DatabaseConnection, EntityTrait, Iterable};
use seaography::{Builder, BuilderContext, FilterType, FnGuard};
use super::util::{get_entity_column_key, get_entity_key};
use crate::graphql::guard::guard_entity_with_subscriber_id;
static CONTEXT: OnceCell<BuilderContext> = OnceCell::new();
fn restrict_filter_input_for_entity<T>(
context: &mut BuilderContext,
column: &T::Column,
filter_type: Option<FilterType>,
) where
T: EntityTrait,
<T as EntityTrait>::Model: Sync,
{
let key = get_entity_column_key::<T>(context, column);
context.filter_types.overwrites.insert(key, filter_type);
}
fn restrict_subscriber_for_entity<T>(
context: &mut BuilderContext,
column: &T::Column,
entity_guard: impl FnOnce(&BuilderContext, &T::Column) -> FnGuard,
) where
T: EntityTrait,
<T as EntityTrait>::Model: Sync,
{
let entity_key = get_entity_key::<T>(context);
context
.guards
.entity_guards
.insert(entity_key, entity_guard(context, column));
}
pub fn schema(
database: DatabaseConnection,
depth: Option<usize>,
complexity: Option<usize>,
) -> Result<Schema, SchemaError> {
use crate::models::*;
let context = CONTEXT.get_or_init(|| {
let mut context = BuilderContext::default();
restrict_subscriber_for_entity::<bangumi::Entity>(
&mut context,
&bangumi::Column::SubscriberId,
guard_entity_with_subscriber_id::<bangumi::Entity>,
);
restrict_subscriber_for_entity::<downloaders::Entity>(
&mut context,
&downloaders::Column::SubscriberId,
guard_entity_with_subscriber_id::<downloaders::Entity>,
);
restrict_subscriber_for_entity::<downloads::Entity>(
&mut context,
&downloads::Column::SubscriberId,
guard_entity_with_subscriber_id::<downloads::Entity>,
);
restrict_subscriber_for_entity::<episodes::Entity>(
&mut context,
&episodes::Column::SubscriberId,
guard_entity_with_subscriber_id::<episodes::Entity>,
);
restrict_subscriber_for_entity::<subscriptions::Entity>(
&mut context,
&subscriptions::Column::SubscriberId,
guard_entity_with_subscriber_id::<subscriptions::Entity>,
);
restrict_subscriber_for_entity::<subscribers::Entity>(
&mut context,
&subscribers::Column::Id,
guard_entity_with_subscriber_id::<subscribers::Entity>,
);
restrict_subscriber_for_entity::<subscription_bangumi::Entity>(
&mut context,
&subscription_bangumi::Column::SubscriberId,
guard_entity_with_subscriber_id::<subscription_bangumi::Entity>,
);
restrict_subscriber_for_entity::<subscription_episode::Entity>(
&mut context,
&subscription_episode::Column::SubscriberId,
guard_entity_with_subscriber_id::<subscription_episode::Entity>,
);
for column in subscribers::Column::iter() {
if !matches!(column, subscribers::Column::Id) {
restrict_filter_input_for_entity::<subscribers::Entity>(
&mut context,
&column,
None,
);
}
}
context
});
let mut builder = Builder::new(context, database.clone());
{
builder.register_entity::<subscribers::Entity>(
<subscribers::RelatedEntity as sea_orm::Iterable>::iter()
.map(|rel| seaography::RelationBuilder::get_relation(&rel, builder.context))
.collect(),
);
builder = builder.register_entity_dataloader_one_to_one(subscribers::Entity, tokio::spawn);
builder = builder.register_entity_dataloader_one_to_many(subscribers::Entity, tokio::spawn);
}
seaography::register_entities!(
builder,
[
bangumi,
downloaders,
downloads,
episodes,
subscription_bangumi,
subscription_episode,
subscriptions
]
);
{
builder.register_enumeration::<downloads::DownloadStatus>();
builder.register_enumeration::<subscriptions::SubscriptionCategory>();
builder.register_enumeration::<downloaders::DownloaderCategory>();
builder.register_enumeration::<downloads::DownloadMime>();
}
let schema = builder.schema_builder();
let schema = if let Some(depth) = depth {
schema.limit_depth(depth)
} else {
schema
};
let schema = if let Some(complexity) = complexity {
schema.limit_complexity(complexity)
} else {
schema
};
schema
.data(database)
// .extension(GraphqlAuthExtension)
.finish()
.inspect_err(|e| tracing::error!(e = ?e))
}

View File

@ -4,7 +4,7 @@ use loco_rs::app::{AppContext, Initializer};
use once_cell::sync::OnceCell; use once_cell::sync::OnceCell;
use sea_orm::DatabaseConnection; use sea_orm::DatabaseConnection;
use super::{config::AppGraphQLConfig, query_root}; use super::{config::AppGraphQLConfig, schema_root};
use crate::config::AppConfigExt; use crate::config::AppConfigExt;
static APP_GRAPHQL_SERVICE: OnceCell<AppGraphQLService> = OnceCell::new(); static APP_GRAPHQL_SERVICE: OnceCell<AppGraphQLService> = OnceCell::new();
@ -16,7 +16,7 @@ pub struct AppGraphQLService {
impl AppGraphQLService { impl AppGraphQLService {
pub fn new(config: AppGraphQLConfig, db: DatabaseConnection) -> Result<Self, SchemaError> { pub fn new(config: AppGraphQLConfig, db: DatabaseConnection) -> Result<Self, SchemaError> {
let schema = query_root::schema(db, config.depth_limit, config.complexity_limit)?; let schema = schema_root::schema(db, config.depth_limit, config.complexity_limit)?;
Ok(Self { schema }) Ok(Self { schema })
} }

View File

@ -0,0 +1,30 @@
use sea_orm::{EntityName, EntityTrait, IdenStatic};
use seaography::BuilderContext;
pub fn get_entity_key<T>(context: &BuilderContext) -> String
where
T: EntityTrait,
<T as EntityTrait>::Model: Sync,
{
context.entity_object.type_name.as_ref()(<T as EntityName>::table_name(&T::default()))
}
pub fn get_column_key<T>(context: &BuilderContext, column: &T::Column) -> String
where
T: EntityTrait,
<T as EntityTrait>::Model: Sync,
{
let entity_name = get_entity_key::<T>(context);
context.entity_object.column_name.as_ref()(&entity_name, column.as_str())
}
pub fn get_entity_column_key<T>(context: &BuilderContext, column: &T::Column) -> String
where
T: EntityTrait,
<T as EntityTrait>::Model: Sync,
{
let entity_name = get_entity_key::<T>(context);
let column_name = get_column_key::<T>(context, column);
format!("{}.{}", &entity_name, &column_name)
}

View File

@ -16,7 +16,6 @@ pub enum GeneralIds {
pub enum Subscribers { pub enum Subscribers {
Table, Table,
Id, Id,
Pid,
DisplayName, DisplayName,
DownloaderId, DownloaderId,
BangumiConf, BangumiConf,
@ -58,6 +57,7 @@ pub enum Bangumi {
pub enum SubscriptionBangumi { pub enum SubscriptionBangumi {
Table, Table,
Id, Id,
SubscriberId,
SubscriptionId, SubscriptionId,
BangumiId, BangumiId,
} }
@ -90,6 +90,7 @@ pub enum Episodes {
pub enum SubscriptionEpisode { pub enum SubscriptionEpisode {
Table, Table,
Id, Id,
SubscriberId,
SubscriptionId, SubscriptionId,
EpisodeId, EpisodeId,
} }
@ -130,7 +131,6 @@ pub enum Auth {
Id, Id,
Pid, Pid,
SubscriberId, SubscriberId,
AvatarUrl,
AuthType, AuthType,
} }

View File

@ -24,7 +24,6 @@ impl MigrationTrait for Migration {
.create_table( .create_table(
table_auto(Subscribers::Table) table_auto(Subscribers::Table)
.col(pk_auto(Subscribers::Id)) .col(pk_auto(Subscribers::Id))
.col(string_len_uniq(Subscribers::Pid, 64))
.col(string(Subscribers::DisplayName)) .col(string(Subscribers::DisplayName))
.col(json_binary_null(Subscribers::BangumiConf)) .col(json_binary_null(Subscribers::BangumiConf))
.to_owned(), .to_owned(),
@ -42,8 +41,8 @@ impl MigrationTrait for Migration {
.exec_stmt( .exec_stmt(
Query::insert() Query::insert()
.into_table(Subscribers::Table) .into_table(Subscribers::Table)
.columns([Subscribers::Pid, Subscribers::DisplayName]) .columns([Subscribers::DisplayName])
.values_panic([SEED_SUBSCRIBER.into(), SEED_SUBSCRIBER.into()]) .values_panic([SEED_SUBSCRIBER.into()])
.to_owned(), .to_owned(),
) )
.await?; .await?;
@ -159,6 +158,7 @@ impl MigrationTrait for Migration {
.create_table( .create_table(
table_auto(SubscriptionBangumi::Table) table_auto(SubscriptionBangumi::Table)
.col(pk_auto(SubscriptionBangumi::Id)) .col(pk_auto(SubscriptionBangumi::Id))
.col(integer(SubscriptionBangumi::SubscriberId))
.col(integer(SubscriptionBangumi::SubscriptionId)) .col(integer(SubscriptionBangumi::SubscriptionId))
.col(integer(SubscriptionBangumi::BangumiId)) .col(integer(SubscriptionBangumi::BangumiId))
.foreign_key( .foreign_key(
@ -193,6 +193,17 @@ impl MigrationTrait for Migration {
) )
.await?; .await?;
manager
.create_index(
Index::create()
.if_not_exists()
.name("index_subscription_bangumi_subscriber_id")
.table(SubscriptionBangumi::Table)
.col(SubscriptionBangumi::SubscriberId)
.to_owned(),
)
.await?;
manager manager
.create_table( .create_table(
table_auto(Episodes::Table) table_auto(Episodes::Table)
@ -268,6 +279,7 @@ impl MigrationTrait for Migration {
.col(pk_auto(SubscriptionEpisode::Id)) .col(pk_auto(SubscriptionEpisode::Id))
.col(integer(SubscriptionEpisode::SubscriptionId)) .col(integer(SubscriptionEpisode::SubscriptionId))
.col(integer(SubscriptionEpisode::EpisodeId)) .col(integer(SubscriptionEpisode::EpisodeId))
.col(integer(SubscriptionEpisode::SubscriberId))
.foreign_key( .foreign_key(
ForeignKey::create() ForeignKey::create()
.name("fk_subscription_episode_subscription_id") .name("fk_subscription_episode_subscription_id")
@ -300,10 +312,31 @@ impl MigrationTrait for Migration {
) )
.await?; .await?;
manager
.create_index(
Index::create()
.if_not_exists()
.name("index_subscription_episode_subscriber_id")
.table(SubscriptionEpisode::Table)
.col(SubscriptionEpisode::SubscriberId)
.to_owned(),
)
.await?;
Ok(()) Ok(())
} }
async fn down(&self, manager: &SchemaManager) -> Result<(), DbErr> { async fn down(&self, manager: &SchemaManager) -> Result<(), DbErr> {
manager
.drop_index(
Index::drop()
.if_exists()
.name("index_subscription_episode_subscriber_id")
.table(SubscriptionBangumi::Table)
.to_owned(),
)
.await?;
manager manager
.drop_table(Table::drop().table(SubscriptionEpisode::Table).to_owned()) .drop_table(Table::drop().table(SubscriptionEpisode::Table).to_owned())
.await?; .await?;
@ -316,6 +349,16 @@ impl MigrationTrait for Migration {
.drop_table(Table::drop().table(Episodes::Table).to_owned()) .drop_table(Table::drop().table(Episodes::Table).to_owned())
.await?; .await?;
manager
.drop_index(
Index::drop()
.if_exists()
.name("index_subscription_bangumi_subscriber_id")
.table(SubscriptionBangumi::Table)
.to_owned(),
)
.await?;
manager manager
.drop_table(Table::drop().table(SubscriptionBangumi::Table).to_owned()) .drop_table(Table::drop().table(SubscriptionBangumi::Table).to_owned())
.await?; .await?;

View File

@ -2,13 +2,13 @@ use sea_orm_migration::{prelude::*, schema::*};
use crate::{ use crate::{
migrations::defs::{CustomSchemaManagerExt, Downloaders, GeneralIds, Subscribers}, migrations::defs::{CustomSchemaManagerExt, Downloaders, GeneralIds, Subscribers},
models::{downloaders::DownloaderCategoryEnum, prelude::DownloaderCategory}, models::downloaders::{DownloaderCategory, DownloaderCategoryEnum},
}; };
#[derive(DeriveMigrationName)] #[derive(DeriveMigrationName)]
pub struct Migration; pub struct Migration;
#[async_trait] #[async_trait::async_trait]
impl MigrationTrait for Migration { impl MigrationTrait for Migration {
async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> { async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> {
create_postgres_enum_for_active_enum!( create_postgres_enum_for_active_enum!(

View File

@ -34,7 +34,6 @@ impl MigrationTrait for Migration {
AuthTypeEnum, AuthTypeEnum,
AuthType::iden_values(), AuthType::iden_values(),
)) ))
.col(string_null(Auth::AvatarUrl))
.col(integer(Auth::SubscriberId)) .col(integer(Auth::SubscriberId))
.foreign_key( .foreign_key(
ForeignKey::create() ForeignKey::create()
@ -66,6 +65,20 @@ impl MigrationTrait for Migration {
.create_postgres_auto_update_ts_trigger_for_col(Auth::Table, GeneralIds::UpdatedAt) .create_postgres_auto_update_ts_trigger_for_col(Auth::Table, GeneralIds::UpdatedAt)
.await?; .await?;
let seed_subscriber_id = manager
.get_connection()
.query_one(
manager.get_database_backend().build(
Query::select()
.column(Subscribers::Id)
.from(Subscribers::Table)
.limit(1),
),
)
.await?
.ok_or_else(|| DbErr::RecordNotFound(String::from("seed subscriber not found")))?
.try_get_by_index::<i32>(0)?;
manager manager
.exec_stmt( .exec_stmt(
Query::insert() Query::insert()
@ -74,7 +87,7 @@ impl MigrationTrait for Migration {
.values_panic([ .values_panic([
SEED_SUBSCRIBER.into(), SEED_SUBSCRIBER.into(),
SimpleExpr::from(AuthType::Basic).as_enum(AuthTypeEnum), SimpleExpr::from(AuthType::Basic).as_enum(AuthTypeEnum),
1.into(), seed_subscriber_id.into(),
]) ])
.to_owned(), .to_owned(),
) )

View File

@ -5,6 +5,7 @@ pub use sea_orm_migration::prelude::*;
pub mod defs; pub mod defs;
pub mod m20220101_000001_init; pub mod m20220101_000001_init;
pub mod m20240224_082543_add_downloads; pub mod m20240224_082543_add_downloads;
pub mod m20240225_060853_subscriber_add_downloader;
pub mod m20241231_000001_auth; pub mod m20241231_000001_auth;
pub struct Migrator; pub struct Migrator;
@ -15,6 +16,7 @@ impl MigratorTrait for Migrator {
vec![ vec![
Box::new(m20220101_000001_init::Migration), Box::new(m20220101_000001_init::Migration),
Box::new(m20240224_082543_add_downloads::Migration), Box::new(m20240224_082543_add_downloads::Migration),
Box::new(m20240225_060853_subscriber_add_downloader::Migration),
Box::new(m20241231_000001_auth::Migration), Box::new(m20241231_000001_auth::Migration),
] ]
} }

View File

@ -1,7 +1,13 @@
use async_trait::async_trait; use async_trait::async_trait;
use sea_orm::entity::prelude::*; use loco_rs::{
app::AppContext,
model::{ModelError, ModelResult},
};
use sea_orm::{Set, TransactionTrait, entity::prelude::*};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use super::subscribers::{self, SEED_SUBSCRIBER};
#[derive( #[derive(
Clone, Debug, PartialEq, Eq, EnumIter, DeriveActiveEnum, DeriveDisplay, Serialize, Deserialize, Clone, Debug, PartialEq, Eq, EnumIter, DeriveActiveEnum, DeriveDisplay, Serialize, Deserialize,
)] )]
@ -17,14 +23,16 @@ pub enum AuthType {
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, DeriveEntityModel)] #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, DeriveEntityModel)]
#[sea_orm(table_name = "auth")] #[sea_orm(table_name = "auth")]
pub struct Model { pub struct Model {
#[sea_orm(default_expr = "Expr::current_timestamp()")]
pub created_at: DateTime, pub created_at: DateTime,
#[sea_orm(default_expr = "Expr::current_timestamp()")]
pub updated_at: DateTime, pub updated_at: DateTime,
#[sea_orm(primary_key)] #[sea_orm(primary_key)]
pub id: i32, pub id: i32,
#[sea_orm(unique)]
pub pid: String, pub pid: String,
pub subscriber_id: i32, pub subscriber_id: i32,
pub auth_type: AuthType, pub auth_type: AuthType,
pub avatar_url: Option<String>,
} }
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
@ -47,3 +55,52 @@ impl Related<super::subscribers::Entity> for Entity {
#[async_trait] #[async_trait]
impl ActiveModelBehavior for ActiveModel {} impl ActiveModelBehavior for ActiveModel {}
impl Model {
pub async fn find_by_pid(ctx: &AppContext, pid: &str) -> ModelResult<Self> {
let db = &ctx.db;
let subscriber_auth = Entity::find()
.filter(Column::Pid.eq(pid))
.one(db)
.await?
.ok_or_else(|| ModelError::EntityNotFound)?;
Ok(subscriber_auth)
}
pub async fn create_from_oidc(ctx: &AppContext, sub: String) -> ModelResult<Self> {
let db = &ctx.db;
let txn = db.begin().await?;
let subscriber_id = if let Some(seed_subscriber_id) = Entity::find()
.filter(
Column::AuthType
.eq(AuthType::Basic)
.and(Column::Pid.eq(SEED_SUBSCRIBER)),
)
.one(&txn)
.await?
.map(|m| m.subscriber_id)
{
seed_subscriber_id
} else {
let new_subscriber = subscribers::ActiveModel {
..Default::default()
};
let new_subscriber: subscribers::Model = new_subscriber.save(&txn).await?.try_into()?;
new_subscriber.id
};
let new_item = ActiveModel {
pid: Set(sub),
auth_type: Set(AuthType::Oidc),
subscriber_id: Set(subscriber_id),
..Default::default()
};
let new_item: Model = new_item.save(&txn).await?.try_into()?;
Ok(new_item)
}
}

View File

@ -1,7 +1,7 @@
use async_graphql::SimpleObject; use async_graphql::SimpleObject;
use async_trait::async_trait; use async_trait::async_trait;
use loco_rs::app::AppContext; use loco_rs::app::AppContext;
use sea_orm::{entity::prelude::*, sea_query::OnConflict, ActiveValue, FromJsonQueryResult}; use sea_orm::{ActiveValue, FromJsonQueryResult, entity::prelude::*, sea_query::OnConflict};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use super::subscription_bangumi; use super::subscription_bangumi;
@ -9,7 +9,6 @@ use super::subscription_bangumi;
#[derive( #[derive(
Clone, Debug, PartialEq, Eq, Serialize, Deserialize, FromJsonQueryResult, SimpleObject, Clone, Debug, PartialEq, Eq, Serialize, Deserialize, FromJsonQueryResult, SimpleObject,
)] )]
#[graphql(name = "BangumiFilter")]
pub struct BangumiFilter { pub struct BangumiFilter {
pub name: Option<Vec<String>>, pub name: Option<Vec<String>>,
pub group: Option<Vec<String>>, pub group: Option<Vec<String>>,
@ -18,7 +17,6 @@ pub struct BangumiFilter {
#[derive( #[derive(
Clone, Debug, PartialEq, Eq, Serialize, Deserialize, FromJsonQueryResult, SimpleObject, Clone, Debug, PartialEq, Eq, Serialize, Deserialize, FromJsonQueryResult, SimpleObject,
)] )]
#[graphql(name = "BangumiExtra")]
pub struct BangumiExtra { pub struct BangumiExtra {
pub name_zh: Option<String>, pub name_zh: Option<String>,
pub s_name_zh: Option<String>, pub s_name_zh: Option<String>,
@ -30,14 +28,14 @@ pub struct BangumiExtra {
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize, SimpleObject)] #[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize, SimpleObject)]
#[sea_orm(table_name = "bangumi")] #[sea_orm(table_name = "bangumi")]
#[graphql(name = "Bangumi")]
pub struct Model { pub struct Model {
#[sea_orm(default_expr = "Expr::current_timestamp()")]
pub created_at: DateTime, pub created_at: DateTime,
#[sea_orm(default_expr = "Expr::current_timestamp()")]
pub updated_at: DateTime, pub updated_at: DateTime,
#[sea_orm(primary_key)] #[sea_orm(primary_key)]
pub id: i32, pub id: i32,
pub mikan_bangumi_id: Option<String>, pub mikan_bangumi_id: Option<String>,
#[graphql(default_with = "default_subscriber_id")]
pub subscriber_id: i32, pub subscriber_id: i32,
pub display_name: String, pub display_name: String,
pub raw_name: String, pub raw_name: String,

View File

@ -22,9 +22,9 @@ pub enum DownloaderCategory {
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)] #[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)]
#[sea_orm(table_name = "downloaders")] #[sea_orm(table_name = "downloaders")]
pub struct Model { pub struct Model {
#[sea_orm(column_type = "Timestamp")] #[sea_orm(default_expr = "Expr::current_timestamp()")]
pub created_at: DateTime, pub created_at: DateTime,
#[sea_orm(column_type = "Timestamp")] #[sea_orm(default_expr = "Expr::current_timestamp()")]
pub updated_at: DateTime, pub updated_at: DateTime,
#[sea_orm(primary_key)] #[sea_orm(primary_key)]
pub id: i32, pub id: i32,

View File

@ -38,7 +38,9 @@ pub enum DownloadMime {
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)] #[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)]
#[sea_orm(table_name = "downloads")] #[sea_orm(table_name = "downloads")]
pub struct Model { pub struct Model {
#[sea_orm(default_expr = "Expr::current_timestamp()")]
pub created_at: DateTime, pub created_at: DateTime,
#[sea_orm(default_expr = "Expr::current_timestamp()")]
pub updated_at: DateTime, pub updated_at: DateTime,
#[sea_orm(primary_key)] #[sea_orm(primary_key)]
pub id: i32, pub id: i32,

View File

@ -2,14 +2,14 @@ use std::sync::Arc;
use async_trait::async_trait; use async_trait::async_trait;
use loco_rs::app::AppContext; use loco_rs::app::AppContext;
use sea_orm::{entity::prelude::*, sea_query::OnConflict, ActiveValue, FromJsonQueryResult}; use sea_orm::{ActiveValue, FromJsonQueryResult, entity::prelude::*, sea_query::OnConflict};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use super::{bangumi, query::InsertManyReturningExt, subscription_episode}; use super::{bangumi, query::InsertManyReturningExt, subscription_episode};
use crate::{ use crate::{
app::AppContextExt, app::AppContextExt,
extract::{ extract::{
mikan::{build_mikan_episode_homepage, MikanEpisodeMeta}, mikan::{MikanEpisodeMeta, build_mikan_episode_homepage},
rawname::parse_episode_meta_from_raw_name, rawname::parse_episode_meta_from_raw_name,
}, },
}; };
@ -27,7 +27,9 @@ pub struct EpisodeExtra {
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)] #[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)]
#[sea_orm(table_name = "episodes")] #[sea_orm(table_name = "episodes")]
pub struct Model { pub struct Model {
#[sea_orm(default_expr = "Expr::current_timestamp()")]
pub created_at: DateTime, pub created_at: DateTime,
#[sea_orm(default_expr = "Expr::current_timestamp()")]
pub updated_at: DateTime, pub updated_at: DateTime,
#[sea_orm(primary_key)] #[sea_orm(primary_key)]
pub id: i32, pub id: i32,
@ -135,6 +137,7 @@ pub struct MikanEpsiodeCreation {
impl Model { impl Model {
pub async fn add_episodes( pub async fn add_episodes(
ctx: &AppContext, ctx: &AppContext,
subscriber_id: i32,
subscription_id: i32, subscription_id: i32,
creations: impl IntoIterator<Item = MikanEpsiodeCreation>, creations: impl IntoIterator<Item = MikanEpsiodeCreation>,
) -> color_eyre::eyre::Result<()> { ) -> color_eyre::eyre::Result<()> {
@ -162,6 +165,7 @@ impl Model {
let insert_subscription_episode_links = inserted_episodes.into_iter().map(|episode_id| { let insert_subscription_episode_links = inserted_episodes.into_iter().map(|episode_id| {
subscription_episode::ActiveModel::from_subscription_and_episode( subscription_episode::ActiveModel::from_subscription_and_episode(
subscriber_id,
subscription_id, subscription_id,
episode_id, episode_id,
) )

View File

@ -4,7 +4,7 @@ use loco_rs::{
app::AppContext, app::AppContext,
model::{ModelError, ModelResult}, model::{ModelError, ModelResult},
}; };
use sea_orm::{entity::prelude::*, ActiveValue, FromJsonQueryResult, TransactionTrait}; use sea_orm::{ActiveValue, FromJsonQueryResult, TransactionTrait, entity::prelude::*};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
pub const SEED_SUBSCRIBER: &str = "konobangu"; pub const SEED_SUBSCRIBER: &str = "konobangu";
@ -16,15 +16,15 @@ pub struct SubscriberBangumiConfig {
pub leading_group_tag: Option<bool>, pub leading_group_tag: Option<bool>,
} }
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)] #[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize, SimpleObject)]
#[sea_orm(table_name = "subscribers")] #[sea_orm(table_name = "subscribers")]
pub struct Model { pub struct Model {
#[sea_orm(default_expr = "Expr::current_timestamp()")]
pub created_at: DateTime, pub created_at: DateTime,
#[sea_orm(default_expr = "Expr::current_timestamp()")]
pub updated_at: DateTime, pub updated_at: DateTime,
#[sea_orm(primary_key)] #[sea_orm(primary_key)]
pub id: i32, pub id: i32,
#[sea_orm(unique)]
pub pid: String,
pub display_name: String, pub display_name: String,
pub bangumi_conf: Option<SubscriberBangumiConfig>, pub bangumi_conf: Option<SubscriberBangumiConfig>,
} }
@ -91,59 +91,22 @@ pub struct SubscriberIdParams {
} }
#[async_trait] #[async_trait]
impl ActiveModelBehavior for ActiveModel { impl ActiveModelBehavior for ActiveModel {}
async fn before_save<C>(self, _db: &C, insert: bool) -> Result<Self, DbErr>
where
C: ConnectionTrait,
{
if insert {
let mut this = self;
this.pid = ActiveValue::Set(Uuid::new_v4().to_string());
Ok(this)
} else {
Ok(self)
}
}
}
impl Model { impl Model {
pub async fn find_by_pid(ctx: &AppContext, pid: &str) -> ModelResult<Self> { pub async fn find_seed_subscriber_id(ctx: &AppContext) -> ModelResult<i32> {
let db = &ctx.db; let subscriber_auth = crate::models::auth::Model::find_by_pid(ctx, SEED_SUBSCRIBER).await?;
let parse_uuid = Uuid::parse_str(pid).map_err(|e| ModelError::Any(e.into()))?; Ok(subscriber_auth.subscriber_id)
let subscriber = Entity::find()
.filter(Column::Pid.eq(parse_uuid))
.one(db)
.await?;
subscriber.ok_or_else(|| ModelError::EntityNotFound)
} }
pub async fn find_by_id(ctx: &AppContext, id: i32) -> ModelResult<Self> { pub async fn find_by_id(ctx: &AppContext, id: i32) -> ModelResult<Self> {
let db = &ctx.db; let db = &ctx.db;
let subscriber = Entity::find_by_id(id).one(db).await?; let subscriber = Entity::find_by_id(id)
subscriber.ok_or_else(|| ModelError::EntityNotFound) .one(db)
} .await?
.ok_or_else(|| ModelError::EntityNotFound)?;
pub async fn find_pid_by_id_with_cache( Ok(subscriber)
ctx: &AppContext,
id: i32,
) -> color_eyre::eyre::Result<String> {
let db = &ctx.db;
let cache = &ctx.cache;
let pid = cache
.get_or_insert(&format!("subscriber-id2pid::{}", id), async {
let subscriber = Entity::find_by_id(id)
.one(db)
.await?
.ok_or_else(|| loco_rs::Error::string(&format!("No such pid for id {}", id)))?;
Ok(subscriber.pid)
})
.await?;
Ok(pid)
}
pub async fn find_root(ctx: &AppContext) -> ModelResult<Self> {
Self::find_by_pid(ctx, SEED_SUBSCRIBER).await
} }
pub async fn create_root(ctx: &AppContext) -> ModelResult<Self> { pub async fn create_root(ctx: &AppContext) -> ModelResult<Self> {
@ -152,7 +115,6 @@ impl Model {
let user = ActiveModel { let user = ActiveModel {
display_name: ActiveValue::set(SEED_SUBSCRIBER.to_string()), display_name: ActiveValue::set(SEED_SUBSCRIBER.to_string()),
pid: ActiveValue::set(SEED_SUBSCRIBER.to_string()),
..Default::default() ..Default::default()
} }
.insert(&txn) .insert(&txn)

View File

@ -1,5 +1,5 @@
use async_trait::async_trait; use async_trait::async_trait;
use sea_orm::{entity::prelude::*, ActiveValue}; use sea_orm::{ActiveValue, entity::prelude::*};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)] #[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)]
@ -7,6 +7,7 @@ use serde::{Deserialize, Serialize};
pub struct Model { pub struct Model {
#[sea_orm(primary_key)] #[sea_orm(primary_key)]
pub id: i32, pub id: i32,
pub subscriber_id: i32,
pub subscription_id: i32, pub subscription_id: i32,
pub bangumi_id: i32, pub bangumi_id: i32,
} }
@ -55,8 +56,13 @@ pub enum RelatedEntity {
impl ActiveModelBehavior for ActiveModel {} impl ActiveModelBehavior for ActiveModel {}
impl ActiveModel { impl ActiveModel {
pub fn from_subscription_and_bangumi(subscription_id: i32, bangumi_id: i32) -> Self { pub fn from_subscription_and_bangumi(
subscriber_id: i32,
subscription_id: i32,
bangumi_id: i32,
) -> Self {
Self { Self {
subscriber_id: ActiveValue::Set(subscriber_id),
subscription_id: ActiveValue::Set(subscription_id), subscription_id: ActiveValue::Set(subscription_id),
bangumi_id: ActiveValue::Set(bangumi_id), bangumi_id: ActiveValue::Set(bangumi_id),
..Default::default() ..Default::default()

View File

@ -1,5 +1,5 @@
use async_trait::async_trait; use async_trait::async_trait;
use sea_orm::{entity::prelude::*, ActiveValue}; use sea_orm::{ActiveValue, entity::prelude::*};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)] #[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)]
@ -7,6 +7,7 @@ use serde::{Deserialize, Serialize};
pub struct Model { pub struct Model {
#[sea_orm(primary_key)] #[sea_orm(primary_key)]
pub id: i32, pub id: i32,
pub subscriber_id: i32,
pub subscription_id: i32, pub subscription_id: i32,
pub episode_id: i32, pub episode_id: i32,
} }
@ -55,8 +56,13 @@ pub enum RelatedEntity {
impl ActiveModelBehavior for ActiveModel {} impl ActiveModelBehavior for ActiveModel {}
impl ActiveModel { impl ActiveModel {
pub fn from_subscription_and_episode(subscription_id: i32, episode_id: i32) -> Self { pub fn from_subscription_and_episode(
subscriber_id: i32,
subscription_id: i32,
episode_id: i32,
) -> Self {
Self { Self {
subscriber_id: ActiveValue::Set(subscriber_id),
subscription_id: ActiveValue::Set(subscription_id), subscription_id: ActiveValue::Set(subscription_id),
episode_id: ActiveValue::Set(episode_id), episode_id: ActiveValue::Set(episode_id),
..Default::default() ..Default::default()

View File

@ -3,7 +3,7 @@ use std::{collections::HashSet, sync::Arc};
use async_trait::async_trait; use async_trait::async_trait;
use itertools::Itertools; use itertools::Itertools;
use loco_rs::app::AppContext; use loco_rs::app::AppContext;
use sea_orm::{entity::prelude::*, ActiveValue}; use sea_orm::{ActiveValue, entity::prelude::*};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use super::{bangumi, episodes, query::filter_values_in}; use super::{bangumi, episodes, query::filter_values_in};
@ -15,8 +15,8 @@ use crate::{
parse_mikan_bangumi_meta_from_mikan_homepage, parse_mikan_bangumi_meta_from_mikan_homepage,
parse_mikan_episode_meta_from_mikan_homepage, parse_mikan_rss_channel_from_rss_link, parse_mikan_episode_meta_from_mikan_homepage, parse_mikan_rss_channel_from_rss_link,
web_parser::{ web_parser::{
parse_mikan_bangumi_poster_from_origin_poster_src_with_cache,
MikanBangumiPosterMeta, MikanBangumiPosterMeta,
parse_mikan_bangumi_poster_from_origin_poster_src_with_cache,
}, },
}, },
rawname::extract_season_from_title_body, rawname::extract_season_from_title_body,
@ -43,9 +43,9 @@ pub enum SubscriptionCategory {
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)] #[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)]
#[sea_orm(table_name = "subscriptions")] #[sea_orm(table_name = "subscriptions")]
pub struct Model { pub struct Model {
#[sea_orm(column_type = "Timestamp")] #[sea_orm(default_expr = "Expr::current_timestamp()")]
pub created_at: DateTime, pub created_at: DateTime,
#[sea_orm(column_type = "Timestamp")] #[sea_orm(default_expr = "Expr::current_timestamp()")]
pub updated_at: DateTime, pub updated_at: DateTime,
#[sea_orm(primary_key)] #[sea_orm(primary_key)]
pub id: i32, pub id: i32,
@ -325,6 +325,7 @@ impl Model {
); );
episodes::Model::add_episodes( episodes::Model::add_episodes(
ctx, ctx,
self.subscriber_id,
self.id, self.id,
new_ep_metas.into_iter().map(|item| MikanEpsiodeCreation { new_ep_metas.into_iter().map(|item| MikanEpsiodeCreation {
episode: item, episode: item,

View File

@ -3,17 +3,11 @@ use serde::{Deserialize, Serialize};
use crate::models::subscribers; use crate::models::subscribers;
#[derive(Debug, Deserialize, Serialize)] #[derive(Debug, Deserialize, Serialize)]
pub struct CurrentResponse { pub struct CurrentResponse {}
pub pid: String,
pub display_name: String,
}
impl CurrentResponse { impl CurrentResponse {
#[must_use] #[must_use]
pub fn new(user: &subscribers::Model) -> Self { pub fn new(_user: &subscribers::Model) -> Self {
Self { Self {}
pid: user.pid.to_string(),
display_name: user.display_name.to_string(),
}
} }
} }

View File

@ -1,16 +1,16 @@
use insta::assert_debug_snapshot; // use insta::assert_debug_snapshot;
use loco_rs::testing; // use loco_rs::testing;
use recorder::{app::App, models::subscribers::Model}; // use recorder::{app::App, models::subscribers::Model};
use serial_test::serial; use serial_test::serial;
macro_rules! configure_insta { // macro_rules! configure_insta {
($($expr:expr),*) => { // ($($expr:expr),*) => {
let mut settings = insta::Settings::clone_current(); // let mut settings = insta::Settings::clone_current();
settings.set_prepend_module_to_snapshot(false); // settings.set_prepend_module_to_snapshot(false);
settings.set_snapshot_suffix("users"); // settings.set_snapshot_suffix("users");
let _guard = settings.bind_to_scope(); // let _guard = settings.bind_to_scope();
}; // };
} // }
#[tokio::test] #[tokio::test]
#[serial] #[serial]