Skip to content

Commit

Permalink
support cors array string config
Browse files Browse the repository at this point in the history
  • Loading branch information
timzaak committed Sep 4, 2024
1 parent 40ceccc commit 9258253
Show file tree
Hide file tree
Showing 11 changed files with 61 additions and 27 deletions.
8 changes: 4 additions & 4 deletions config.release.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
## directory to store static web files. if you use docker, please mount a persistence volume for it.
file_dir = "/data"

## enable cors, default is none, when set '*', then all cors is ok.
## enable cors, default is none, if cors is [], then all cors is ok.
## Access-Control-Allow-Origin: $ORIGIN
## Access-Control-Allow-Methods: OPTION,GET,HEAD
## Access-Control-Max-Age: 3600
## If you put the server behind HTTPS proxy, please enable it, or domains.cors = ['?']
## Attension: domains.cors would overwrite the cors config, rather than merge this.
cors = ['*']
## If you put the server behind HTTPS proxy, please enable it, or domains.cors = ['http://www.example.com:8080']
## Attension: domains.cors config would overwrite the cors config, rather than merge this.
cors = []
## http bind, if set port <= 0 or remove http, will disable http server(need set https config)
[http]
port = 80
Expand Down
8 changes: 4 additions & 4 deletions docs/guide/spa-server-configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ The config default path is `./config.toml`, you can change it by environment `SP
## directory to store static web files. if you use docker, please mount a persistence volume for it.
file_dir = "/data"

## enable cors, default is none, when set '*', then all cors is ok.
## enable cors, default is none, if cors is [], then all cors is ok.
## Access-Control-Allow-Origin: $ORIGIN
## Access-Control-Allow-Methods: OPTION,GET,HEAD
## Access-Control-Max-Age: 3600
## If you put the server behind HTTPS proxy, please enable it, or domains.cors = ['?']
## Attension: domains.cors would overwrite the cors config, rather than merge this.
cors = ['*']
## If you put the server behind HTTPS proxy, please enable it, or domains.cors = ['http://www.example.com:8080']
## Attension: domains.cors config would overwrite the cors config, rather than merge this.
cors = []
## http bind, if set port <= 0 or remove http, will disable http server(need set https config)
[http]
port = 80
Expand Down
31 changes: 28 additions & 3 deletions server/src/config.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use anyhow::{bail, Context};
use duration_str::deserialize_duration;
use serde::Deserialize;
use serde::{Deserialize, Deserializer};
use small_acme::LetsEncrypt;
use std::time::Duration;
use std::{env, fs};
use std::collections::HashSet;
use headers::{HeaderValue, Origin};
use tracing::warn;

const CONFIG_PATH: &str = "/config/config.toml";
Expand All @@ -12,7 +14,7 @@ const CONFIG_PATH: &str = "/config/config.toml";
pub struct Config {
pub file_dir: String,
#[serde(default)]
pub cors: bool,
pub cors: Option<HashSet<OriginWrapper>>,
pub admin_config: Option<AdminConfig>,
pub http: Option<HttpConfig>,
pub https: Option<HttpsConfig>,
Expand Down Expand Up @@ -94,7 +96,7 @@ fn default_max_upload_size() -> u64 {
#[derive(Deserialize, Debug, Clone, PartialEq)]
pub struct DomainConfig {
pub domain: String,
pub cors: Option<bool>,
pub cors: Option<HashSet<OriginWrapper>>,
pub cache: Option<DomainCacheConfig>,
pub https: Option<DomainHttpsConfig>,
pub alias: Option<Vec<String>>,
Expand Down Expand Up @@ -226,6 +228,29 @@ pub fn get_host_path_from_domain(domain: &str) -> (&str, &str) {
}
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct OriginWrapper(HeaderValue);

pub(crate) fn extract_origin(data:&Option<HashSet<OriginWrapper>>) -> Option<HashSet<HeaderValue>> {
data.as_ref().map(|set| set.iter().map(|o| o.0.clone()).collect())
}

impl <'de> Deserialize<'de> for OriginWrapper {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>
{
let data = String::deserialize(deserializer)?;
let mut parts = data.splitn(2, "://");
let scheme = parts.next().expect("missing scheme");
let rest = parts.next().expect("missing scheme");
let origin = Origin::try_from_parts(scheme, rest, None).expect("invalid Origin");

Ok(OriginWrapper(origin.to_string().parse()
.expect("Origin is always a valid HeaderValue")))
}
}

#[cfg(test)]
mod test {
use std::env;
Expand Down
15 changes: 12 additions & 3 deletions server/src/cors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,23 @@ pub fn cors_resp(mut res: Response<Body>, origin: HeaderValue) -> Response<Body>
res
}

fn is_origin_allowed(origins: &Option<HashSet<HeaderValue>>, origin: &HeaderValue) -> bool {
if let Some(ref allowed) = origins {
allowed.is_empty()||allowed.contains(origin)
} else {
false
}
}
// preflight response
pub fn resp_cors_request(
method: &Method,
headers: &HeaderMap,
allow_cors: bool,
origins: &Option<HashSet<HeaderValue>>,
) -> Either<Validated, Response<Body>> {
match (headers.get(header::ORIGIN), method) {
(Some(origin), &Method::OPTIONS) => {
if !allow_cors {

if !is_origin_allowed(origins, origin) {
return Either::Right(resp(StatusCode::FORBIDDEN, "origin not allowed"));
}
if let Some(req_method) = headers.get(header::ACCESS_CONTROL_REQUEST_METHOD) {
Expand All @@ -59,7 +68,7 @@ pub fn resp_cors_request(
Either::Right(res)
}
(Some(origin), _) => {
if !allow_cors {
if !is_origin_allowed(origins, origin) {
Either::Right(resp(StatusCode::FORBIDDEN, "origin not allowed"))
} else {
Either::Left(Validated::Simple(origin.clone()))
Expand Down
10 changes: 5 additions & 5 deletions server/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ use crate::acme::{get_challenge_path, ChallengePath, ACME_CHALLENGE};
use crate::cors::{cors_resp, resp_cors_request, Validated};
use crate::DomainStorage;
use futures_util::future::Either;
use headers::HeaderMapExt;
use headers::{HeaderMapExt, HeaderValue};
use hyper::header::LOCATION;
use hyper::http::uri::Authority;
use hyper::{Body, Request, Response, StatusCode};
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::convert::Infallible;
use std::str::FromStr;
use std::sync::Arc;
Expand All @@ -24,7 +24,7 @@ pub struct ServiceConfig {
}

pub struct DomainServiceConfig {
pub cors: bool,
pub cors: Option<HashSet<HeaderValue>>,
pub redirect_https: Option<u16>,
pub enable_acme: bool,
}
Expand Down Expand Up @@ -130,7 +130,7 @@ pub async fn create_http_service(

let service_config = service_config.get_domain_service_config(host);
// cors
let origin_opt = match resp_cors_request(req.method(), req.headers(), service_config.cors) {
let origin_opt = match resp_cors_request(req.method(), req.headers(), &service_config.cors) {
Either::Left(x) => Some(x),
Either::Right(v) => return Ok(v),
};
Expand Down Expand Up @@ -194,7 +194,7 @@ pub async fn create_https_service(

let service_config = service_config.get_domain_service_config(host);
// cors
let origin_opt = match resp_cors_request(req.method(), req.headers(), service_config.cors) {
let origin_opt = match resp_cors_request(req.method(), req.headers(), &service_config.cors) {
Either::Left(x) => Some(x),
Either::Right(v) => return Ok(v),
};
Expand Down
6 changes: 3 additions & 3 deletions server/src/web_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use futures_util::future::Either;
use tokio::net::TcpListener as TKTcpListener;
use tokio::sync::oneshot::Receiver;

use crate::config::{Config, HttpConfig, HttpsConfig};
use crate::config::{extract_origin, Config, HttpConfig, HttpsConfig};
use crate::domain_storage::DomainStorage;
use crate::service::{create_http_service, create_https_service, DomainServiceConfig, ServiceConfig};
use crate::tls::TlsAcceptor;
Expand Down Expand Up @@ -93,7 +93,7 @@ impl Server {
};

let default = DomainServiceConfig {
cors: conf.cors,
cors: extract_origin(&conf.cors),
redirect_https: default_http_redirect_to_https,
enable_acme: conf.https.as_ref().and_then(|x| x.acme.as_ref()).is_some(),
};
Expand All @@ -116,7 +116,7 @@ impl Server {
Some(false) => None
};
let domain_service_config: DomainServiceConfig = DomainServiceConfig {
cors: domain.cors.unwrap_or(default.cors),
cors: extract_origin(&domain.cors).or_else(||default.cors.clone()),
redirect_https,
enable_acme: domain
.https
Expand Down
2 changes: 1 addition & 1 deletion tests/data/server_config.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
file_dir = "./data/web"
cors = true
cors = []
[http]
port = 8080
addr = "0.0.0.0"
Expand Down
2 changes: 1 addition & 1 deletion tests/data/server_config_acme.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
file_dir = "./data/web"
cors = true
cors = []

# http bind, if set port <= 0, will disable http server(need set https config)
[http]
Expand Down
2 changes: 1 addition & 1 deletion tests/data/server_config_acme_alias.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
file_dir = "./data/web"
cors = true
cors = []

# http bind, if set port <= 0, will disable http server(need set https config)
[http]
Expand Down
2 changes: 1 addition & 1 deletion tests/data/server_config_alias.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
cors = true
cors = []
file_dir = "./data/web"


Expand Down
2 changes: 1 addition & 1 deletion tests/data/server_config_https.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
file_dir = "./data/web"
cors = true
cors = []

# http bind, if set port <= 0, will disable http server(need set https config)
[http]
Expand Down

0 comments on commit 9258253

Please sign in to comment.