diff --git a/src/bootstrap/middleware/cors.rs b/src/bootstrap/middleware/cors.rs index ccbb9face..488d11607 100644 --- a/src/bootstrap/middleware/cors.rs +++ b/src/bootstrap/middleware/cors.rs @@ -1,28 +1,108 @@ +use std::{collections::HashSet, sync::Arc, task::Poll}; + use axum_starter::{prepare, PrepareMiddlewareEffect}; use http::{HeaderValue, Method}; -use tower_http::cors::CorsLayer; +use tower::{Layer, Service}; +use tower_http::cors::{Any, Cors, CorsLayer}; pub trait CorsConfigTrait { fn allow_origins(&self) -> Vec; fn allow_methods(&self) -> Vec; + + fn bypass_paths(&self) -> Arc>; } -#[prepare(PrepareCors)] -pub fn prepare_cors(cfg: &T) -> CorsMiddleware { - CorsMiddleware( - CorsLayer::new() - .allow_origin(cfg.allow_origins()) - .allow_methods(cfg.allow_methods()), - ) +#[derive(Clone)] +pub struct ConditionCors { + bypass_cors: Cors, + default_cors: Cors, + bypass_paths: Arc>, } -pub struct CorsMiddleware(CorsLayer); +impl Service> for ConditionCors +where + S: Service, Response = http::Response>, + Req: Default, + Resp: Default, +{ + type Error = as Service>>::Error; + type Future = as Service>>::Future; + type Response = as Service>>::Response; + + fn poll_ready( + &mut self, cx: &mut std::task::Context<'_>, + ) -> Poll> { + match ( + self.bypass_cors.poll_ready(cx), + self.default_cors.poll_ready(cx), + ) { + (Poll::Ready(bypass), Poll::Ready(default)) => { + Poll::Ready(bypass.and(default)) + } + (Poll::Ready(_), Poll::Pending) => Poll::Pending, + (Poll::Pending, Poll::Ready(_)) => Poll::Pending, + (Poll::Pending, Poll::Pending) => Poll::Pending, + } + } + + fn call(&mut self, req: http::Request) -> Self::Future { + let uri = req.uri().path(); + if self.bypass_paths.contains(uri) { + self.bypass_cors.call(req) + } + else { + self.default_cors.call(req) + } + } +} -impl PrepareMiddlewareEffect for CorsMiddleware { - type Middleware = CorsLayer; +#[derive(Clone)] +pub struct ConditionCorsLayer { + bypass_cors: CorsLayer, + default_cors: CorsLayer, + bypass_paths: Arc>, +} + +impl ConditionCorsLayer { + fn from_config(config: &impl CorsConfigTrait) -> Self { + Self { + bypass_cors: CorsLayer::new() + .allow_origin(Any) + .allow_methods(Any), + default_cors: CorsLayer::new() + .allow_origin(config.allow_origins()) + .allow_methods(config.allow_methods()), + bypass_paths: config.bypass_paths(), + } + } +} + +impl Layer for ConditionCorsLayer { + type Service = ConditionCors; + + fn layer(&self, inner: S) -> Self::Service { + ConditionCors { + bypass_cors: self.bypass_cors.layer(inner.clone()), + default_cors: self.default_cors.layer(inner), + bypass_paths: Arc::clone(&self.bypass_paths), + } + } +} + +impl PrepareMiddlewareEffect for ConditionCorsLayer { + type Middleware = ConditionCorsLayer; fn take(self, _: &mut axum_starter::StateCollector) -> Self::Middleware { - self.0 + self } } + +pub type ConditionCorsEffect = ConditionCorsLayer; + +#[prepare(ConditionCorsPrepare)] +pub fn prepare_condition_cors( + config: &T, +) -> ConditionCorsEffect { + ConditionCorsEffect::from_config(config) +} diff --git a/src/configs/cors_config.rs b/src/configs/cors_config.rs index da4735c64..b96810645 100644 --- a/src/configs/cors_config.rs +++ b/src/configs/cors_config.rs @@ -1,3 +1,5 @@ +use std::{collections::HashSet, sync::Arc}; + use http::{HeaderValue, Method}; use serde::{Deserialize, Deserializer}; @@ -9,12 +11,18 @@ pub struct CorsConfigImpl { allow_origins: Vec, #[serde(alias = "methods", deserialize_with = "de_methods")] allow_methods: Vec, + #[serde(alias = "paths")] + bypass_path: Arc>, } impl CorsConfigTrait for CorsConfigImpl { fn allow_origins(&self) -> Vec { self.allow_origins.clone() } fn allow_methods(&self) -> Vec { self.allow_methods.clone() } + + fn bypass_paths(&self) -> Arc> { + Arc::clone(&self.bypass_path) + } } fn de_origins<'de, D: Deserializer<'de>>( diff --git a/src/error/mod.rs b/src/error/mod.rs index 6ceb46d64..81dc31583 100644 --- a/src/error/mod.rs +++ b/src/error/mod.rs @@ -10,8 +10,8 @@ use tracing::{error, instrument, warn}; #[macro_export] /// 1. 辅助构造枚举形式的Error, -/// 并提供 [Form](std::convert::Form)转换实现, -/// 和 [StatusErr](status_err::StatusErr)实现 +/// 并提供 [Form](std::convert::Form)转换实现, +/// 和 [StatusErr](status_err::StatusErr)实现 /// ```rust /// error_generate!( /// // |------- 构造的枚举型异常的类型名称 diff --git a/src/main.rs b/src/main.rs index fbbceac79..3f43754cd 100644 --- a/src/main.rs +++ b/src/main.rs @@ -14,7 +14,7 @@ use bootstrap::{ service_init::{graceful_shutdown, RouteV1, RouterFallback}, }, middleware::{ - cors::PrepareCors, panic_report::PrepareCatchPanic, + cors::ConditionCorsPrepare, panic_report::PrepareCatchPanic, tracing_request::PrepareRequestTracker, }, }; @@ -95,7 +95,9 @@ async fn main_task() { // router .prepare_route(RouteV1) .prepare_route(RouterFallback) - .prepare_middleware::(PrepareCors::<_, CorsConfigImpl>) + .prepare_middleware::( + ConditionCorsPrepare::<_, CorsConfigImpl>, + ) .prepare_middleware::( PrepareCatchPanic::<_, QqChannelConfig>, ) diff --git a/src/serves/frontend/ceobe/operation/announcement/view.rs b/src/serves/frontend/ceobe/operation/announcement/view.rs index f51ff703b..9644b3bcb 100644 --- a/src/serves/frontend/ceobe/operation/announcement/view.rs +++ b/src/serves/frontend/ceobe/operation/announcement/view.rs @@ -42,15 +42,6 @@ impl From for AnnouncementItem { } } -#[cfg(test)] -mod test { - #[test] - fn test_url() { - let url = url::Url::parse("icon"); - println!("{url:?}") - } -} - /// 用于请求头缓存信息生成 pub struct AnnouncementItems(pub(super) Vec); impl AnnouncementItems { @@ -67,3 +58,12 @@ impl ModifyState for AnnouncementItems { Cow::Borrowed(&self.0) } } + +#[cfg(test)] +mod test { + #[test] + fn test_url() { + let url = url::Url::parse("icon"); + println!("{url:?}") + } +}