Skip to content

Commit

Permalink
Merge pull request #206 from Enraged-Dun-Cookie-Development-Team/feat…
Browse files Browse the repository at this point in the history
…-修改跨域中间件

[feat]条件跨域请求约束添加
  • Loading branch information
Goodjooy authored Sep 3, 2024
2 parents 892fd3c + 03c4d44 commit bcef6fb
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 25 deletions.
104 changes: 92 additions & 12 deletions src/bootstrap/middleware/cors.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,108 @@
use std::{collections::HashSet, sync::Arc, task::Poll};

use axum_starter::{prepare, PrepareMiddlewareEffect};
use http::{HeaderValue, Method};
use tower_http::cors::CorsLayer;
use tower::{Layer, Service};
use tower_http::cors::{Any, Cors, CorsLayer};

pub trait CorsConfigTrait {
fn allow_origins(&self) -> Vec<HeaderValue>;

fn allow_methods(&self) -> Vec<Method>;

fn bypass_paths(&self) -> Arc<HashSet<String>>;
}

#[prepare(PrepareCors)]
pub fn prepare_cors<T: CorsConfigTrait>(cfg: &T) -> CorsMiddleware {
CorsMiddleware(
CorsLayer::new()
.allow_origin(cfg.allow_origins())
.allow_methods(cfg.allow_methods()),
)
#[derive(Clone)]
pub struct ConditionCors<S> {
bypass_cors: Cors<S>,
default_cors: Cors<S>,
bypass_paths: Arc<HashSet<String>>,
}

pub struct CorsMiddleware(CorsLayer);
impl<S, Req, Resp> Service<http::Request<Req>> for ConditionCors<S>
where
S: Service<http::Request<Req>, Response = http::Response<Resp>>,
Req: Default,
Resp: Default,
{
type Error = <Cors<S> as Service<http::Request<Req>>>::Error;
type Future = <Cors<S> as Service<http::Request<Req>>>::Future;
type Response = <Cors<S> as Service<http::Request<Req>>>::Response;

fn poll_ready(
&mut self, cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
match (
self.bypass_cors.poll_ready(cx),
self.default_cors.poll_ready(cx),
) {
(Poll::Ready(bypass), Poll::Ready(default)) => {
Poll::Ready(bypass.and(default))
}
(Poll::Ready(_), Poll::Pending) => Poll::Pending,
(Poll::Pending, Poll::Ready(_)) => Poll::Pending,
(Poll::Pending, Poll::Pending) => Poll::Pending,
}
}

fn call(&mut self, req: http::Request<Req>) -> Self::Future {
let uri = req.uri().path();
if self.bypass_paths.contains(uri) {
self.bypass_cors.call(req)
}
else {
self.default_cors.call(req)
}
}
}

impl<S> PrepareMiddlewareEffect<S> for CorsMiddleware {
type Middleware = CorsLayer;
#[derive(Clone)]
pub struct ConditionCorsLayer {
bypass_cors: CorsLayer,
default_cors: CorsLayer,
bypass_paths: Arc<HashSet<String>>,
}

impl ConditionCorsLayer {
fn from_config(config: &impl CorsConfigTrait) -> Self {
Self {
bypass_cors: CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any),
default_cors: CorsLayer::new()
.allow_origin(config.allow_origins())
.allow_methods(config.allow_methods()),
bypass_paths: config.bypass_paths(),
}
}
}

impl<S: Clone> Layer<S> for ConditionCorsLayer {
type Service = ConditionCors<S>;

fn layer(&self, inner: S) -> Self::Service {
ConditionCors {
bypass_cors: self.bypass_cors.layer(inner.clone()),
default_cors: self.default_cors.layer(inner),
bypass_paths: Arc::clone(&self.bypass_paths),
}
}
}

impl<S: Clone> PrepareMiddlewareEffect<S> for ConditionCorsLayer {
type Middleware = ConditionCorsLayer;

fn take(self, _: &mut axum_starter::StateCollector) -> Self::Middleware {
self.0
self
}
}

pub type ConditionCorsEffect = ConditionCorsLayer;

#[prepare(ConditionCorsPrepare)]
pub fn prepare_condition_cors<T: CorsConfigTrait>(
config: &T,
) -> ConditionCorsEffect {
ConditionCorsEffect::from_config(config)
}
8 changes: 8 additions & 0 deletions src/configs/cors_config.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::{collections::HashSet, sync::Arc};

use http::{HeaderValue, Method};
use serde::{Deserialize, Deserializer};

Expand All @@ -9,12 +11,18 @@ pub struct CorsConfigImpl {
allow_origins: Vec<HeaderValue>,
#[serde(alias = "methods", deserialize_with = "de_methods")]
allow_methods: Vec<Method>,
#[serde(alias = "paths")]
bypass_path: Arc<HashSet<String>>,
}

impl CorsConfigTrait for CorsConfigImpl {
fn allow_origins(&self) -> Vec<HeaderValue> { self.allow_origins.clone() }

fn allow_methods(&self) -> Vec<Method> { self.allow_methods.clone() }

fn bypass_paths(&self) -> Arc<HashSet<String>> {
Arc::clone(&self.bypass_path)
}
}

fn de_origins<'de, D: Deserializer<'de>>(
Expand Down
4 changes: 2 additions & 2 deletions src/error/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ use tracing::{error, instrument, warn};

#[macro_export]
/// 1. 辅助构造枚举形式的Error,
/// 并提供 [Form](std::convert::Form)转换实现,
/// 和 [StatusErr](status_err::StatusErr)实现
/// 并提供 [Form](std::convert::Form)转换实现,
/// 和 [StatusErr](status_err::StatusErr)实现
/// ```rust
/// error_generate!(
/// // |------- 构造的枚举型异常的类型名称
Expand Down
6 changes: 4 additions & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use bootstrap::{
service_init::{graceful_shutdown, RouteV1, RouterFallback},
},
middleware::{
cors::PrepareCors, panic_report::PrepareCatchPanic,
cors::ConditionCorsPrepare, panic_report::PrepareCatchPanic,
tracing_request::PrepareRequestTracker,
},
};
Expand Down Expand Up @@ -95,7 +95,9 @@ async fn main_task() {
// router
.prepare_route(RouteV1)
.prepare_route(RouterFallback)
.prepare_middleware::<Route, _>(PrepareCors::<_, CorsConfigImpl>)
.prepare_middleware::<Route, _>(
ConditionCorsPrepare::<_, CorsConfigImpl>,
)
.prepare_middleware::<Route, _>(
PrepareCatchPanic::<_, QqChannelConfig>,
)
Expand Down
18 changes: 9 additions & 9 deletions src/serves/frontend/ceobe/operation/announcement/view.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,6 @@ impl From<announcement::Model> for AnnouncementItem {
}
}

#[cfg(test)]
mod test {
#[test]
fn test_url() {
let url = url::Url::parse("icon");
println!("{url:?}")
}
}

/// 用于请求头缓存信息生成
pub struct AnnouncementItems(pub(super) Vec<AnnouncementItem>);
impl AnnouncementItems {
Expand All @@ -67,3 +58,12 @@ impl ModifyState for AnnouncementItems {
Cow::Borrowed(&self.0)
}
}

#[cfg(test)]
mod test {
#[test]
fn test_url() {
let url = url::Url::parse("icon");
println!("{url:?}")
}
}

0 comments on commit bcef6fb

Please sign in to comment.